diff --git a/src/data_generator.py b/src/data_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..0b092a0023461bb010361c21d38f6d919b87b753 --- /dev/null +++ b/src/data_generator.py @@ -0,0 +1,68 @@ +__author__ = 'Felix Kleinert, Lukas Leufen' +__date__ = '2019-11-07' + + +import keras +from src import helpers +import os +from typing import Union, List +import decimal +import numpy as np + + +class DataGenerator(keras.utils.Sequence): + + """ + This class is a generator to handle large arrays for machine learning. This class can be used with keras' + fit_generator and predict_generator. Individual stations are the iterables. This class uses class Dataprep and + returns X, y when an item is called. + Item can be called manually by position (integer) or station id (string). Methods also accept lists with exactly + one entry of integer or string + """ + + def __init__(self, path: str, network: str, stations: Union[str, List[str]], variables: List[str], dim: str, + target_dim: str, target_var: str, **kwargs): + self.path = os.path.abspath(path) + self.network = network + self.stations = helpers.to_list(stations) + self.variables = variables + self.dim = dim + self.target_dim = target_dim + self.target_var = target_var + self.kwargs = kwargs + self.threshold = self.threshold_setup() + + def __repr__(self): + """ + display all class attributes + """ + return f"DataGenerator(path='{self.path}', network='{self.network}', stations={self.stations}, "\ + f"variables={self.variables}, dim='{self.dim}', target_dim='{self.target_dim}', target_var='" \ + f"{self.target_var}', **{self.kwargs})" + + def __len__(self): + """ + display the number of stations + """ + return len(self.stations) + + def __iter__(self): + self.iterator = 0 + return self + + def __next__(self): + raise NotImplementedError + + def __getitem__(self, item): + raise NotImplementedError + + def threshold_setup(self) -> List[str]: + """ + set threshold for given min/max and number of steps. defaults are [0, 100] with n=200 steps + :return: + """ + thr_min = self.kwargs.get('thr_min', 0) + thr_max = self.kwargs.get('thr_max', 100) + thr_number_of_steps = self.kwargs.get('thr_number_of_steps', 200) + return [str(decimal.Decimal("%.4f" % e)) for e in np.linspace(thr_min, thr_max, thr_number_of_steps)] + diff --git a/test/test_data_generator.py b/test/test_data_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..e6be4982307ffb8498a474510b0c04baef0c637b --- /dev/null +++ b/test/test_data_generator.py @@ -0,0 +1,49 @@ +import pytest +import os +from src.data_generator import DataGenerator +import logging +import numpy as np +import xarray as xr +import datetime as dt +import pandas as pd +from operator import itemgetter + + +class TestDataGenerator: + + @pytest.fixture + def gen(self): + return DataGenerator('data', 'UBA', 'DEBW107', ['o3', 'temp'], 'datetime', 'datetime', 'o3') + + def test_init(self, gen): + assert gen.path == os.path.abspath('data') + assert gen.network == 'UBA' + assert gen.stations == ['DEBW107'] + assert gen.variables == ['o3', 'temp'] + assert gen.dim == 'datetime' + assert gen.target_dim == 'datetime' + assert gen.target_var == 'o3' + assert gen.threshold is not None + + def test_repr(self, gen): + path = os.path.join(os.path.dirname(__file__), 'data') + assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\ + f"variables=['o3', 'temp'], dim='datetime', target_dim='datetime', " \ + f"target_var='o3', **{{}})".rstrip() + + def test_len(self, gen): + assert len(gen) == 1 + gen.stations = ['station1', 'station2', 'station3'] + assert len(gen) == 3 + + def test_threshold_setup(self, gen): + def res(arg, val): + gen.kwargs[arg] = val + return list(map(float, gen.threshold_setup())) + compare = np.testing.assert_array_almost_equal + assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None + assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None + assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None + assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None + +