From 6365184637832ed7552a42eb66d16f80216fbdf0 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Tue, 17 Mar 2020 15:13:18 +0100 Subject: [PATCH] data generator and distributor can use upsampling --- src/data_handling/data_distributor.py | 21 ++++++++++++++----- src/data_handling/data_generator.py | 9 +++++++- .../test_data_distributor.py | 2 +- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index b1624410..c6015ed7 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -8,15 +8,18 @@ import math import keras import numpy as np +from src.data_handling.data_generator import DataGenerator + class Distributor(keras.utils.Sequence): - def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, - permute_data: bool = False): + def __init__(self, generator: DataGenerator, model: keras.models, batch_size: int = 256, + permute_data: bool = False, upsampling: bool = False): self.generator = generator self.model = model self.batch_size = batch_size self.do_data_permutation = permute_data + self.upsampling = upsampling def _get_model_rank(self): mod_out = self.model.output_shape @@ -31,7 +34,7 @@ class Distributor(keras.utils.Sequence): return mod_rank def _get_number_of_mini_batches(self, values): - return math.ceil(values[0].shape[0] / self.batch_size) + return math.ceil(values.shape[0] / self.batch_size) def _permute_data(self, x, y): """ @@ -48,10 +51,18 @@ class Distributor(keras.utils.Sequence): for k, v in enumerate(self.generator): # get rank of output mod_rank = self._get_model_rank() - # get number of mini batches - num_mini_batches = self._get_number_of_mini_batches(v) + # get data x_total = np.copy(v[0]) y_total = np.copy(v[1]) + if self.upsampling: + try: + s = self.generator.get_data_generator(k) + x_total = np.concatenate([x_total, np.copy(s.extremes_history.copy())], axis=0) + y_total = np.concatenate([y_total, np.copy(s.extremes_labels.copy())], axis=0) + except AttributeError: # no extremes history / labels available, copy will fail + pass + # get number of mini batches + num_mini_batches = self._get_number_of_mini_batches(x_total) # permute order for mini-batches x_total, y_total = self._permute_data(x_total, y_total) for prev, curr in enumerate(range(1, num_mini_batches+1)): diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 24c9ada6..0bf0bc35 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -14,6 +14,9 @@ from src import helpers from src.data_handling.data_preparation import DataPrep from src.join import EmptyQueryResult +number = Union[float, int] +num_or_list = Union[number, List[number]] + class DataGenerator(keras.utils.Sequence): """ @@ -27,7 +30,7 @@ class DataGenerator(keras.utils.Sequence): def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, - window_lead_time: int = 4, transformation: Dict = None, **kwargs): + window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs): self.data_path = os.path.abspath(data_path) self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") if not os.path.exists(self.data_path_tmp): @@ -43,6 +46,7 @@ class DataGenerator(keras.utils.Sequence): self.limit_nan_fill = limit_nan_fill self.window_history_size = window_history_size self.window_lead_time = window_lead_time + self.extreme_values = extreme_values self.kwargs = kwargs self.transformation = self.setup_transformation(transformation) @@ -188,6 +192,9 @@ class DataGenerator(keras.utils.Sequence): data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.make_observation(self.target_dim, self.target_var, self.interpolate_dim) data.remove_nan(self.interpolate_dim) + if self.extreme_values: + kwargs = {"extremes_on_right_tail_only": self.kwargs.get("extremes_on_right_tail_only", False)} + data.multiply_extremes(self.extreme_values, **kwargs) if save_local_tmp_storage: self._save_pickle_data(data) return data diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index a26e76a0..dd0ca99d 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -46,7 +46,7 @@ class TestDistributor: distributor.model = 1 def test_get_number_of_mini_batches(self, distributor): - values = np.zeros((2, 2311, 19)) + values = np.zeros((2311, 19)) assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size) def test_distribute_on_batches_single_loop(self, generator_two_stations, model): -- GitLab