From 49b35cb56d51252d329d914faeb8c0fed7b4d352 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Fri, 30 Oct 2020 10:02:05 +0100
Subject: [PATCH] apply various developments from #79: * corrected own_args
 method * corrected workflow registry * renaming of all data_preparations to
 handlers * data handler can use class data_handler attribute * error handling
 in post processing

---
 mlair/__init__.py                             |  6 +-
 mlair/configuration/.gitignore                |  3 +-
 mlair/configuration/defaults.py               |  6 +-
 mlair/data_handler/abstract_data_handler.py   |  5 +-
 .../data_preparation_neighbors.py             | 11 +--
 mlair/data_handler/default_data_handler.py    | 34 ++++---
 mlair/data_handler/station_preparation.py     | 25 +++---
 mlair/helpers/join.py                         |  1 +
 mlair/model_modules/model_class.py            | 62 ++++++++++++-
 mlair/plotting/postprocessing_plotting.py     |  7 +-
 mlair/run_modules/post_processing.py          |  2 +-
 mlair/run_modules/pre_processing.py           | 18 ++--
 mlair/workflows/abstract_workflow.py          | 10 ++-
 mlair/workflows/default_workflow.py           | 90 ++++++++++---------
 14 files changed, 174 insertions(+), 106 deletions(-)

diff --git a/mlair/__init__.py b/mlair/__init__.py
index 7097b1f3..5d6f2b67 100644
--- a/mlair/__init__.py
+++ b/mlair/__init__.py
@@ -1,7 +1,7 @@
 __version_info__ = {
-    'major': 1,
-    'minor': 0,
-    'micro': 0,
+    'major': 0,
+    'minor': 12,
+    'micro': 2,
 }
 
 from mlair.run_modules import RunEnvironment, ExperimentSetup, PreProcessing, ModelSetup, Training, PostProcessing
diff --git a/mlair/configuration/.gitignore b/mlair/configuration/.gitignore
index 8e2358dc..91eccc69 100644
--- a/mlair/configuration/.gitignore
+++ b/mlair/configuration/.gitignore
@@ -1 +1,2 @@
-join_settings.py
\ No newline at end of file
+join_settings.py
+join_rest
\ No newline at end of file
diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index d191af2e..85cf4334 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -46,15 +46,11 @@ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True
 DEFAULT_EVALUATE_BOOTSTRAPS = True
 DEFAULT_CREATE_NEW_BOOTSTRAPS = False
 DEFAULT_NUMBER_OF_BOOTSTRAPS = 20
-#DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
-#                     "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles",
-#                     "PlotAvailability"]
-DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", 
+DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
                      "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles",
                      "PlotAvailability"]
 
 
-
 def get_defaults():
     """Return all default parameters set in defaults.py"""
     return {key: value for key, value in globals().items() if key.startswith('DEFAULT')}
diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py
index 04b3d465..26ccf69c 100644
--- a/mlair/data_handler/abstract_data_handler.py
+++ b/mlair/data_handler/abstract_data_handler.py
@@ -27,7 +27,10 @@ class AbstractDataHandler:
 
     @classmethod
     def own_args(cls, *args):
-        return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args))
+        """Return all arguments (including kwonlyargs)."""
+        arg_spec = inspect.getfullargspec(cls)
+        list_of_args = arg_spec.args + arg_spec.kwonlyargs
+        return remove_items(list_of_args, ["self"] + list(args))
 
     @classmethod
     def transformation(cls, *args, **kwargs):
diff --git a/mlair/data_handler/data_preparation_neighbors.py b/mlair/data_handler/data_preparation_neighbors.py
index 1482bb9f..a004e659 100644
--- a/mlair/data_handler/data_preparation_neighbors.py
+++ b/mlair/data_handler/data_preparation_neighbors.py
@@ -4,9 +4,9 @@ __date__ = '2020-07-17'
 
 
 from mlair.helpers import to_list
-from mlair.data_handler.station_preparation import DataHandlerSingleStation
 from mlair.data_handler import DefaultDataHandler
 import os
+import copy
 
 from typing import Union, List
 
