From d65b8b783b5a51dadbadca22a5f4684815841ef6 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Wed, 5 Feb 2020 11:22:37 +0100
Subject: [PATCH 1/6] change map resolution to accelerate map plot

---
 src/plotting/postprocessing_plotting.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py
index cd49ddd5..97d326bc 100644
--- a/src/plotting/postprocessing_plotting.py
+++ b/src/plotting/postprocessing_plotting.py
@@ -141,11 +141,11 @@ class PlotStationMap(RunEnvironment):
         """
         Draw coastline, lakes, ocean, rivers and country borders as background on the map.
         """
-        self._ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black')
+        self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
         self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
         self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
-        self._ax.add_feature(cfeature.RIVERS.with_scale("10m"))
-        self._ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black')
+        self._ax.add_feature(cfeature.RIVERS.with_scale("50m"))
+        self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
 
     def _plot_stations(self, generators):
         """
-- 
GitLab


From a6dccb6d2c9899f215e20b397d5169089654c8e6 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Wed, 5 Feb 2020 11:23:40 +0100
Subject: [PATCH 2/6] first implementation of local tmp storage using pickle

---
 src/data_handling/data_generator.py | 40 +++++++++++++++++++++++------
 src/run_modules/pre_processing.py   |  9 ++++---
 2 files changed, 37 insertions(+), 12 deletions(-)

diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py
index 1de0ab20..f259c403 100644
--- a/src/data_handling/data_generator.py
+++ b/src/data_handling/data_generator.py
@@ -7,6 +7,8 @@ from src.data_handling.data_preparation import DataPrep
 import os
 from typing import Union, List, Tuple
 import xarray as xr
+import pickle
+import logging
 
 
 class DataGenerator(keras.utils.Sequence):
@@ -23,6 +25,9 @@ class DataGenerator(keras.utils.Sequence):
                  interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
                  window_lead_time: int = 4, transform_method: str = "standardise", **kwargs):
         self.data_path = os.path.abspath(data_path)
+        self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
+        if not os.path.exists(self.data_path_tmp):
+            os.makedirs(self.data_path_tmp)
         self.network = network
         self.stations = helpers.to_list(stations)
         self.variables = variables
@@ -88,7 +93,7 @@ class DataGenerator(keras.utils.Sequence):
         return data.history.transpose("datetime", "window", "Stations", "variables"), \
             data.label.squeeze("Stations").transpose("datetime", "window")
 
-    def get_data_generator(self, key: Union[str, int] = None) -> DataPrep:
+    def get_data_generator(self, key: Union[str, int] = None, load_tmp: bool = True) -> DataPrep:
         """
         Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
         remove nans.
@@ -96,13 +101,32 @@ class DataGenerator(keras.utils.Sequence):
         :return: preprocessed data as a DataPrep instance
         """
         station = self.get_station_key(key)
-        data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
-                        **self.kwargs)
-        data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
-        data.transform("datetime", method=self.transform_method)
-        data.make_history_window(self.interpolate_dim, self.window_history_size)
-        data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
-        data.history_label_nan_remove(self.interpolate_dim)
+        try:
+            if not load_tmp:
+                raise FileNotFoundError
+            data = self._load_pickle_data(station, self.variables)
+        except FileNotFoundError:
+            logging.info(f"load not pickle data for {station}")
+            data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
+                            **self.kwargs)
+            data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
+            data.transform("datetime", method=self.transform_method)
+            data.make_history_window(self.interpolate_dim, self.window_history_size)
+            data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
+            data.history_label_nan_remove(self.interpolate_dim)
+            self._save_pickle_data(data)
+        return data
+
+    def _save_pickle_data(self, data):
+        file = os.path.join(self.data_path_tmp, f"{''.join(data.station)}_{'_'.join(sorted(data.variables))}.pickle")
+        with open(file, "wb") as f:
+            pickle.dump(data, f)
+        logging.debug(f"save pickle data to {file}")
+
+    def _load_pickle_data(self, station, variables):
+        file = os.path.join(self.data_path_tmp, f"{''.join(station)}_{'_'.join(sorted(variables))}.pickle")
+        data = pickle.load(open(file, "rb"))
+        logging.debug(f"load pickle data from {file}")
         return data
 
     def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str:
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index 2a4632d5..2f8b2777 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -36,7 +36,7 @@ class PreProcessing(RunEnvironment):
     def _run(self):
         args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing")
         kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing")
-        valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"))
+        valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False)
         self.data_store.set("stations", valid_stations, "general")
         self.split_train_val_test()
         self.report_pre_processing()
@@ -97,7 +97,7 @@ class PreProcessing(RunEnvironment):
         self.data_store.set("generator", data_set, scope)
 
     @staticmethod
-    def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]):
+    def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True):
         """
         Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
         time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
