Skip to content
Snippets Groups Projects
Commit 0e0f1fe5 authored by lukas leufen's avatar lukas leufen
Browse files

try boot strap generator implementation as keras sequence

parent 02643eb3
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #33345 failed
...@@ -16,7 +16,8 @@ def main(parser_args): ...@@ -16,7 +16,8 @@ def main(parser_args):
with RunEnvironment(): with RunEnvironment():
ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
station_type='background', trainable=False, create_new_model=True) station_type='background', trainable=False, create_new_model=False, window_history_size=6,
create_new_bootstraps=True)
PreProcessing() PreProcessing()
ModelSetup() ModelSetup()
......
...@@ -5,6 +5,7 @@ __date__ = '2020-02-07' ...@@ -5,6 +5,7 @@ __date__ = '2020-02-07'
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
import numpy as np import numpy as np
import logging import logging
import keras
import dask.array as da import dask.array as da
import xarray as xr import xarray as xr
import os import os
...@@ -13,6 +14,142 @@ from src import helpers ...@@ -13,6 +14,142 @@ from src import helpers
from typing import List, Union, Pattern from typing import List, Union, Pattern
class RealBootStrapGenerator(keras.utils.Sequence):
def __getitem__(self, index):
logging.debug(f"boot: {index}")
boot_hist = self.history.copy()
boot_hist = boot_hist.combine_first(self.__get_shuffled(index))
return boot_hist.reindex_like(self.history_orig)
def __get_shuffled(self, index):
shuffled_var = self.shuffled.sel(boots=index).expand_dims("variables").drop("boots")
return shuffled_var.transpose("datetime", "window", "Stations", "variables")
def __init__(self, number_of_boots, history, shuffled, variables, shuffled_variable):
self.number_of_boots = number_of_boots
self.variables = variables
self.history_orig = history
self.history = history.sel(variables=helpers.list_pop(self.variables, shuffled_variable))
self.shuffled = shuffled.sel(variables=shuffled_variable)
def __len__(self):
return self.number_of_boots
class BootStrapGeneratorNew:
def __init__(self, orig_generator, number_of_boots, bootstrap_path):
self.orig_generator: DataGenerator = orig_generator
self.stations = self.orig_generator.stations
self.variables = self.orig_generator.variables
self.number_of_boots = number_of_boots
self.bootstrap_path = bootstrap_path
def __len__(self):
return len(self.orig_generator) * self.number_of_boots
def get_generator_station_var_wise(self, station, var):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
hist, label = self.orig_generator[station]
shuffled_data = self.load_shuffled_data(station, self.variables)
gen = RealBootStrapGenerator(self.number_of_boots, hist, shuffled_data, self.variables, var)
return hist, label, gen, self.number_of_boots
def get_bootstrap_meta_station_var_wise(self, station, var) -> List:
"""
Create meta data on ordering of variable bootstraps according to ordering from get_generator method.
:return: list with bootstrapped variable first and its corresponding station second.
"""
bootstrap_meta = []
label = self.orig_generator.get_data_generator(station).get_transposed_label()
for boot in range(self.number_of_boots):
bootstrap_meta.extend([[var, station]] * len(label))
return bootstrap_meta
def get_labels(self, key: Union[str, int]):
"""
Reepats labels for given key by the number of boots and yield it one by one.
:param key: key of station (either station name as string or the position in generator as integer)
:return: yields labels for length of boots
"""
_, label = self.orig_generator[key]
for _ in range(self.number_of_boots):
yield label
def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN"):
"""
Repeats predictions from given file(_name) in path by the number of boots.
:param path: path to file
:param file_name: file name
:param prediction_name: name of the prediction to select from loaded file
:return: yields predictions for length of boots
"""
file = os.path.join(path, file_name)
data = xr.open_dataarray(file)
for _ in range(self.number_of_boots):
yield data.sel(type=prediction_name).squeeze()
def load_shuffled_data(self, station: str, variables: List[str]) -> xr.DataArray:
"""
Load shuffled data from bootstrap path. Data is stored as
'<station>_<var1>_<var2>_..._hist<histsize>_nboots<nboots>_shuffled.nc', e.g.
'DEBW107_cloudcover_no_no2_temp_u_v_hist13_nboots20_shuffled.nc'
:param station:
:param variables:
:return: shuffled data as xarray
"""
file_name = self.get_shuffled_data_file(station, variables)
shuffled_data = xr.open_dataarray(file_name, chunks=100)
return shuffled_data
def get_shuffled_data_file(self, station, variables):
files = os.listdir(self.bootstrap_path)
regex = self.create_file_regex(station, variables)
file = self.filter_files(regex, files, self.orig_generator.window_history_size, self.number_of_boots)
if file:
return os.path.join(self.bootstrap_path, file)
else:
raise FileNotFoundError(f"Could not find a file to match pattern {regex}")
@staticmethod
def create_file_regex(station: str, variables: List[str]) -> Pattern:
"""
Creates regex for given station and variables to look for shuffled data with pattern:
`<station>(_<var>)*_hist(<hist>)_nboots(<nboots>)_shuffled.nc`
:param station: station name to use as prefix
:param variables: variables to add after station
:return: compiled regular expression
"""
var_regex = "".join([rf"(_\w+)*_{v}(_\w+)*" for v in sorted(variables)])
regex = re.compile(rf"{station}{var_regex}_hist(\d+)_nboots(\d+)_shuffled\.nc")
return regex
@staticmethod
def filter_files(regex: Pattern, files: List[str], window: int, nboot: int) -> Union[str, None]:
"""
Filter list of files by regex. Regex has to be structured to match the following string structure
`<station>(_<var>)*_hist(<hist>)_nboots(<nboots>)_shuffled.nc`. Hist and nboots values have to be included as
group. All matches are compared to given window and nboot parameters. A valid file must have the same value (or
larger) than these parameters and contain all variables.
:param regex: compiled regular expression pattern following the style from method description
:param files: list of file names to filter
:param window: minimum length of window to look for
:param nboot: minimal number of boots to search
:return: matching file name or None, if no valid file was found
"""
for f in files:
match = regex.match(f)
if match:
last = match.lastindex
if (int(match.group(last-1)) >= window) and (int(match.group(last)) >= nboot):
return f
class BootStrapGenerator: class BootStrapGenerator:
def __init__(self, orig_generator, number_of_boots, bootstrap_path): def __init__(self, orig_generator, number_of_boots, bootstrap_path):
...@@ -23,7 +160,7 @@ class BootStrapGenerator: ...@@ -23,7 +160,7 @@ class BootStrapGenerator:
self.bootstrap_path = bootstrap_path self.bootstrap_path = bootstrap_path
def __len__(self): def __len__(self):
return len(self.orig_generator) * self.number_of_boots * len(self.variables) return len(self.orig_generator) * self.number_of_boots
def get_generator_station_var_wise(self, station, var): def get_generator_station_var_wise(self, station, var):
""" """
...@@ -146,7 +283,7 @@ class BootStraps: ...@@ -146,7 +283,7 @@ class BootStraps:
self.bootstrap_path = bootstrap_path self.bootstrap_path = bootstrap_path
self.chunks = self.get_chunk_size() self.chunks = self.get_chunk_size()
self.create_shuffled_data() self.create_shuffled_data()
self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.bootstrap_path) self._boot_strap_generator = BootStrapGeneratorNew(self.data, self.number_bootstraps, self.bootstrap_path)
@property @property
def stations(self): def stations(self):
......
...@@ -111,8 +111,9 @@ class PostProcessing(RunEnvironment): ...@@ -111,8 +111,9 @@ class PostProcessing(RunEnvironment):
hist, label, station_bootstrap, length = bootstraps.get_generator_station_var_wise(station, var) hist, label, station_bootstrap, length = bootstraps.get_generator_station_var_wise(station, var)
# make bootstrap predictions # make bootstrap predictions
bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap(), bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap,
steps=length, steps=length,
workers=4,
use_multiprocessing=True) use_multiprocessing=True)
if isinstance(bootstrap_predictions, list): # if model is branched model if isinstance(bootstrap_predictions, list): # if model is branched model
bootstrap_predictions = bootstrap_predictions[-1] bootstrap_predictions = bootstrap_predictions[-1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment