From d65b8b783b5a51dadbadca22a5f4684815841ef6 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 5 Feb 2020 11:22:37 +0100 Subject: [PATCH 1/6] change map resolution to accelerate map plot --- src/plotting/postprocessing_plotting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index cd49ddd5..97d326bc 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -141,11 +141,11 @@ class PlotStationMap(RunEnvironment): """ Draw coastline, lakes, ocean, rivers and country borders as background on the map. """ - self._ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black') + self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') self._ax.add_feature(cfeature.LAKES.with_scale("50m")) self._ax.add_feature(cfeature.OCEAN.with_scale("50m")) - self._ax.add_feature(cfeature.RIVERS.with_scale("10m")) - self._ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black') + self._ax.add_feature(cfeature.RIVERS.with_scale("50m")) + self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') def _plot_stations(self, generators): """ -- GitLab From a6dccb6d2c9899f215e20b397d5169089654c8e6 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 5 Feb 2020 11:23:40 +0100 Subject: [PATCH 2/6] first implementation of local tmp storage using pickle --- src/data_handling/data_generator.py | 40 +++++++++++++++++++++++------ src/run_modules/pre_processing.py | 9 ++++--- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 1de0ab20..f259c403 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -7,6 +7,8 @@ from src.data_handling.data_preparation import DataPrep import os from typing import Union, List, Tuple import xarray as xr +import pickle +import logging class DataGenerator(keras.utils.Sequence): @@ -23,6 +25,9 @@ class DataGenerator(keras.utils.Sequence): interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, window_lead_time: int = 4, transform_method: str = "standardise", **kwargs): self.data_path = os.path.abspath(data_path) + self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") + if not os.path.exists(self.data_path_tmp): + os.makedirs(self.data_path_tmp) self.network = network self.stations = helpers.to_list(stations) self.variables = variables @@ -88,7 +93,7 @@ class DataGenerator(keras.utils.Sequence): return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") - def get_data_generator(self, key: Union[str, int] = None) -> DataPrep: + def get_data_generator(self, key: Union[str, int] = None, load_tmp: bool = True) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and remove nans. @@ -96,13 +101,32 @@ class DataGenerator(keras.utils.Sequence): :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) - data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) - data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) - data.transform("datetime", method=self.transform_method) - data.make_history_window(self.interpolate_dim, self.window_history_size) - data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) - data.history_label_nan_remove(self.interpolate_dim) + try: + if not load_tmp: + raise FileNotFoundError + data = self._load_pickle_data(station, self.variables) + except FileNotFoundError: + logging.info(f"load not pickle data for {station}") + data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) + data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) + data.transform("datetime", method=self.transform_method) + data.make_history_window(self.interpolate_dim, self.window_history_size) + data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) + data.history_label_nan_remove(self.interpolate_dim) + self._save_pickle_data(data) + return data + + def _save_pickle_data(self, data): + file = os.path.join(self.data_path_tmp, f"{''.join(data.station)}_{'_'.join(sorted(data.variables))}.pickle") + with open(file, "wb") as f: + pickle.dump(data, f) + logging.debug(f"save pickle data to {file}") + + def _load_pickle_data(self, station, variables): + file = os.path.join(self.data_path_tmp, f"{''.join(station)}_{'_'.join(sorted(variables))}.pickle") + data = pickle.load(open(file, "rb")) + logging.debug(f"load pickle data from {file}") return data def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str: diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 2a4632d5..2f8b2777 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -36,7 +36,7 @@ class PreProcessing(RunEnvironment): def _run(self): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing") kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing") - valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general")) + valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False) self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() self.report_pre_processing() @@ -97,7 +97,7 @@ class PreProcessing(RunEnvironment): self.data_store.set("generator", data_set, scope) @staticmethod - def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]): + def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True): """ Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. @@ -118,9 +118,10 @@ class PreProcessing(RunEnvironment): for station in all_stations: t_inner.run() try: - (history, label) = data_gen[station] + # (history, label) = data_gen[station] + data = data_gen.get_data_generator(key=station, load_tmp=load_tmp) valid_stations.append(station) - logging.debug(f"{station}: history_shape = {history.shape}") + logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}") except (AttributeError, EmptyQueryResult): continue -- GitLab From af1ecb8a390e48f0a479204b5e1fdcd4e87ace95 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 5 Feb 2020 11:43:34 +0100 Subject: [PATCH 3/6] minor modifications, add docs --- src/data_handling/data_generator.py | 25 +++++++++++++++++++------ src/run_modules/pre_processing.py | 2 +- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index f259c403..26b12d59 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -5,7 +5,7 @@ import keras from src import helpers from src.data_handling.data_preparation import DataPrep import os -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Any import xarray as xr import pickle import logging @@ -93,16 +93,18 @@ class DataGenerator(keras.utils.Sequence): return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") - def get_data_generator(self, key: Union[str, int] = None, load_tmp: bool = True) -> DataPrep: + def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and remove nans. :param key: station key to choose the data generator. + :param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from + tmp pickle file to save computational time (but of course more disk space required). :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) try: - if not load_tmp: + if not local_tmp_storage: raise FileNotFoundError data = self._load_pickle_data(station, self.variables) except FileNotFoundError: @@ -117,15 +119,26 @@ class DataGenerator(keras.utils.Sequence): self._save_pickle_data(data) return data - def _save_pickle_data(self, data): + def _save_pickle_data(self, data: Any): + """ + Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle' + :param data: any data, that should be saved + """ file = os.path.join(self.data_path_tmp, f"{''.join(data.station)}_{'_'.join(sorted(data.variables))}.pickle") with open(file, "wb") as f: pickle.dump(data, f) logging.debug(f"save pickle data to {file}") - def _load_pickle_data(self, station, variables): + def _load_pickle_data(self, station: Union[str, List[str]], variables: List[str]) -> Any: + """ + Load locally saved data from self.data_path_tmp and name '<station>_<var1>_<var2>_..._<varX>.pickle'. + :param station: station to load + :param variables: list of variables to load + :return: loaded data + """ file = os.path.join(self.data_path_tmp, f"{''.join(station)}_{'_'.join(sorted(variables))}.pickle") - data = pickle.load(open(file, "rb")) + with open(file, "rb") as f: + data = pickle.load(f) logging.debug(f"load pickle data from {file}") return data diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 2f8b2777..5dc61738 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -119,7 +119,7 @@ class PreProcessing(RunEnvironment): t_inner.run() try: # (history, label) = data_gen[station] - data = data_gen.get_data_generator(key=station, load_tmp=load_tmp) + data = data_gen.get_data_generator(key=station, local_tmp_storage=load_tmp) valid_stations.append(station) logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}") -- GitLab From 1754830cc60e6c8cb5b1bf365f032e17baa359c6 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Thu, 6 Feb 2020 09:27:38 +0100 Subject: [PATCH 4/6] update on data prep tests --- src/data_handling/data_preparation.py | 2 +- .../test_data_preparation.py | 106 ++++++++++++++++-- 2 files changed, 97 insertions(+), 11 deletions(-) diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index d0d89438..c39625b1 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -108,7 +108,7 @@ class DataPrep(object): check_dict = {"station_type": self.station_type, "network_name": self.network} for (k, v) in check_dict.items(): if self.meta.at[k, self.station[0]] != v: - logging.debug(f"meta data does not agree which given request for {k}: {v} (requested) != " + logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " f"grapping from web.") raise FileNotFoundError diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 12b619d9..d67b8add 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -7,6 +7,8 @@ import xarray as xr import datetime as dt import pandas as pd from operator import itemgetter +import logging +from src.helpers import PyTestRegex class TestDataPrep: @@ -17,6 +19,17 @@ class TestDataPrep: station_type='background', test='testKWARGS', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + @pytest.fixture + def data_prep_no_init(self): + d = object.__new__(DataPrep) + d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') + d.network = 'UBA' + d.station = ['DEBW107'] + d.variables = ['o3', 'temp'] + d.station_type = "background" + d.kwargs = None + return d + def test_init(self, data): assert data.path == os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') assert data.network == 'AIRBASE' @@ -31,16 +44,79 @@ class TestDataPrep: with pytest.raises(NotImplementedError): DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp']) - def test_repr(self): - d = object.__new__(DataPrep) - d.path = 'data/test' - d.network = 'dummy' - d.station = ['DEBW107'] - d.variables = ['o3', 'temp'] - d.station_type = "traffic" - d.kwargs = None - assert d.__repr__().rstrip() == "Dataprep(path='data/test', network='dummy', station=['DEBW107'], "\ - "variables=['o3', 'temp'], station_type=traffic, **None)".rstrip() + def test_download_data(self, data_prep_no_init): + file_name = data_prep_no_init._set_file_name() + meta_file = data_prep_no_init._set_meta_file_name() + data_prep_no_init.kwargs = {"store_data_locally": False} + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + data_prep_no_init.download_data(file_name, meta_file) + assert isinstance(data_prep_no_init.data, xr.DataArray) + + def test_download_data_from_join(self, data_prep_no_init): + file_name = data_prep_no_init._set_file_name() + meta_file = data_prep_no_init._set_meta_file_name() + data_prep_no_init.kwargs = {"store_data_locally": False} + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + xarr, meta = data_prep_no_init.download_data_from_join(file_name, meta_file) + assert isinstance(xarr, xr.DataArray) + assert isinstance(meta, pd.DataFrame) + + def test_check_station_meta(self, caplog, data_prep_no_init): + caplog.set_level(logging.DEBUG) + file_name = data_prep_no_init._set_file_name() + meta_file = data_prep_no_init._set_meta_file_name() + data_prep_no_init.kwargs = {"store_data_locally": False} + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + data_prep_no_init.download_data(file_name, meta_file) + assert data_prep_no_init.check_station_meta() is None + data_prep_no_init.station_type = "traffic" + with pytest.raises(FileNotFoundError) as e: + data_prep_no_init.check_station_meta() + msg = "meta data does not agree with given request for station_type: traffic (requested) != background (local)" + assert caplog.record_tuples[-1][:-1] == ('root', 10) + assert msg in caplog.record_tuples[-1][-1] + + def test_load_data_overwrite_local_data(self, data_prep_no_init): + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + file_path = data_prep_no_init._set_file_name() + meta_file_path = data_prep_no_init._set_meta_file_name() + os.remove(file_path) + os.remove(meta_file_path) + assert not os.path.exists(file_path) + assert not os.path.exists(meta_file_path) + data_prep_no_init.kwargs = {"overwrite_local_data": True} + data_prep_no_init.load_data() + assert os.path.exists(file_path) + assert os.path.exists(meta_file_path) + t = os.stat(file_path).st_ctime + tm = os.stat(meta_file_path).st_ctime + data_prep_no_init.load_data() + assert os.path.exists(file_path) + assert os.path.exists(meta_file_path) + assert os.stat(file_path).st_ctime > t + assert os.stat(meta_file_path).st_ctime > tm + assert isinstance(data_prep_no_init.data, xr.DataArray) + assert isinstance(data_prep_no_init.meta, pd.DataFrame) + + def test_load_data_keep_local_data(self, data_prep_no_init): + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + data_prep_no_init.station_type = None + data_prep_no_init.kwargs = {} + file_path = data_prep_no_init._set_file_name() + data_prep_no_init.load_data() + assert os.path.exists(file_path) + t = os.stat(file_path).st_ctime + data_prep_no_init.load_data() + assert os.path.exists(data_prep_no_init._set_file_name()) + assert os.stat(file_path).st_ctime == t + assert isinstance(data_prep_no_init.data, xr.DataArray) + assert isinstance(data_prep_no_init.meta, pd.DataFrame) + + def test_repr(self, data_prep_no_init): + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') + assert data_prep_no_init.__repr__().rstrip() == f"Dataprep(path='{path}', network='UBA', " \ + f"station=['DEBW107'], variables=['o3', 'temp'], " \ + f"station_type=background, **None)".rstrip() def test_set_file_name_and_meta(self): d = object.__new__(DataPrep) @@ -133,6 +209,16 @@ class TestDataPrep: with pytest.raises(NotImplementedError): data.inverse_transform() + def test_get_transformation_information(self, data): + assert (None, None, None) == data.get_transformation_information("o3") + mean_test = data.data.mean("datetime").sel(variables='o3').values + std_test = data.data.std("datetime").sel(variables='o3').values + data.transform('datetime') + mean, std, info = data.get_transformation_information("o3") + assert np.testing.assert_almost_equal(mean, mean_test) is None + assert np.testing.assert_almost_equal(std, std_test) is None + assert info == "standardise" + def test_nan_remove_no_hist_or_label(self, data): assert data.history is None assert data.label is None -- GitLab From edc5e931c1cd62b5f4487f9a17799738fe8f6c19 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Thu, 6 Feb 2020 10:36:04 +0100 Subject: [PATCH 5/6] updated generator tests --- .../test_data_handling/test_data_generator.py | 62 ++++++++++++++++--- 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 879436af..34cc60d7 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -1,7 +1,10 @@ import pytest import os import shutil +import numpy as np +import pickle from src.data_handling.data_generator import DataGenerator +from src.data_handling.data_preparation import DataPrep class TestDataGenerator: @@ -17,6 +20,12 @@ class TestDataGenerator: return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3') + class DummyDataPrep: + def __init__(self, data): + self.station = "DEBW107" + self.variables = ["o3", "temp"] + self.data = data + def test_init(self, gen): assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data') assert gen.network == 'AIRBASE' @@ -44,15 +53,6 @@ class TestDataGenerator: gen.stations = ['station1', 'station2', 'station3'] assert len(gen) == 3 - def test_getitem(self, gen): - gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} - station = gen["DEBW107"] - assert len(station) == 2 - assert station[0].Stations.data == "DEBW107" - assert station[0].data.shape[1:] == (8, 1, 2) - assert station[1].data.shape[-1] == gen.window_lead_time - assert station[0].data.shape[1] == gen.window_history_size + 1 - def test_iter(self, gen): assert hasattr(gen, '_iterator') is False iter(gen) @@ -64,6 +64,15 @@ class TestDataGenerator: for i, d in enumerate(gen, start=1): assert i == gen._iterator + def test_getitem(self, gen): + gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} + station = gen["DEBW107"] + assert len(station) == 2 + assert station[0].Stations.data == "DEBW107" + assert station[0].data.shape[1:] == (8, 1, 2) + assert station[1].data.shape[-1] == gen.window_lead_time + assert station[0].data.shape[1] == gen.window_history_size + 1 + def test_get_station_key(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key @@ -85,3 +94,38 @@ class TestDataGenerator: with pytest.raises(KeyError) as e: f(6.5) assert "key has to be from Union[str, int]. Given was 6.5 (float)" + + def test_get_data_generator(self, gen): + gen.kwargs = {"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} + file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle") + if os.path.exists(file): + os.remove(file) + assert not os.path.exists(file) + assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep) + t = os.stat(file).st_ctime + assert os.path.exists(file) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert os.stat(file).st_mtime == t + os.remove(file) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert os.stat(file).st_ctime > t + + def test_save_pickle_data(self, gen): + file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle") + if os.path.exists(file): + os.remove(file) + assert not os.path.exists(file) + data = self.DummyDataPrep(np.ones((10, 2))) + gen._save_pickle_data(data) + assert os.path.exists(file) + os.remove(file) + + def test_load_pickle_data(self, gen): + file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle") + data = self.DummyDataPrep(np.ones((10, 2))) + with open(file, "wb") as f: + pickle.dump(data, f) + assert os.path.exists(file) + res = gen._load_pickle_data("DEBW107", ["o3", "temp"]).data + assert np.testing.assert_almost_equal(res, np.ones((10, 2))) is None + os.remove(file) -- GitLab From 4c0c54dbcddfcdb51e6ed9915a3d04129ef67189 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Thu, 6 Feb 2020 11:13:18 +0100 Subject: [PATCH 6/6] more tests --- src/data_handling/data_distributor.py | 2 +- src/data_handling/data_generator.py | 4 ++-- .../test_keras_extensions.py | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index 74df5f6a..c6f38a6f 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -45,7 +45,7 @@ class Distributor(keras.utils.Sequence): for prev, curr in enumerate(range(1, num_mini_batches+1)): x = x_total[prev*self.batch_size:curr*self.batch_size, ...] y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)] - if x is not None: + if x is not None: # pragma: no branch yield (x, y) if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call: return diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 26b12d59..732a7efd 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -75,11 +75,11 @@ class DataGenerator(keras.utils.Sequence): if self._iterator < self.__len__(): data = self.get_data_generator() self._iterator += 1 - if data.history is not None and data.label is not None: + if data.history is not None and data.label is not None: # pragma: no branch return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") else: - self.__next__() + self.__next__() # pragma: no cover else: raise StopIteration diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py index c50e5e42..7c32844d 100644 --- a/test/test_model_modules/test_keras_extensions.py +++ b/test/test_model_modules/test_keras_extensions.py @@ -5,6 +5,24 @@ import keras import numpy as np +class TestHistoryAdvanced: + + def test_init(self): + hist = HistoryAdvanced() + assert hist.validation_data is None + assert hist.model is None + assert isinstance(hist.epoch, list) and len(hist.epoch) == 0 + assert isinstance(hist.history, dict) and len(hist.history.keys()) == 0 + + def test_on_train_begin(self): + hist = HistoryAdvanced() + hist.epoch = [1, 2, 3] + hist.history = {"mse": [10, 7, 4]} + hist.on_train_begin() + assert hist.epoch == [1, 2, 3] + assert hist.history == {"mse": [10, 7, 4]} + + class TestLearningRateDecay: def test_init(self): -- GitLab