@@ -15,6 +15,7 @@ num_or_list = Union[number, List[number]]
 
 
 class DataHandlerNeighbors(DefaultDataHandler):
+    """Data handler including neighboring stations."""
 
     def __init__(self, id_class, data_path, neighbors=None, min_length=0,
                  extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
@@ -24,14 +25,14 @@ class DataHandlerNeighbors(DefaultDataHandler):
 
     @classmethod
     def build(cls, station, **kwargs):
-        sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
-        sp = DataHandlerSingleStation(station, **sp_keys)
+        sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
+        sp = cls.data_handler(station, **sp_keys)
         n_list = []
         for neighbor in kwargs.get("neighbors", []):
-            n_list.append(DataHandlerSingleStation(neighbor, **sp_keys))
+            n_list.append(cls.data_handler(neighbor, **sp_keys))
         else:
             kwargs["neighbors"] = n_list if len(n_list) > 0 else None
-        dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs}
+        dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
         return cls(sp, **dp_args)
 
     def _create_collection(self):
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 47f63a3e..d8ec3c4f 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -4,6 +4,7 @@ __date__ = '2020-09-21'
 
 import copy
 import inspect
+import gc
 import logging
 import os
 import pickle
@@ -15,7 +16,7 @@ import numpy as np
 import xarray as xr
 
 from mlair.data_handler.abstract_data_handler import AbstractDataHandler
-from mlair.data_handler.station_preparation import DataHandlerSingleStation
+from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
 from mlair.helpers import remove_items, to_list
 from mlair.helpers.join import EmptyQueryResult
 
@@ -25,11 +26,14 @@ num_or_list = Union[number, List[number]]
 
 
 class DefaultDataHandler(AbstractDataHandler):
+    from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler
+    from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation
 
-    _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
+    _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"])
 
-    def __init__(self, id_class: DataHandlerSingleStation, data_path: str, min_length: int = 0,
-                 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None):
+    def __init__(self, id_class: data_handler, data_path: str, min_length: int = 0,
+                 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None,
+                 store_processed_data=True):
         super().__init__()
         self.id_class = id_class
         self.interpolation_dim = "datetime"
@@ -43,12 +47,12 @@ class DefaultDataHandler(AbstractDataHandler):
         self._collection = self._create_collection()
         self.harmonise_X()
         self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim)
-        self._store(fresh_store=True)
+        self._store(fresh_store=True, store_processed_data=store_processed_data)
 
     @classmethod
     def build(cls, station: str, **kwargs):
         sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
-        sp = DataHandlerSingleStation(station, **sp_keys)
+        sp = cls.data_handler(station, **sp_keys)
         dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
         return cls(sp, **dp_args)
 
@@ -61,6 +65,7 @@ class DefaultDataHandler(AbstractDataHandler):
 
     def _reset_data(self):
         self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
+        gc.collect()
 
     def _cleanup(self):
         directory = os.path.dirname(self._save_file)
@@ -69,13 +74,14 @@ class DefaultDataHandler(AbstractDataHandler):
         if os.path.exists(self._save_file):
             shutil.rmtree(self._save_file, ignore_errors=True)
 
-    def _store(self, fresh_store=False):
-        self._cleanup() if fresh_store is True else None
-        data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
-        with open(self._save_file, "wb") as f:
-            pickle.dump(data, f)
-        logging.debug(f"save pickle data to {self._save_file}")
-        self._reset_data()
+    def _store(self, fresh_store=False, store_processed_data=True):
+        if store_processed_data is True:
+            self._cleanup() if fresh_store is True else None
+            data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
+            with open(self._save_file, "wb") as f:
+                pickle.dump(data, f)
+            logging.debug(f"save pickle data to {self._save_file}")
+            self._reset_data()
 
     def _load(self):
         try:
@@ -223,7 +229,7 @@ class DefaultDataHandler(AbstractDataHandler):
         mean, std = None, None
         for station in set_stations:
             try:
-                sp = DataHandlerSingleStation(station, transformation={"method": method}, **sp_keys)
+                sp = cls.data_handler_transformation(station, transformation={"method": method}, **sp_keys)
                 mean = sp.mean.copy(deep=True) if mean is None else mean.combine_first(sp.mean)
                 std = sp.std.copy(deep=True) if std is None else std.combine_first(sp.std)
             except (AttributeError, EmptyQueryResult):
diff --git a/mlair/data_handler/station_preparation.py b/mlair/data_handler/station_preparation.py
index f3428e91..9ffb89f8 100644
--- a/mlair/data_handler/station_preparation.py
+++ b/mlair/data_handler/station_preparation.py
@@ -15,7 +15,7 @@ import xarray as xr
 
 from mlair.configuration import check_path_and_create
 from mlair import helpers
-from mlair.helpers import join, statistics
+from mlair.helpers import join, statistics, TimeTrackingWrapper
 from mlair.data_handler.abstract_data_handler import AbstractDataHandler
 
 # define a more general date type for type hinting
@@ -166,6 +166,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
         self.call_transform()
         self.make_samples()
 
+    @TimeTrackingWrapper
     def setup_samples(self):
         """
         Setup samples. This method prepares and creates samples X, and labels Y.
@@ -508,11 +509,11 @@ class DataHandlerSingleStation(AbstractDataHandler):
         1) `station`: transform data for each station independently (somehow like batch normalisation)
         1) `data`: transform all data of each station with shared metrics
 
-        Transformation must be set by the `transformation` attribute. If `transformation = None` is given to `ExperimentSetup`, 
-        data is not transformed at all. For all other setups, use the following dictionary structure to specify the 
+        Transformation must be set by the `transformation` attribute. If `transformation = None` is given to `ExperimentSetup`,
+        data is not transformed at all. For all other setups, use the following dictionary structure to specify the
         transformation.
         ```
-        transformation = {"scope": <...>, 
+        transformation = {"scope": <...>,
                         "method": <...>,
                         "mean": <...>,
                         "std": <...>}
@@ -523,7 +524,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
 
         **station**: mean and std are not used
 
-        **data**: either provide already calculated values for mean and std (if required by transformation method), or choose 
+        **data**: either provide already calculated values for mean and std (if required by transformation method), or choose
         from different calculation schemes, explained in the mean and std section.
 
         ### supported transformation methods
@@ -532,26 +533,26 @@ class DataHandlerSingleStation(AbstractDataHandler):
         * centre
 
         ### mean and std
-        `"mean"="accurate"`: calculate the accurate values of mean and std (depending on method) by using all data. Although, 
-        this method is accurate, it may take some time for the calculation. Furthermore, this could potentially lead to memory 
+        `"mean"="accurate"`: calculate the accurate values of mean and std (depending on method) by using all data. Although,
+        this method is accurate, it may take some time for the calculation. Furthermore, this could potentially lead to memory
         issue (not explored yet, but could appear for a very big amount of data)
 
         `"mean"="estimate"`: estimate mean and std (depending on method). For each station, mean and std are calculated and
-        afterwards aggregated using the mean value over all station-wise metrics. This method is less accurate, especially 
+        afterwards aggregated using the mean value over all station-wise metrics. This method is less accurate, especially
         regarding the std calculation but therefore much faster.
 
         We recommend to use the later method *estimate* because of following reasons:
         * much faster calculation
         * real accuracy of mean and std is less important, because it is "just" a transformation / scaling
-        * accuracy of mean is almost as high as in the *accurate* case, because of 
-        $\bar{x_{ij}} = \bar{\left(\bar{x_i}\right)_j}$. The only difference is, that in the *estimate* case, each mean is 
+        * accuracy of mean is almost as high as in the *accurate* case, because of
+        $\bar{x_{ij}} = \bar{\left(\bar{x_i}\right)_j}$. The only difference is, that in the *estimate* case, each mean is
         equally weighted for each station independently of the actual data count of the station.
-        * accuracy of std is lower for *estimate* because of $\var{x_{ij}} \ne \bar{\left(\var{x_i}\right)_j}$, but still the mean of all 
+        * accuracy of std is lower for *estimate* because of $\var{x_{ij}} \ne \bar{\left(\var{x_i}\right)_j}$, but still the mean of all
         station-wise std is a decent estimate of the true std.
 
         `"mean"=<value, e.g. xr.DataArray>`: If mean and std are already calculated or shall be set manually, just add the
         scaling values instead of the calculation method. For method *centre*, std can still be None, but is required for the
-        *standardise* method. **Important**: Format of given values **must** match internal data format of DataPreparation 
+        *standardise* method. **Important**: Format of given values **must** match internal data format of DataPreparation
         class: `xr.DataArray` with `dims=["variables"]` and one value for each variable.
 
         """