@@ -118,9 +118,10 @@ class PreProcessing(RunEnvironment):
         for station in all_stations:
             t_inner.run()
             try:
-                (history, label) = data_gen[station]
+                # (history, label) = data_gen[station]
+                data = data_gen.get_data_generator(key=station, load_tmp=load_tmp)
                 valid_stations.append(station)
-                logging.debug(f"{station}: history_shape = {history.shape}")
+                logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
                 logging.debug(f"{station}: loading time = {t_inner}")
             except (AttributeError, EmptyQueryResult):
                 continue
-- 
GitLab


From af1ecb8a390e48f0a479204b5e1fdcd4e87ace95 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Wed, 5 Feb 2020 11:43:34 +0100
Subject: [PATCH 3/6] minor modifications, add docs

---
 src/data_handling/data_generator.py | 25 +++++++++++++++++++------
 src/run_modules/pre_processing.py   |  2 +-
 2 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py
index f259c403..26b12d59 100644
--- a/src/data_handling/data_generator.py
+++ b/src/data_handling/data_generator.py
@@ -5,7 +5,7 @@ import keras
 from src import helpers
 from src.data_handling.data_preparation import DataPrep
 import os
-from typing import Union, List, Tuple
+from typing import Union, List, Tuple, Any
 import xarray as xr
 import pickle
 import logging
@@ -93,16 +93,18 @@ class DataGenerator(keras.utils.Sequence):
         return data.history.transpose("datetime", "window", "Stations", "variables"), \
             data.label.squeeze("Stations").transpose("datetime", "window")
 
-    def get_data_generator(self, key: Union[str, int] = None, load_tmp: bool = True) -> DataPrep:
+    def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep:
         """
         Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
         remove nans.
         :param key: station key to choose the data generator.
+        :param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from
+            tmp pickle file to save computational time (but of course more disk space required).
         :return: preprocessed data as a DataPrep instance
         """
         station = self.get_station_key(key)
         try:
-            if not load_tmp:
+            if not local_tmp_storage:
                 raise FileNotFoundError
             data = self._load_pickle_data(station, self.variables)
         except FileNotFoundError:
@@ -117,15 +119,26 @@ class DataGenerator(keras.utils.Sequence):
             self._save_pickle_data(data)
         return data
 
-    def _save_pickle_data(self, data):
+    def _save_pickle_data(self, data: Any):
+        """
+        Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle'
+        :param data: any data, that should be saved
+        """
         file = os.path.join(self.data_path_tmp, f"{''.join(data.station)}_{'_'.join(sorted(data.variables))}.pickle")
         with open(file, "wb") as f:
             pickle.dump(data, f)
         logging.debug(f"save pickle data to {file}")
 
-    def _load_pickle_data(self, station, variables):
+    def _load_pickle_data(self, station: Union[str, List[str]], variables: List[str]) -> Any:
+        """
+        Load locally saved data from self.data_path_tmp and name '<station>_<var1>_<var2>_..._<varX>.pickle'.
+        :param station: station to load
+        :param variables: list of variables to load
+        :return: loaded data
+        """
         file = os.path.join(self.data_path_tmp, f"{''.join(station)}_{'_'.join(sorted(variables))}.pickle")
-        data = pickle.load(open(file, "rb"))
+        with open(file, "rb") as f:
+            data = pickle.load(f)
         logging.debug(f"load pickle data from {file}")
         return data
 
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index 2f8b2777..5dc61738 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -119,7 +119,7 @@ class PreProcessing(RunEnvironment):
             t_inner.run()
             try:
                 # (history, label) = data_gen[station]
-                data = data_gen.get_data_generator(key=station, load_tmp=load_tmp)
+                data = data_gen.get_data_generator(key=station, local_tmp_storage=load_tmp)
                 valid_stations.append(station)
                 logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
                 logging.debug(f"{station}: loading time = {t_inner}")
-- 
GitLab


From 1754830cc60e6c8cb5b1bf365f032e17baa359c6 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Thu, 6 Feb 2020 09:27:38 +0100
Subject: [PATCH 4/6] update on data prep tests

