From 1d619ee1b9c806bd9695271f1987ae6af6bb52b0 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 7 Feb 2020 16:06:27 +0100
Subject: [PATCH] first implementation of create shuffled data

---
 src/data_handling/bootstraps.py | 87 +++++++++++++++++++++++++++++++++
 1 file changed, 87 insertions(+)
 create mode 100644 src/data_handling/bootstraps.py

diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py
new file mode 100644
index 00000000..21fc23d8
--- /dev/null
+++ b/src/data_handling/bootstraps.py
@@ -0,0 +1,87 @@
+__author__ = 'Felix Kleinert, Lukas Leufen'
+__date__ = '2020-02-07'
+
+
+from src.run_modules.run_environment import RunEnvironment
+from src.data_handling.data_generator import DataGenerator
+import numpy as np
+import logging
+import xarray as xr
+import os
+import re
+
+
+class BootStraps(RunEnvironment):
+
+    def __init__(self):
+
+        super().__init__()
+        self.test_data: DataGenerator = self.data_store.get("generator", "general.test")
+        self.number_bootstraps = 200
+        self.bootstrap_path = self.data_store.get("bootstrap_path", "general")
+        self.create_shuffled_data()
+
+    def create_shuffled_data(self):
+        variables_str = '_'.join(sorted(self.test_data.variables))
+        window = self.test_data.window_history_size
+        for station in self.test_data.stations:
+            valid, _, max_nboot = self.valid_bootstrap_file(station, variables_str, window)
+            if not valid:
+                logging.info(f'create bootstap data for {station}')
+                hist, _ = self.test_data[station]
+                data = hist.copy()
+                file_name = f"{station}_{variables_str}_hist{window}_nboots{max_nboot}_shuffled.nc"
+                file_path = os.path.join(self.bootstrap_path, file_name)
+                data = data.expand_dims({'boots': range(max_nboot)}, axis=-1)
+                shuffled_variable = np.full(data.shape, np.nan)
+                for i, var in enumerate(data.coords['variables']):
+                    single_variable = data.sel(variables=var).values
+                    shuffled_variable[..., i, :] = self.shuffle_single_variable(single_variable)
+                shuffled_data = xr.DataArray(shuffled_variable, coords=data.coords, dims=data.dims)
+                shuffled_data.to_netcdf(file_path)
+
+    def valid_bootstrap_file(self, station, variables, window):
+        str_re = re.compile(f"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled*")
+        dir_list = os.listdir(self.bootstrap_path)
+        max_nboot = self.number_bootstraps
+        max_window = self.number_bootstraps
+        for file in dir_list:
+            match = str_re.match(file)
+            if match:
+                window_existing = int(match.group(1))
+                nboot_existing = int(match.group(2))
+                max_window = max([max_window, window_existing])
+                max_nboot = max([max_nboot, nboot_existing])
+                if (window_existing >= window) and (nboot_existing >= self.number_bootstraps):
+                    return True, 0, 0
+                else:
+                    os.remove(os.path.join(self.bootstrap_path, file))
+        return False, max_window, max_nboot
+
+
+
+
+
+    def shuffle_single_variable(self, data):
+        orig_shape = data.shape
+        size = orig_shape
+        # size = (*orig_shape, self.number_bootstraps)
+        return np.random.choice(data.reshape(-1,), size=size)
+
+
+
+if __name__ == "__main__":
+
+    from src.run_modules.experiment_setup import ExperimentSetup
+    from src.run_modules.run_environment import RunEnvironment
+    from src.run_modules.pre_processing import PreProcessing
+
+    formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    with RunEnvironment():
+        ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'],
+                        station_type='background', trainable=True, window_history_size=9)
+        PreProcessing()
+
+        BootStraps()
-- 
GitLab