Skip to content
Snippets Groups Projects
Commit 75c484bb authored by lukas leufen's avatar lukas leufen
Browse files

new function list_pop, intermediate working step

parent 6ae330b2
No related branches found
No related tags found
2 merge requests!59Develop,!52implemented bootstraps
Pipeline #29743 passed
......@@ -10,6 +10,76 @@ import dask.array as da
import xarray as xr
import os
import re
from src import helpers
class BootStrapGenerator:
def __init__(self, orig_generator, boots, chunksize, bootstrap_path):
self.orig_generator: DataGenerator = orig_generator
self.stations = self.orig_generator.stations
self.boots = boots
self.chunksize = chunksize
self.bootstrap_path = bootstrap_path
self._iterator = 0
self.__next__()
a = 1
def __len__(self):
"""
display the number of stations
"""
return len(self.orig_generator)*self.boots
def __iter__(self):
"""
Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
`_iterator` to 0.
:return:
"""
self._iterator = 0
return self
def __next__(self):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
if self._iterator < self.__len__():
for i, data in enumerate(self.orig_generator):
station = self.orig_generator.get_station_key(i)
hist, label = data
shuffled_data = self.load_boot_data(station)
all_variables = self.orig_generator.variables
for var in all_variables:
for boot in range(self.boots):
boot_hist: xr.DataArray = hist
boot_hist = boot_hist.sel(variables=helpers.list_pop(all_variables, var))
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
boot_hist = boot_hist.combine_first(shuffled_var)
boot_hist.sortby("variables")
# boot_hist
# self._iterator += 1
# if data.history is not None and data.label is not None: # pragma: no branch
# return data.history.transpose("datetime", "window", "Stations", "variables"), \
# data.label.squeeze("Stations").transpose("datetime", "window")
else:
self.__next__() # pragma: no cover
else:
raise StopIteration
def load_boot_data(self, station):
files = os.listdir(self.bootstrap_path)
regex = re.compile(rf"{station}_\w*\.nc")
file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0])
# shuffled_data = xr.open_dataarray(file_name, chunks=self.chunksize)
shuffled_data = xr.open_dataarray(file_name, chunks=100)
return shuffled_data
class BootStraps(RunEnvironment):
......@@ -18,9 +88,15 @@ class BootStraps(RunEnvironment):
super().__init__()
self.test_data: DataGenerator = self.data_store.get("generator", "general.test")
self.number_bootstraps = 100
self.number_bootstraps = 50
self.bootstrap_path = self.data_store.get("bootstrap_path", "general")
self.chunks = self.get_chunk_size()
self.create_shuffled_data()
BootStrapGenerator(self.test_data, self.number_bootstraps, self.chunks, self.bootstrap_path)
def get_chunk_size(self):
hist, _ = self.test_data[0]
return (100, *hist.shape[1:], self.number_bootstraps)
def create_shuffled_data(self):
"""
......@@ -61,7 +137,7 @@ class BootStraps(RunEnvironment):
:param window:
:return:
"""
regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled*")
regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled")
max_nboot = self.number_bootstraps
for file in os.listdir(self.bootstrap_path):
match = regex.match(file)
......
......@@ -195,3 +195,13 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce
"""
multiplier = 10. ** decimals
return round_type(number * multiplier) / multiplier
def list_pop(list_full: list, pop_items):
pop_items = to_list(pop_items)
if len(pop_items) > 1:
return [e for e in list_full if e not in pop_items]
else:
list_pop = list_full.copy()
list_pop.remove(pop_items[0])
return list_pop
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment