From 1d619ee1b9c806bd9695271f1987ae6af6bb52b0 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 7 Feb 2020 16:06:27 +0100 Subject: [PATCH] first implementation of create shuffled data --- src/data_handling/bootstraps.py | 87 +++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 src/data_handling/bootstraps.py diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py new file mode 100644 index 00000000..21fc23d8 --- /dev/null +++ b/src/data_handling/bootstraps.py @@ -0,0 +1,87 @@ +__author__ = 'Felix Kleinert, Lukas Leufen' +__date__ = '2020-02-07' + + +from src.run_modules.run_environment import RunEnvironment +from src.data_handling.data_generator import DataGenerator +import numpy as np +import logging +import xarray as xr +import os +import re + + +class BootStraps(RunEnvironment): + + def __init__(self): + + super().__init__() + self.test_data: DataGenerator = self.data_store.get("generator", "general.test") + self.number_bootstraps = 200 + self.bootstrap_path = self.data_store.get("bootstrap_path", "general") + self.create_shuffled_data() + + def create_shuffled_data(self): + variables_str = '_'.join(sorted(self.test_data.variables)) + window = self.test_data.window_history_size + for station in self.test_data.stations: + valid, _, max_nboot = self.valid_bootstrap_file(station, variables_str, window) + if not valid: + logging.info(f'create bootstap data for {station}') + hist, _ = self.test_data[station] + data = hist.copy() + file_name = f"{station}_{variables_str}_hist{window}_nboots{max_nboot}_shuffled.nc" + file_path = os.path.join(self.bootstrap_path, file_name) + data = data.expand_dims({'boots': range(max_nboot)}, axis=-1) + shuffled_variable = np.full(data.shape, np.nan) + for i, var in enumerate(data.coords['variables']): + single_variable = data.sel(variables=var).values + shuffled_variable[..., i, :] = self.shuffle_single_variable(single_variable) + shuffled_data = xr.DataArray(shuffled_variable, coords=data.coords, dims=data.dims) + shuffled_data.to_netcdf(file_path) + + def valid_bootstrap_file(self, station, variables, window): + str_re = re.compile(f"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled*") + dir_list = os.listdir(self.bootstrap_path) + max_nboot = self.number_bootstraps + max_window = self.number_bootstraps + for file in dir_list: + match = str_re.match(file) + if match: + window_existing = int(match.group(1)) + nboot_existing = int(match.group(2)) + max_window = max([max_window, window_existing]) + max_nboot = max([max_nboot, nboot_existing]) + if (window_existing >= window) and (nboot_existing >= self.number_bootstraps): + return True, 0, 0 + else: + os.remove(os.path.join(self.bootstrap_path, file)) + return False, max_window, max_nboot + + + + + + def shuffle_single_variable(self, data): + orig_shape = data.shape + size = orig_shape + # size = (*orig_shape, self.number_bootstraps) + return np.random.choice(data.reshape(-1,), size=size) + + + +if __name__ == "__main__": + + from src.run_modules.experiment_setup import ExperimentSetup + from src.run_modules.run_environment import RunEnvironment + from src.run_modules.pre_processing import PreProcessing + + formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' + logging.basicConfig(format=formatter, level=logging.INFO) + + with RunEnvironment(): + ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'], + station_type='background', trainable=True, window_history_size=9) + PreProcessing() + + BootStraps() -- GitLab