Skip to content
Snippets Groups Projects
Commit 63651846 authored by lukas leufen's avatar lukas leufen
Browse files

data generator and distributor can use upsampling

parent 00d65a75
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!77Resolve "Upsample "extremes" in standardised data space"
Pipeline #32183 passed
......@@ -8,15 +8,18 @@ import math
import keras
import numpy as np
from src.data_handling.data_generator import DataGenerator
class Distributor(keras.utils.Sequence):
def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
permute_data: bool = False):
def __init__(self, generator: DataGenerator, model: keras.models, batch_size: int = 256,
permute_data: bool = False, upsampling: bool = False):
self.generator = generator
self.model = model
self.batch_size = batch_size
self.do_data_permutation = permute_data
self.upsampling = upsampling
def _get_model_rank(self):
mod_out = self.model.output_shape
......@@ -31,7 +34,7 @@ class Distributor(keras.utils.Sequence):
return mod_rank
def _get_number_of_mini_batches(self, values):
return math.ceil(values[0].shape[0] / self.batch_size)
return math.ceil(values.shape[0] / self.batch_size)
def _permute_data(self, x, y):
"""
......@@ -48,10 +51,18 @@ class Distributor(keras.utils.Sequence):
for k, v in enumerate(self.generator):
# get rank of output
mod_rank = self._get_model_rank()
# get number of mini batches
num_mini_batches = self._get_number_of_mini_batches(v)
# get data
x_total = np.copy(v[0])
y_total = np.copy(v[1])
if self.upsampling:
try:
s = self.generator.get_data_generator(k)
x_total = np.concatenate([x_total, np.copy(s.extremes_history.copy())], axis=0)
y_total = np.concatenate([y_total, np.copy(s.extremes_labels.copy())], axis=0)
except AttributeError: # no extremes history / labels available, copy will fail
pass
# get number of mini batches
num_mini_batches = self._get_number_of_mini_batches(x_total)
# permute order for mini-batches
x_total, y_total = self._permute_data(x_total, y_total)
for prev, curr in enumerate(range(1, num_mini_batches+1)):
......
......@@ -14,6 +14,9 @@ from src import helpers
from src.data_handling.data_preparation import DataPrep
from src.join import EmptyQueryResult
number = Union[float, int]
num_or_list = Union[number, List[number]]
class DataGenerator(keras.utils.Sequence):
"""
......@@ -27,7 +30,7 @@ class DataGenerator(keras.utils.Sequence):
def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
window_lead_time: int = 4, transformation: Dict = None, **kwargs):
window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs):
self.data_path = os.path.abspath(data_path)
self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
if not os.path.exists(self.data_path_tmp):
......@@ -43,6 +46,7 @@ class DataGenerator(keras.utils.Sequence):
self.limit_nan_fill = limit_nan_fill
self.window_history_size = window_history_size
self.window_lead_time = window_lead_time
self.extreme_values = extreme_values
self.kwargs = kwargs
self.transformation = self.setup_transformation(transformation)
......@@ -188,6 +192,9 @@ class DataGenerator(keras.utils.Sequence):
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.make_observation(self.target_dim, self.target_var, self.interpolate_dim)
data.remove_nan(self.interpolate_dim)
if self.extreme_values:
kwargs = {"extremes_on_right_tail_only": self.kwargs.get("extremes_on_right_tail_only", False)}
data.multiply_extremes(self.extreme_values, **kwargs)
if save_local_tmp_storage:
self._save_pickle_data(data)
return data
......
......@@ -46,7 +46,7 @@ class TestDistributor:
distributor.model = 1
def test_get_number_of_mini_batches(self, distributor):
values = np.zeros((2, 2311, 19))
values = np.zeros((2311, 19))
assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)
def test_distribute_on_batches_single_loop(self, generator_two_stations, model):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment