diff --git a/src/data_generator.py b/src/data_generator.py index 0b092a0023461bb010361c21d38f6d919b87b753..a43d4bf9772ba7d311b4502f53963e72d5ff98e4 100644 --- a/src/data_generator.py +++ b/src/data_generator.py @@ -1,9 +1,9 @@ __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-11-07' - import keras from src import helpers +from src.data_preparation import DataPrep import os from typing import Union, List import decimal @@ -11,7 +11,6 @@ import numpy as np class DataGenerator(keras.utils.Sequence): - """ This class is a generator to handle large arrays for machine learning. This class can be used with keras' fit_generator and predict_generator. Individual stations are the iterables. This class uses class Dataprep and @@ -20,15 +19,22 @@ class DataGenerator(keras.utils.Sequence): one entry of integer or string """ - def __init__(self, path: str, network: str, stations: Union[str, List[str]], variables: List[str], dim: str, - target_dim: str, target_var: str, **kwargs): + def __init__(self, path: str, network: str, stations: Union[str, List[str]], variables: List[str], + interpolate_dim: str, target_dim: str, target_var: str, interpolate_method: str = "linear", + limit_nan_fill: int = 1, window_history: int = 7, window_lead_time: int = 4, + transform_method: str = "standardise", **kwargs): self.path = os.path.abspath(path) self.network = network self.stations = helpers.to_list(stations) self.variables = variables - self.dim = dim + self.interpolate_dim = interpolate_dim self.target_dim = target_dim self.target_var = target_var + self.interpolate_method = interpolate_method + self.limit_nan_fill = limit_nan_fill + self.window_history = window_history + self.window_lead_time = window_lead_time + self.transform_method = transform_method self.kwargs = kwargs self.threshold = self.threshold_setup() @@ -36,9 +42,9 @@ class DataGenerator(keras.utils.Sequence): """ display all class attributes """ - return f"DataGenerator(path='{self.path}', network='{self.network}', stations={self.stations}, "\ - f"variables={self.variables}, dim='{self.dim}', target_dim='{self.target_dim}', target_var='" \ - f"{self.target_var}', **{self.kwargs})" + return f"DataGenerator(path='{self.path}', network='{self.network}', stations={self.stations}, " \ + f"variables={self.variables}, interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}'" \ + f", target_var='{self.target_var}', **{self.kwargs})" def __len__(self): """ @@ -51,10 +57,21 @@ class DataGenerator(keras.utils.Sequence): return self def __next__(self): - raise NotImplementedError + if self.iterator < self.__len__(): + data = self.get_data_generator() + self.iterator += 1 + if data.history is not None and data.label is not None: + return data.history.transpose("datetime", "window", "Stations", "variables"), \ + data.label.squeeze("Stations").transpose("datetime", "window") + else: + self.__next__() + else: + raise StopIteration - def __getitem__(self, item): - raise NotImplementedError + def __getitem__(self, item: Union[str, int]): + data = self.get_data_generator(key=item) + return data.history.transpose("datetime", "window", "Stations", "variables"), \ + data.label.squeeze("Stations").transpose("datetime", "window") def threshold_setup(self) -> List[str]: """ @@ -66,3 +83,47 @@ class DataGenerator(keras.utils.Sequence): thr_number_of_steps = self.kwargs.get('thr_number_of_steps', 200) return [str(decimal.Decimal("%.4f" % e)) for e in np.linspace(thr_min, thr_max, thr_number_of_steps)] + def get_data_generator(self, key: Union[str, int] = None) -> DataPrep: + """ + Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and + remove nans. + :param key: + :return: preprocessed data as a DataPrep instance + """ + station = self.get_station_key(key) + data = DataPrep(self.path, self.network, station, self.variables, **self.kwargs) + data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) + data.transform("datetime", method=self.transform_method) + data.make_history_window(self.interpolate_dim, self.window_history) + data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) + data.history_label_nan_remove(self.interpolate_dim) + return data + + def get_station_key(self, key: Union[str, int, List[Union[str, int]]]) -> str: + """ + Return a valid station key or raise KeyError if this wasn't possible + :param key: + :return: + """ + # extract value if given as list + if isinstance(key, list): + if len(key) == 1: + key = key[0] + else: + raise KeyError(f"More than one key was given: {key}") + # return station name either from key or the recent element from iterator + if key is None: + return self.stations[self.iterator] + else: + if isinstance(key, int): + if key < self.__len__(): + return self.stations[key] + else: + raise KeyError(f"{key} is not in range(0, {self.__len__()})") + elif isinstance(key, str): + if key in self.stations: + return key + else: + raise KeyError(f"{key} is not in stations") + else: + raise KeyError(f"Key has to be from Union[str, int]. Given was {key} ({type(key)})") diff --git a/test/test_data_generator.py b/test/test_data_generator.py index e6be4982307ffb8498a474510b0c04baef0c637b..b316fa887ec925b25e3a362100218e1f7ddbfe89 100644 --- a/test/test_data_generator.py +++ b/test/test_data_generator.py @@ -13,29 +13,55 @@ class TestDataGenerator: @pytest.fixture def gen(self): - return DataGenerator('data', 'UBA', 'DEBW107', ['o3', 'temp'], 'datetime', 'datetime', 'o3') + return DataGenerator('data', 'UBA', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3') def test_init(self, gen): assert gen.path == os.path.abspath('data') assert gen.network == 'UBA' assert gen.stations == ['DEBW107'] assert gen.variables == ['o3', 'temp'] - assert gen.dim == 'datetime' - assert gen.target_dim == 'datetime' + assert gen.interpolate_dim == 'datetime' + assert gen.target_dim == 'variables' assert gen.target_var == 'o3' + assert gen.interpolate_method == "linear" + assert gen.limit_nan_fill == 1 + assert gen.window_history == 7 + assert gen.window_lead_time == 4 + assert gen.transform_method == "standardise" + assert gen.kwargs == {} assert gen.threshold is not None def test_repr(self, gen): path = os.path.join(os.path.dirname(__file__), 'data') assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\ - f"variables=['o3', 'temp'], dim='datetime', target_dim='datetime', " \ - f"target_var='o3', **{{}})".rstrip() + f"variables=['o3', 'temp'], interpolate_dim='datetime', " \ + f"target_dim='variables', target_var='o3', **{{}})".rstrip() def test_len(self, gen): assert len(gen) == 1 gen.stations = ['station1', 'station2', 'station3'] assert len(gen) == 3 + def test_iter(self, gen): + assert hasattr(gen, 'iterator') is False + iter(gen) + assert hasattr(gen, 'iterator') + assert gen.iterator == 0 + + def test_next(self, gen): + gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} + for i, d in enumerate(gen, start=1): + assert i == gen.iterator + + def test_getitem(self, gen): + gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} + station = gen["DEBW107"] + assert len(station) == 2 + assert station[0].Stations.data == "DEBW107" + assert station[0].data.shape[1:] == (8, 1, 2) + assert station[1].data.shape[-1] == gen.window_lead_time + assert station[0].data.shape[1] == gen.window_history + 1 + def test_threshold_setup(self, gen): def res(arg, val): gen.kwargs[arg] = val @@ -46,4 +72,24 @@ class TestDataGenerator: assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None - + def test_get_key_representation(self, gen): + gen.stations.append("DEBW108") + f = gen.__iter__().get_station_key + assert f(None) == "DEBW107" + assert f([None]) == "DEBW107" + with pytest.raises(KeyError) as e: + f([None, None]) + assert "More than one key was given: [None, None]" in e.value.args[0] + assert f(1) == "DEBW108" + assert f([1]) == "DEBW108" + with pytest.raises(KeyError) as e: + f(3) + assert "3 is not in range(0, 2)" in e.value.args[0] + assert f("DEBW107") == "DEBW107" + assert f(["DEBW108"]) == "DEBW108" + with pytest.raises(KeyError) as e: + f("DEBW999") + assert "DEBW999 is not in stations" in e.value.args[0] + with pytest.raises(KeyError) as e: + f(6.5) + assert "key has to be from Union[str, int]. Given was 6.5 (float)"