---
 src/data_handling/data_preparation.py         |   2 +-
 .../test_data_preparation.py                  | 106 ++++++++++++++++--
 2 files changed, 97 insertions(+), 11 deletions(-)

diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index d0d89438..c39625b1 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -108,7 +108,7 @@ class DataPrep(object):
         check_dict = {"station_type": self.station_type, "network_name": self.network}
         for (k, v) in check_dict.items():
             if self.meta.at[k, self.station[0]] != v:
-                logging.debug(f"meta data does not agree which given request for {k}: {v} (requested) != "
+                logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
                               f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new "
                               f"grapping from web.")
                 raise FileNotFoundError
diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py
index 12b619d9..d67b8add 100644
--- a/test/test_data_handling/test_data_preparation.py
+++ b/test/test_data_handling/test_data_preparation.py
@@ -7,6 +7,8 @@ import xarray as xr
 import datetime as dt
 import pandas as pd
 from operator import itemgetter
+import logging
+from src.helpers import PyTestRegex
 
 
 class TestDataPrep:
@@ -17,6 +19,17 @@ class TestDataPrep:
                         station_type='background', test='testKWARGS',
                         statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
 
+    @pytest.fixture
+    def data_prep_no_init(self):
+        d = object.__new__(DataPrep)
+        d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
+        d.network = 'UBA'
+        d.station = ['DEBW107']
+        d.variables = ['o3', 'temp']
+        d.station_type = "background"
+        d.kwargs = None
+        return d
+
     def test_init(self, data):
         assert data.path == os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
         assert data.network == 'AIRBASE'
@@ -31,16 +44,79 @@ class TestDataPrep:
         with pytest.raises(NotImplementedError):
             DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'])
 
-    def test_repr(self):
-        d = object.__new__(DataPrep)
-        d.path = 'data/test'
-        d.network = 'dummy'
-        d.station = ['DEBW107']
-        d.variables = ['o3', 'temp']
-        d.station_type = "traffic"
-        d.kwargs = None
-        assert d.__repr__().rstrip() == "Dataprep(path='data/test', network='dummy', station=['DEBW107'], "\
-                                        "variables=['o3', 'temp'], station_type=traffic, **None)".rstrip()
+    def test_download_data(self, data_prep_no_init):
+        file_name = data_prep_no_init._set_file_name()
+        meta_file = data_prep_no_init._set_meta_file_name()
+        data_prep_no_init.kwargs = {"store_data_locally": False}
+        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
+        data_prep_no_init.download_data(file_name, meta_file)
+        assert isinstance(data_prep_no_init.data, xr.DataArray)
+
+    def test_download_data_from_join(self, data_prep_no_init):
+        file_name = data_prep_no_init._set_file_name()
+        meta_file = data_prep_no_init._set_meta_file_name()
+        data_prep_no_init.kwargs = {"store_data_locally": False}
+        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
+        xarr, meta = data_prep_no_init.download_data_from_join(file_name, meta_file)
+        assert isinstance(xarr, xr.DataArray)
+        assert isinstance(meta, pd.DataFrame)
+
+    def test_check_station_meta(self, caplog, data_prep_no_init):
+        caplog.set_level(logging.DEBUG)
+        file_name = data_prep_no_init._set_file_name()
+        meta_file = data_prep_no_init._set_meta_file_name()
+        data_prep_no_init.kwargs = {"store_data_locally": False}
+        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
+        data_prep_no_init.download_data(file_name, meta_file)
+        assert data_prep_no_init.check_station_meta() is None
+        data_prep_no_init.station_type = "traffic"
+        with pytest.raises(FileNotFoundError) as e:
+            data_prep_no_init.check_station_meta()
+        msg = "meta data does not agree with given request for station_type: traffic (requested) != background (local)"
+        assert caplog.record_tuples[-1][:-1] == ('root', 10)
+        assert msg in caplog.record_tuples[-1][-1]
+
+    def test_load_data_overwrite_local_data(self, data_prep_no_init):
+        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
+        file_path = data_prep_no_init._set_file_name()
+        meta_file_path = data_prep_no_init._set_meta_file_name()
+        os.remove(file_path)
+        os.remove(meta_file_path)
+        assert not os.path.exists(file_path)
+        assert not os.path.exists(meta_file_path)
+        data_prep_no_init.kwargs = {"overwrite_local_data": True}
+        data_prep_no_init.load_data()
+        assert os.path.exists(file_path)
+        assert os.path.exists(meta_file_path)
+        t = os.stat(file_path).st_ctime
+        tm = os.stat(meta_file_path).st_ctime
+        data_prep_no_init.load_data()
+        assert os.path.exists(file_path)
+        assert os.path.exists(meta_file_path)
+        assert os.stat(file_path).st_ctime > t
+        assert os.stat(meta_file_path).st_ctime > tm
+        assert isinstance(data_prep_no_init.data, xr.DataArray)
+        assert isinstance(data_prep_no_init.meta, pd.DataFrame)
+
+    def test_load_data_keep_local_data(self, data_prep_no_init):
+        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
+        data_prep_no_init.station_type = None
+        data_prep_no_init.kwargs = {}
+        file_path = data_prep_no_init._set_file_name()
+        data_prep_no_init.load_data()
+        assert os.path.exists(file_path)
+        t = os.stat(file_path).st_ctime
+        data_prep_no_init.load_data()
+        assert os.path.exists(data_prep_no_init._set_file_name())
+        assert os.stat(file_path).st_ctime == t
+        assert isinstance(data_prep_no_init.data, xr.DataArray)
+        assert isinstance(data_prep_no_init.meta, pd.DataFrame)
+
+    def test_repr(self, data_prep_no_init):
+        path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
+        assert data_prep_no_init.__repr__().rstrip() == f"Dataprep(path='{path}', network='UBA', " \
+                                                        f"station=['DEBW107'], variables=['o3', 'temp'], " \
+                                                        f"station_type=background, **None)".rstrip()
 
     def test_set_file_name_and_meta(self):
         d = object.__new__(DataPrep)