diff --git a/mlair/helpers/join.py b/mlair/helpers/join.py
index a3c6876e..b1e27830 100644
--- a/mlair/helpers/join.py
+++ b/mlair/helpers/join.py
@@ -138,6 +138,7 @@ def load_series_information(station_name: List[str], station_type: str_or_none,
     opts = {"base": join_url_base, "service": "series", "station_id": station_name[0], "station_type": station_type,
             "network_name": network_name}
     station_vars = get_data(opts, headers)
+    logging.info(f"{station_name}: {station_vars}")
     vars_dict = {item[3].lower(): item[0] for item in station_vars}
     return vars_dict
 
diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py
index c9cc13bd..a603b466 100644
--- a/mlair/model_modules/model_class.py
+++ b/mlair/model_modules/model_class.py
@@ -396,8 +396,66 @@ class MyLittleModel(AbstractModelClass):
     def set_compile_options(self):
         self.initial_lr = 1e-2
         self.optimizer = keras.optimizers.adam(lr=self.initial_lr)
-        self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
-                                                                               epochs_drop=10)
+        # self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
+        #                                                                        epochs_drop=10)
+        self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
+
+
+class MyLittleModelHourly(AbstractModelClass):
+    """
+    A customised model with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the
+    output layer depending on the window_lead_time parameter. Dropout is used between the Convolution and the first
+    Dense layer.
+    """
+
+    def __init__(self, input_shape: list, output_shape: list):
+        """
+        Sets model and loss depending on the given arguments.
+
+        :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast))
+        """
+
+        assert len(input_shape) == 1
+        assert len(output_shape) == 1
+        super().__init__(input_shape[0], output_shape[0])
+
+        # settings
+        self.dropout_rate = 0.1
+        self.regularizer = keras.regularizers.l2(0.001)
+        self.activation = keras.layers.PReLU
+
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(loss=self.compile_options['loss'])
+
+    def set_model(self):
+        """
+        Build the model.
+        """
+
+        # add 1 to window_size to include current time step t0
+        x_input = keras.layers.Input(shape=self._input_shape)
+        x_in = keras.layers.Conv2D(128, (1, 1), padding='same', name='{}_Conv_1x1_128'.format("major"))(x_input)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Conv2D(64, (1, 1), padding='same', name='{}_Conv_1x1_64'.format("major"))(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1_32'.format("major"))(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
+        x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in)
+        x_in = keras.layers.Dense(128, name='{}_Dense_128'.format("major"))(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Dense(self._output_shape, name='{}_Dense'.format("major"))(x_in)
+        out_main = self.activation()(x_in)
+        self.model = keras.Model(inputs=x_input, outputs=[out_main])
+
+    def set_compile_options(self):
+        self.initial_lr = 1e-2
+        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
         self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
 
 
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 675e5ade..327dc40d 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -902,6 +902,7 @@ class PlotAvailability(AbstractPlotClass):
         # create standard Gantt plot for all stations (currently in single pdf file with single page)
         super().__init__(plot_folder, "data_availability")
         self.dim = time_dimension
+        self.linewidth = None
         self.sampling = self._get_sampling(sampling)
         plot_dict = self._prepare_data(generators)
         lgd = self._plot(plot_dict)
@@ -917,11 +918,11 @@ class PlotAvailability(AbstractPlotClass):
         lgd = self._plot(plot_dict_summary)
         self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
 
-    @staticmethod
-    def _get_sampling(sampling):
+    def _get_sampling(self, sampling):
         if sampling == "daily":
             return "D"
         elif sampling == "hourly":
+            self.linewidth = 0.001
             return "h"
 
     def _prepare_data(self, generators: Dict[str, DataCollection]):
@@ -982,7 +983,7 @@ class PlotAvailability(AbstractPlotClass):
                 plt_data = d.get(subset)
                 if plt_data is None:
                     continue
-                ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white")
+                ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth)
             yticklabels.append(station)
 
         ax.set_ylim([height, number_of_stations + 1])
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index de43f30d..7b1d1455 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -528,7 +528,7 @@ class PostProcessing(RunEnvironment):
             # external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"])
             external_data = self._create_observation(observation, None, mean, std, transformation_method, normalised=False)
             return external_data.rename({external_data.dims[0]: 'index'})
