diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 412cb4448180e89462a270930ec710936a203db3..6b66b87552e4962ec921e6c37d509c90121f3b9b 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -10,30 +10,20 @@ import xarray as xr import os import re from src import helpers -from typing import List +from typing import List, Union class BootStrapGenerator: - def __init__(self, orig_generator, boots, chunksize, bootstrap_path): + def __init__(self, orig_generator, number_of_boots, bootstrap_path): self.orig_generator: DataGenerator = orig_generator self.stations = self.orig_generator.stations self.variables = self.orig_generator.variables - self.boots = boots - self.chunksize = chunksize + self.number_of_boots = number_of_boots self.bootstrap_path = bootstrap_path - self._iterator = 0 def __len__(self): - """ - display the number of stations - """ - return len(self.orig_generator)*self.boots*len(self.variables) - - def get_labels(self, key): - _, label = self.orig_generator[key] - for _ in range(self.boots): - yield label + return len(self.orig_generator) * self.number_of_boots * len(self.variables) def get_generator(self): """ @@ -46,10 +36,10 @@ class BootStrapGenerator: station = self.orig_generator.get_station_key(i) logging.info(f"station: {station}") hist, label = data - shuffled_data = self.load_boot_data(station) + shuffled_data = self.load_shuffled_data(station, self.variables) for var in self.variables: - logging.info(f" var: {var}") - for boot in range(self.boots): + logging.debug(f" var: {var}") + for boot in range(self.number_of_boots): logging.debug(f"boot: {boot}") boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var)) shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables") @@ -67,23 +57,54 @@ class BootStrapGenerator: for station in self.stations: label = self.orig_generator.get_data_generator(station).get_transposed_label() for var in self.variables: - for boot in range(self.boots): + for boot in range(self.number_of_boots): bootstrap_meta.extend([[var, station]] * len(label)) return bootstrap_meta - def get_orig_prediction(self, path, file_name, prediction_name="CNN"): + def get_labels(self, key: Union[str, int]): + """ + Reepats labels for given key by the number of boots and yield it one by one. + :param key: key of station (either station name as string or the position in generator as integer) + :return: yields labels for length of boots + """ + _, label = self.orig_generator[key] + for _ in range(self.number_of_boots): + yield label + + def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN"): + """ + Repeats predictions from given file(_name) in path by the number of boots. + :param path: path to file + :param file_name: file name + :param prediction_name: name of the prediction to select from loaded file + :return: yields predictions for length of boots + """ file = os.path.join(path, file_name) data = xr.open_dataarray(file) - for _ in range(self.boots): + for _ in range(self.number_of_boots): yield data.sel(type=prediction_name).squeeze() - def load_boot_data(self, station): + def load_shuffled_data(self, station: str, variables: List[str]) -> xr.DataArray: + """ + Load shuffled data from bootstrap path. Data is stored as + '<station>_<var1>_<var2>_..._hist<histsize>_nboots<nboots>_shuffled.nc', e.g. + 'DEBW107_cloudcover_no_no2_temp_u_v_hist13_nboots20_shuffled.nc' + :param station: + :param variables: + :return: shuffled data as xarray + """ files = os.listdir(self.bootstrap_path) - regex = re.compile(rf"{station}_\w*\.nc") + regex = self.create_file_regex(station, variables) file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0]) shuffled_data = xr.open_dataarray(file_name, chunks=100) return shuffled_data + @staticmethod + def create_file_regex(station, variables): + var_regex = "".join([rf'(_\w+)*_{v}(_\w+)*' for v in sorted(variables)]) + regex = re.compile(rf"{station}{var_regex}_shuffled\.nc") + return regex + class BootStraps: @@ -93,7 +114,7 @@ class BootStraps: self.bootstrap_path = bootstrap_path self.chunks = self.get_chunk_size() self.create_shuffled_data() - self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.chunks, self.bootstrap_path) + self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.bootstrap_path) def get_boot_strap_meta(self): return self._boot_strap_generator.get_bootstrap_meta() @@ -135,7 +156,7 @@ class BootStraps: randomly selected variables. If there is a suitable local file for requested window size and number of bootstraps, no additional file will be created inside this function. """ - logging.info("create shuffled bootstrap data") + logging.info("create / check shuffled bootstrap data") variables_str = '_'.join(sorted(self.data.variables)) window = self.data.window_history_size for station in self.data.stations: diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py index 9dd23893ef903bfbd0595a482dceb32724c3b437..e66c13e43e7e311e0eef36b4c7561ca01b8a5d86 100644 --- a/test/test_data_handling/test_bootstraps.py +++ b/test/test_data_handling/test_bootstraps.py @@ -1,10 +1,12 @@ -from src.data_handling.bootstraps import BootStraps +from src.data_handling.bootstraps import BootStraps, BootStrapGenerator +from src.data_handling.data_generator import DataGenerator import pytest import os import numpy as np +import xarray as xr class TestBootstraps: @@ -61,4 +63,57 @@ class TestBootstraps: assert set(np.unique(res)).issubset({1, 2, 3}) def test_create_shuffled_data(self): - pass \ No newline at end of file + pass + + +class TestBootstrapGenerator: + + @pytest.fixture + def orig_generator(self): + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'], + ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014) + + @pytest.fixture + def boot_gen(self, orig_generator): + path = os.path.join(os.path.dirname(__file__), 'data') + dummy_content = xr.DataArray([1, 2, 3], dims="dummy") + dummy_content.to_netcdf(os.path.join(path, "DEBW107_o3_temp_shuffled.nc")) + dummy_content.to_netcdf(os.path.join(path, "DEBW013_o3_temp_shuffled.nc")) + return BootStrapGenerator(orig_generator, 20, path) + + def test_init(self, orig_generator): + gen = BootStrapGenerator(orig_generator, 20, os.path.join(os.path.dirname(__file__), 'data')) + assert gen.stations == ["DEBW107", "DEBW013"] + assert gen.variables == ["o3", "temp"] + assert gen.number_of_boots == 20 + assert gen.bootstrap_path == os.path.join(os.path.dirname(__file__), 'data') + + def test_len(self, boot_gen): + assert len(boot_gen) == 80 + + def test_get_generator(self, boot_gen): + pass + + def test_get_bootstrap_meta(self, boot_gen): + pass + + def test_get_labels(self, boot_gen): + pass + + def test_get_orig_prediction(self, boot_gen): + pass + + def test_load_shuffled_data(self, boot_gen): + shuffled_data = boot_gen.load_shuffled_data("DEBW107", ["o3", "temp"]) + assert isinstance(shuffled_data, xr.DataArray) + assert all(shuffled_data.compute().values == [1, 2, 3]) + + def test_create_file_regex(self, boot_gen): + regex = boot_gen.create_file_regex("DEBW108", ["o3", "temp", "h2o"]) + test_list = ["DEBW108_o3_test23_test_shuffled.nc", + "DEBW107_o3_test23_test_shuffled.nc", + "DEBW108_o3_test23_test.nc", + "DEBW108_h2o_o3_temp_test_shuffled.nc", + "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"] + assert list(filter(regex.search, test_list)) == ["DEBW108_h2o_o3_temp_test_shuffled.nc", + "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"]