@@ -133,6 +209,16 @@ class TestDataPrep:
         with pytest.raises(NotImplementedError):
             data.inverse_transform()
 
+    def test_get_transformation_information(self, data):
+        assert (None, None, None) == data.get_transformation_information("o3")
+        mean_test = data.data.mean("datetime").sel(variables='o3').values
+        std_test = data.data.std("datetime").sel(variables='o3').values
+        data.transform('datetime')
+        mean, std, info = data.get_transformation_information("o3")
+        assert np.testing.assert_almost_equal(mean, mean_test) is None
+        assert np.testing.assert_almost_equal(std, std_test) is None
+        assert info == "standardise"
+
     def test_nan_remove_no_hist_or_label(self, data):
         assert data.history is None
         assert data.label is None
-- 
GitLab


From edc5e931c1cd62b5f4487f9a17799738fe8f6c19 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Thu, 6 Feb 2020 10:36:04 +0100
Subject: [PATCH 5/6] updated generator tests

---
 .../test_data_handling/test_data_generator.py | 62 ++++++++++++++++---
 1 file changed, 53 insertions(+), 9 deletions(-)

diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py
index 879436af..34cc60d7 100644
--- a/test/test_data_handling/test_data_generator.py
+++ b/test/test_data_handling/test_data_generator.py
@@ -1,7 +1,10 @@
 import pytest
 import os
 import shutil
+import numpy as np
+import pickle
 from src.data_handling.data_generator import DataGenerator
+from src.data_handling.data_preparation import DataPrep
 
 
 class TestDataGenerator:
@@ -17,6 +20,12 @@ class TestDataGenerator:
         return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
                              'datetime', 'variables', 'o3')
 
+    class DummyDataPrep:
+        def __init__(self, data):
+            self.station = "DEBW107"
+            self.variables = ["o3", "temp"]
+            self.data = data
+
     def test_init(self, gen):
         assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data')
         assert gen.network == 'AIRBASE'
@@ -44,15 +53,6 @@ class TestDataGenerator:
         gen.stations = ['station1', 'station2', 'station3']
         assert len(gen) == 3
 
-    def test_getitem(self, gen):
-        gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
-        station = gen["DEBW107"]
-        assert len(station) == 2
-        assert station[0].Stations.data == "DEBW107"
-        assert station[0].data.shape[1:] == (8, 1, 2)
-        assert station[1].data.shape[-1] == gen.window_lead_time
-        assert station[0].data.shape[1] == gen.window_history_size + 1
-
     def test_iter(self, gen):
         assert hasattr(gen, '_iterator') is False
         iter(gen)
@@ -64,6 +64,15 @@ class TestDataGenerator:
         for i, d in enumerate(gen, start=1):
             assert i == gen._iterator
 