-        except IndexError:
+        except (IndexError, KeyError):
             return None
 
     def calculate_skill_scores(self) -> Tuple[Dict, Dict]:
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index ed972896..f074863e 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -11,7 +11,7 @@ import numpy as np
 import pandas as pd
 
 from mlair.data_handler import DataCollection, AbstractDataHandler
-from mlair.helpers import TimeTracking
+from mlair.helpers import TimeTracking, to_list
 from mlair.configuration import path_config
 from mlair.helpers.join import EmptyQueryResult
 from mlair.run_modules.run_environment import RunEnvironment
@@ -56,7 +56,8 @@ class PreProcessing(RunEnvironment):
     def _run(self):
         stations = self.data_store.get("stations")
         data_handler = self.data_store.get("data_handler")
-        _, valid_stations = self.validate_station(data_handler, stations, "preprocessing", overwrite_local_data=True)
+        _, valid_stations = self.validate_station(data_handler, stations,
+                                                  "preprocessing")  # , store_processed_data=False)
         if len(valid_stations) == 0:
             raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.")
         self.data_store.set("stations", valid_stations)
@@ -192,20 +193,14 @@ class PreProcessing(RunEnvironment):
         self.data_store.set("stations", valid_stations, scope=set_name)
         self.data_store.set("data_collection", collection, scope=set_name)
 
-    def validate_station(self, data_handler: AbstractDataHandler, set_stations, set_name=None, overwrite_local_data=False):
+    def validate_station(self, data_handler: AbstractDataHandler, set_stations, set_name=None,
+                         store_processed_data=True):
         """
         Check if all given stations in `all_stations` are valid.
 
         Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the
         loading time are logged in debug mode.
 
-        :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
-            `variables`, `time_dim`, `target_dim`, `target_var`).
-        :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolation_method`,
-            `window_lead_time`).
-        :param all_stations: All stations to check.
-        :param name: name to display in the logging info message
-
         :return: Corrected list containing only valid station IDs.
         """
         t_outer = TimeTracking()
@@ -219,7 +214,8 @@ class PreProcessing(RunEnvironment):
         kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name)
         for station in set_stations:
             try:
-                dp = data_handler.build(station, name_affix=set_name, **kwargs)
+                dp = data_handler.build(station, name_affix=set_name, store_processed_data=store_processed_data,
+                                        **kwargs)
                 collection.add(dp)
                 valid_stations.append(station)
             except (AttributeError, EmptyQueryResult):
diff --git a/mlair/workflows/abstract_workflow.py b/mlair/workflows/abstract_workflow.py
index bced90bb..3a627d9f 100644
--- a/mlair/workflows/abstract_workflow.py
+++ b/mlair/workflows/abstract_workflow.py
@@ -16,15 +16,17 @@ class Workflow:
     execution but not the dependencies (workflow would probably fail in this case)."""
 
     def __init__(self, name=None):
-        self._registry = OrderedDict()
+        self._registry_kwargs = {}
+        self._registry = []
         self._name = name if name is not None else self.__class__.__name__
 
     def add(self, stage, **kwargs):
         """Add a new stage with optional kwargs."""
-        self._registry[stage] = kwargs
+        self._registry.append(stage)
+        self._registry_kwargs[len(self._registry) - 1] = kwargs
 
     def run(self):
         """Run workflow embedded in a run environment and according to the stage's ordering."""
         with RunEnvironment(name=self._name):
