From 6365184637832ed7552a42eb66d16f80216fbdf0 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 17 Mar 2020 15:13:18 +0100
Subject: [PATCH] data generator and distributor can use upsampling

---
 src/data_handling/data_distributor.py         | 21 ++++++++++++++-----
 src/data_handling/data_generator.py           |  9 +++++++-
 .../test_data_distributor.py                  |  2 +-
 3 files changed, 25 insertions(+), 7 deletions(-)

diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
index b1624410..c6015ed7 100644
--- a/src/data_handling/data_distributor.py
+++ b/src/data_handling/data_distributor.py
@@ -8,15 +8,18 @@ import math
 import keras
 import numpy as np
 
+from src.data_handling.data_generator import DataGenerator
+
 
 class Distributor(keras.utils.Sequence):
 
-    def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
-                 permute_data: bool = False):
+    def __init__(self, generator: DataGenerator, model: keras.models, batch_size: int = 256,
+                 permute_data: bool = False, upsampling: bool = False):
         self.generator = generator
         self.model = model
         self.batch_size = batch_size
         self.do_data_permutation = permute_data
+        self.upsampling = upsampling
 
     def _get_model_rank(self):
         mod_out = self.model.output_shape
@@ -31,7 +34,7 @@ class Distributor(keras.utils.Sequence):
         return mod_rank
 
     def _get_number_of_mini_batches(self, values):
-        return math.ceil(values[0].shape[0] / self.batch_size)
+        return math.ceil(values.shape[0] / self.batch_size)
 
     def _permute_data(self, x, y):
         """
@@ -48,10 +51,18 @@ class Distributor(keras.utils.Sequence):
             for k, v in enumerate(self.generator):
                 # get rank of output
                 mod_rank = self._get_model_rank()
-                # get number of mini batches
-                num_mini_batches = self._get_number_of_mini_batches(v)
+                # get data
                 x_total = np.copy(v[0])
                 y_total = np.copy(v[1])
+                if self.upsampling:
+                    try:
+                        s = self.generator.get_data_generator(k)
+                        x_total = np.concatenate([x_total, np.copy(s.extremes_history.copy())], axis=0)
+                        y_total = np.concatenate([y_total, np.copy(s.extremes_labels.copy())], axis=0)
+                    except AttributeError:  # no extremes history / labels available, copy will fail
+                        pass
+                # get number of mini batches
+                num_mini_batches = self._get_number_of_mini_batches(x_total)
                 # permute order for mini-batches
                 x_total, y_total = self._permute_data(x_total, y_total)
                 for prev, curr in enumerate(range(1, num_mini_batches+1)):
diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py
index 24c9ada6..0bf0bc35 100644
--- a/src/data_handling/data_generator.py
+++ b/src/data_handling/data_generator.py
@@ -14,6 +14,9 @@ from src import helpers
 from src.data_handling.data_preparation import DataPrep
 from src.join import EmptyQueryResult
 
+number = Union[float, int]
+num_or_list = Union[number, List[number]]
+
 
 class DataGenerator(keras.utils.Sequence):
     """
@@ -27,7 +30,7 @@ class DataGenerator(keras.utils.Sequence):
     def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
                  interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
                  interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
-                 window_lead_time: int = 4, transformation: Dict = None, **kwargs):
+                 window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs):
         self.data_path = os.path.abspath(data_path)
         self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
         if not os.path.exists(self.data_path_tmp):
@@ -43,6 +46,7 @@ class DataGenerator(keras.utils.Sequence):
         self.limit_nan_fill = limit_nan_fill
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
+        self.extreme_values = extreme_values
         self.kwargs = kwargs
         self.transformation = self.setup_transformation(transformation)
 
@@ -188,6 +192,9 @@ class DataGenerator(keras.utils.Sequence):
             data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
             data.make_observation(self.target_dim, self.target_var, self.interpolate_dim)
             data.remove_nan(self.interpolate_dim)
+            if self.extreme_values:
+                kwargs = {"extremes_on_right_tail_only": self.kwargs.get("extremes_on_right_tail_only", False)}
+                data.multiply_extremes(self.extreme_values, **kwargs)
             if save_local_tmp_storage:
                 self._save_pickle_data(data)
         return data
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
index a26e76a0..dd0ca99d 100644
--- a/test/test_data_handling/test_data_distributor.py
+++ b/test/test_data_handling/test_data_distributor.py
@@ -46,7 +46,7 @@ class TestDistributor:
         distributor.model = 1
 
     def test_get_number_of_mini_batches(self, distributor):
-        values = np.zeros((2, 2311, 19))
+        values = np.zeros((2311, 19))
         assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)
 
     def test_distribute_on_batches_single_loop(self,  generator_two_stations, model):
-- 
GitLab