From 75c484bb5607bc2a5b8cb4b593587ffce2fad3e0 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Mon, 17 Feb 2020 13:43:22 +0100 Subject: [PATCH] new function list_pop, intermediate working step --- src/data_handling/bootstraps.py | 80 ++++++++++++++++++++++++++++++++- src/helpers.py | 10 +++++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 80867822..9247e6e1 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -10,6 +10,76 @@ import dask.array as da import xarray as xr import os import re +from src import helpers + + +class BootStrapGenerator: + + def __init__(self, orig_generator, boots, chunksize, bootstrap_path): + self.orig_generator: DataGenerator = orig_generator + self.stations = self.orig_generator.stations + self.boots = boots + self.chunksize = chunksize + self.bootstrap_path = bootstrap_path + self._iterator = 0 + self.__next__() + a = 1 + + def __len__(self): + """ + display the number of stations + """ + return len(self.orig_generator)*self.boots + + def __iter__(self): + """ + Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute + `_iterator` to 0. + :return: + """ + self._iterator = 0 + return self + + def __next__(self): + """ + This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return + the history and label data of this generator. + :return: + """ + if self._iterator < self.__len__(): + for i, data in enumerate(self.orig_generator): + station = self.orig_generator.get_station_key(i) + hist, label = data + shuffled_data = self.load_boot_data(station) + all_variables = self.orig_generator.variables + for var in all_variables: + for boot in range(self.boots): + boot_hist: xr.DataArray = hist + boot_hist = boot_hist.sel(variables=helpers.list_pop(all_variables, var)) + shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables") + boot_hist = boot_hist.combine_first(shuffled_var) + boot_hist.sortby("variables") + # boot_hist + + + + + # self._iterator += 1 + # if data.history is not None and data.label is not None: # pragma: no branch + # return data.history.transpose("datetime", "window", "Stations", "variables"), \ + # data.label.squeeze("Stations").transpose("datetime", "window") + else: + self.__next__() # pragma: no cover + else: + raise StopIteration + + def load_boot_data(self, station): + files = os.listdir(self.bootstrap_path) + regex = re.compile(rf"{station}_\w*\.nc") + file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0]) + # shuffled_data = xr.open_dataarray(file_name, chunks=self.chunksize) + shuffled_data = xr.open_dataarray(file_name, chunks=100) + return shuffled_data class BootStraps(RunEnvironment): @@ -18,9 +88,15 @@ class BootStraps(RunEnvironment): super().__init__() self.test_data: DataGenerator = self.data_store.get("generator", "general.test") - self.number_bootstraps = 100 + self.number_bootstraps = 50 self.bootstrap_path = self.data_store.get("bootstrap_path", "general") + self.chunks = self.get_chunk_size() self.create_shuffled_data() + BootStrapGenerator(self.test_data, self.number_bootstraps, self.chunks, self.bootstrap_path) + + def get_chunk_size(self): + hist, _ = self.test_data[0] + return (100, *hist.shape[1:], self.number_bootstraps) def create_shuffled_data(self): """ @@ -61,7 +137,7 @@ class BootStraps(RunEnvironment): :param window: :return: """ - regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled*") + regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled") max_nboot = self.number_bootstraps for file in os.listdir(self.bootstrap_path): match = regex.match(file) diff --git a/src/helpers.py b/src/helpers.py index 680d3bd1..399804d7 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -195,3 +195,13 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce """ multiplier = 10. ** decimals return round_type(number * multiplier) / multiplier + + +def list_pop(list_full: list, pop_items): + pop_items = to_list(pop_items) + if len(pop_items) > 1: + return [e for e in list_full if e not in pop_items] + else: + list_pop = list_full.copy() + list_pop.remove(pop_items[0]) + return list_pop -- GitLab