-            for stage, kwargs in self._registry.items():
-                stage(**kwargs)
+            for pos, stage in enumerate(self._registry):
+                stage(**self._registry_kwargs[pos])
diff --git a/mlair/workflows/default_workflow.py b/mlair/workflows/default_workflow.py
index 85d6726b..4d113190 100644
--- a/mlair/workflows/default_workflow.py
+++ b/mlair/workflows/default_workflow.py
@@ -14,28 +14,29 @@ class DefaultWorkflow(Workflow):
     the mentioned ordering."""
 
     def __init__(self, stations=None,
-        train_model=None, create_new_model=None,
-        window_history_size=None,
-        experiment_date="testrun",
-        variables=None, statistics_per_var=None,
-        start=None, end=None,
-        target_var=None, target_dim=None,
-        window_lead_time=None,
-        dimensions=None,
-        interpolation_method=None, time_dim=None, limit_nan_fill=None,
-        train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None,
-        use_all_stations_on_all_data_sets=None, fraction_of_train=None,
-        experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, overwrite_local_data=None,
-        sampling=None,
-        permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None,
-        transformation=None,
-        train_min_length=None, val_min_length=None, test_min_length=None,
-        evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None,
-        plot_list=None,
-        model=None,
-        batch_size=None,
-        epochs=None,
-        data_preparation=None,
+                 train_model=None, create_new_model=None,
+                 window_history_size=None,
+                 experiment_date="testrun",
+                 variables=None, statistics_per_var=None,
+                 start=None, end=None,
+                 target_var=None, target_dim=None,
+                 window_lead_time=None,
+                 dimensions=None,
+                 interpolation_method=None, time_dim=None, limit_nan_fill=None,
+                 train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None,
+                 use_all_stations_on_all_data_sets=None, fraction_of_train=None,
+                 experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None,
+                 overwrite_local_data=None,
+                 sampling=None,
+                 permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None,
+                 transformation=None,
+                 train_min_length=None, val_min_length=None, test_min_length=None,
+                 evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None,
+                 plot_list=None,
+                 model=None,
+                 batch_size=None,
+                 epochs=None,
+                 data_handler=None,
                  **kwargs):
         super().__init__()
 
@@ -58,28 +59,29 @@ class DefaultWorkflowHPC(Workflow):
     Training and PostProcessing in exact the mentioned ordering."""
 
     def __init__(self, stations=None,
-        train_model=None, create_new_model=None,
-        window_history_size=None,
-        experiment_date="testrun",
-        variables=None, statistics_per_var=None,
-        start=None, end=None,
-        target_var=None, target_dim=None,
-        window_lead_time=None,
-        dimensions=None,
-        interpolation_method=None, time_dim=None, limit_nan_fill=None,
-        train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None,
-        use_all_stations_on_all_data_sets=None, fraction_of_train=None,
-        experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, overwrite_local_data=None,
-        sampling=None,
-        permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None,
-        transformation=None,
-        train_min_length=None, val_min_length=None, test_min_length=None,
-        evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None,
-        plot_list=None,
-        model=None,
-        batch_size=None,
-        epochs=None,
-        data_preparation=None, **kwargs):
+                 train_model=None, create_new_model=None,
+                 window_history_size=None,
+                 experiment_date="testrun",
+                 variables=None, statistics_per_var=None,
+                 start=None, end=None,
+                 target_var=None, target_dim=None,
+                 window_lead_time=None,
+                 dimensions=None,
+                 interpolation_method=None, time_dim=None, limit_nan_fill=None,
+                 train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None,
+                 use_all_stations_on_all_data_sets=None, fraction_of_train=None,
+                 experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None,
+                 overwrite_local_data=None,
+                 sampling=None,
+                 permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None,
+                 transformation=None,
+                 train_min_length=None, val_min_length=None, test_min_length=None,
+                 evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None,
+                 plot_list=None,
+                 model=None,
+                 batch_size=None,
+                 epochs=None,
+                 data_handler=None, **kwargs):
         super().__init__()
 
         # extract all given kwargs arguments
-- 
GitLab