From 75c484bb5607bc2a5b8cb4b593587ffce2fad3e0 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 17 Feb 2020 13:43:22 +0100
Subject: [PATCH] new function list_pop, intermediate working step

---
 src/data_handling/bootstraps.py | 80 ++++++++++++++++++++++++++++++++-
 src/helpers.py                  | 10 +++++
 2 files changed, 88 insertions(+), 2 deletions(-)

diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py
index 80867822..9247e6e1 100644
--- a/src/data_handling/bootstraps.py
+++ b/src/data_handling/bootstraps.py
@@ -10,6 +10,76 @@ import dask.array as da
 import xarray as xr
 import os
 import re
+from src import helpers
+
+
+class BootStrapGenerator:
+
+    def __init__(self, orig_generator, boots, chunksize, bootstrap_path):
+        self.orig_generator: DataGenerator = orig_generator
+        self.stations = self.orig_generator.stations
+        self.boots = boots
+        self.chunksize = chunksize
+        self.bootstrap_path = bootstrap_path
+        self._iterator = 0
+        self.__next__()
+        a = 1
+
+    def __len__(self):
+        """
+        display the number of stations
+        """
+        return len(self.orig_generator)*self.boots
+
+    def __iter__(self):
+        """
+        Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
+        `_iterator` to 0.
+        :return:
+        """
+        self._iterator = 0
+        return self
+
+    def __next__(self):
+        """
+        This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
+        the history and label data of this generator.
+        :return:
+        """
+        if self._iterator < self.__len__():
+            for i, data in enumerate(self.orig_generator):
+                station = self.orig_generator.get_station_key(i)
+                hist, label = data
+                shuffled_data = self.load_boot_data(station)
+                all_variables = self.orig_generator.variables
+                for var in all_variables:
+                    for boot in range(self.boots):
+                        boot_hist: xr.DataArray = hist
+                        boot_hist = boot_hist.sel(variables=helpers.list_pop(all_variables, var))
+                        shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
+                        boot_hist = boot_hist.combine_first(shuffled_var)
+                        boot_hist.sortby("variables")
+                        # boot_hist
+
+
+
+
+            # self._iterator += 1
+            # if data.history is not None and data.label is not None:  # pragma: no branch
+            #     return data.history.transpose("datetime", "window", "Stations", "variables"), \
+            #         data.label.squeeze("Stations").transpose("datetime", "window")
+            else:
+                self.__next__()  # pragma: no cover
+        else:
+            raise StopIteration
+
+    def load_boot_data(self, station):
+        files = os.listdir(self.bootstrap_path)
+        regex = re.compile(rf"{station}_\w*\.nc")
+        file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0])
+        # shuffled_data = xr.open_dataarray(file_name, chunks=self.chunksize)
+        shuffled_data = xr.open_dataarray(file_name, chunks=100)
+        return shuffled_data
 
 
 class BootStraps(RunEnvironment):
@@ -18,9 +88,15 @@ class BootStraps(RunEnvironment):
 
         super().__init__()
         self.test_data: DataGenerator = self.data_store.get("generator", "general.test")
-        self.number_bootstraps = 100
+        self.number_bootstraps = 50
         self.bootstrap_path = self.data_store.get("bootstrap_path", "general")
+        self.chunks = self.get_chunk_size()
         self.create_shuffled_data()
+        BootStrapGenerator(self.test_data, self.number_bootstraps, self.chunks, self.bootstrap_path)
+
+    def get_chunk_size(self):
+        hist, _ = self.test_data[0]
+        return (100, *hist.shape[1:], self.number_bootstraps)
 
     def create_shuffled_data(self):
         """
@@ -61,7 +137,7 @@ class BootStraps(RunEnvironment):
         :param window:
         :return:
         """
-        regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled*")
+        regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled")
         max_nboot = self.number_bootstraps
         for file in os.listdir(self.bootstrap_path):
             match = regex.match(file)
diff --git a/src/helpers.py b/src/helpers.py
index 680d3bd1..399804d7 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -195,3 +195,13 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce
     """
     multiplier = 10. ** decimals
     return round_type(number * multiplier) / multiplier
+
+
+def list_pop(list_full: list, pop_items):
+    pop_items = to_list(pop_items)
+    if len(pop_items) > 1:
+        return [e for e in list_full if e not in pop_items]
+    else:
+        list_pop = list_full.copy()
+        list_pop.remove(pop_items[0])
+        return list_pop
-- 
GitLab