+    def test_getitem(self, gen):
+        gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
+        station = gen["DEBW107"]
+        assert len(station) == 2
+        assert station[0].Stations.data == "DEBW107"
+        assert station[0].data.shape[1:] == (8, 1, 2)
+        assert station[1].data.shape[-1] == gen.window_lead_time
+        assert station[0].data.shape[1] == gen.window_history_size + 1
+
     def test_get_station_key(self, gen):
         gen.stations.append("DEBW108")
         f = gen.get_station_key
@@ -85,3 +94,38 @@ class TestDataGenerator:
         with pytest.raises(KeyError) as e:
             f(6.5)
         assert "key has to be from Union[str, int]. Given was 6.5 (float)"
+
+    def test_get_data_generator(self, gen):
+        gen.kwargs = {"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}}
+        file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle")
+        if os.path.exists(file):
+            os.remove(file)
+        assert not os.path.exists(file)
+        assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep)
+        t = os.stat(file).st_ctime
+        assert os.path.exists(file)
+        assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)
+        assert os.stat(file).st_mtime == t
+        os.remove(file)
+        assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)
+        assert os.stat(file).st_ctime > t
+
+    def test_save_pickle_data(self, gen):
+        file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle")
+        if os.path.exists(file):
+            os.remove(file)
+        assert not os.path.exists(file)
+        data = self.DummyDataPrep(np.ones((10, 2)))
+        gen._save_pickle_data(data)
+        assert os.path.exists(file)
+        os.remove(file)
+
+    def test_load_pickle_data(self, gen):
+        file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle")
+        data = self.DummyDataPrep(np.ones((10, 2)))
+        with open(file, "wb") as f:
+            pickle.dump(data, f)
+        assert os.path.exists(file)
+        res = gen._load_pickle_data("DEBW107", ["o3", "temp"]).data
+        assert np.testing.assert_almost_equal(res, np.ones((10, 2))) is None
+        os.remove(file)
-- 
GitLab


From 4c0c54dbcddfcdb51e6ed9915a3d04129ef67189 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Thu, 6 Feb 2020 11:13:18 +0100
Subject: [PATCH 6/6] more tests

---
 src/data_handling/data_distributor.py          |  2 +-
 src/data_handling/data_generator.py            |  4 ++--
 .../test_keras_extensions.py                   | 18 ++++++++++++++++++
 3 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
index 74df5f6a..c6f38a6f 100644
--- a/src/data_handling/data_distributor.py
+++ b/src/data_handling/data_distributor.py
@@ -45,7 +45,7 @@ class Distributor(keras.utils.Sequence):
                 for prev, curr in enumerate(range(1, num_mini_batches+1)):
                     x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
                     y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
-                    if x is not None:
+                    if x is not None:  # pragma: no branch
                         yield (x, y)
                         if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call:
                             return
diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py
index 26b12d59..732a7efd 100644
--- a/src/data_handling/data_generator.py
+++ b/src/data_handling/data_generator.py
@@ -75,11 +75,11 @@ class DataGenerator(keras.utils.Sequence):
         if self._iterator < self.__len__():
             data = self.get_data_generator()
             self._iterator += 1
-            if data.history is not None and data.label is not None:
+            if data.history is not None and data.label is not None:  # pragma: no branch
                 return data.history.transpose("datetime", "window", "Stations", "variables"), \
                     data.label.squeeze("Stations").transpose("datetime", "window")
             else:
-                self.__next__()
+                self.__next__()  # pragma: no cover
         else:
             raise StopIteration
 
diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py
index c50e5e42..7c32844d 100644
--- a/test/test_model_modules/test_keras_extensions.py
+++ b/test/test_model_modules/test_keras_extensions.py
@@ -5,6 +5,24 @@ import keras
 import numpy as np
 
 
+class TestHistoryAdvanced:
+
+    def test_init(self):
+        hist = HistoryAdvanced()
+        assert hist.validation_data is None
+        assert hist.model is None
+        assert isinstance(hist.epoch, list) and len(hist.epoch) == 0
+        assert isinstance(hist.history, dict) and len(hist.history.keys()) == 0
+
+    def test_on_train_begin(self):
+        hist = HistoryAdvanced()
+        hist.epoch = [1, 2, 3]
+        hist.history = {"mse": [10, 7, 4]}
+        hist.on_train_begin()
+        assert hist.epoch == [1, 2, 3]
+        assert hist.history == {"mse": [10, 7, 4]}
+
+
 class TestLearningRateDecay:
 
     def test_init(self):
-- 
GitLab