diff --git a/.gitignore b/.gitignore
index e109cec7c7622cee9d9a635cc458b9c662fc4761..305a5d1b9420eb62da24772fc1f4b263c1f3efe1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -60,7 +60,7 @@ Thumbs.db
 htmlcov/
 .pytest_cache
 /test/data/
-/test/test_modules/data/
+/test/test_run_modules/data/
 report.html
 /TestExperiment/
 /testrun_network*/
@@ -73,7 +73,7 @@ report.html
 
 # secret variables #
 ####################
-/src/configuration/join_settings.py
+/mlair/configuration/join_settings.py
 
 # ignore locally build documentation #
 ######################################
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a08bab0068246f8d57b3789e10f6f0f4105817ad..823c37005922ca5b8a621b55f7bdb5528f7f9b76 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,37 @@
 # Changelog
 All notable changes to this project will be documented in this file.
 
+## v0.10.0 -  2020-07-15  -  MLAir is official name, Workflows, easy Model plug-in
+
+### general
+- Official project name is released: MLAir (Machine Learning on Air data)
+- a model class can now easily be plugged in into MLAir. #121
+- introduced new concept of workflows, #134
+
+### new features
+- workflows are used to execute a sequence of run modules, #134
+- default workflows for standard and the Juelich HPC systems are available, custom workflows can be defined, #134
+- seasonal decomposition is available for conditional quantile plot, #112
+- map plot is created with coordinates, #108
+- `flatten_tails` are now more general and easier to customise, #114
+- model classes have custom compile options (replaces `set_loss`), #110
+- model can be set in ExperimentSetup from outside, #121
+- default experiment settings can be queried using `get_defaults()`, #123
+- training and model settings are reported as MarkDown and Tex tables, #145
+
+### technical
+- Juelich HPC systems are supported and installation scripts are available, #106
+- data store is tracked, I/O is saved and illustrated in a plot, #116
+- batch size, epoch parameter have to be defined in ExperimentSetup, #127, #122
+- automatic documentation with sphinx, #109
+- default experiment settings are updated, #123
+- refactoring of experiment path and its default naming, #124
+- refactoring of some parameter names, #146
+- preparation for package distribution with pip, #119
+- all run scripts are updated to run with workflows, #134
+- the experiment folder is restructured, #130
+
+
 ## v0.9.0  -  2020-04-15  -  faster bootstraps, extreme value upsamling
 ### general
 - improved and faster bootstrap workflow
diff --git a/CI/run_pytest_coverage.sh b/CI/run_pytest_coverage.sh
index 45916427f1521843923fb94e49dc661241dc0369..24d916b1a32da714abc2e5de0ac2b4c2790752a9 100644
--- a/CI/run_pytest_coverage.sh
+++ b/CI/run_pytest_coverage.sh
@@ -1,7 +1,7 @@
 #!/usr/bin/env bash
 
 # run coverage twice, 1) for html deploy 2) for success evaluation
-python3.6 -m pytest --cov=src --cov-report term  --cov-report html test/ | tee coverage_results.out
+python3.6 -m pytest --cov=mlair --cov-report term  --cov-report html test/ | tee coverage_results.out
 
 IS_FAILED=$?
 
diff --git a/LICENSE b/LICENSE
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..a79ea789a5b55f7328d1fd987293376838112048 100644
--- a/LICENSE
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Lukas Leufen
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
index baae0af91036da10ba70f154ac875c18908858c3..5c55b4094232908a56cdcf61ba437976f8714e8b 100644
--- a/README.md
+++ b/README.md
@@ -1,50 +1,217 @@
-# MachineLearningTools
+# MLAir - Machine Learning on Air Data
 
-This is a collection of all relevant functions used for ML stuff in the ESDE group
+MLAir (Machine Learning on Air data) is an environment that simplifies and accelerates the creation of new machine 
+learning (ML) models for the analysis and forecasting of meteorological and air quality time series.
 
-## Inception Model
+# Installation
 
-See a description [here](https://towardsdatascience.com/a-simple-guide-to-the-versions-of-the-inception-network-7fc52b863202)
-or take a look on the papers [Going Deeper with Convolutions (Szegedy et al., 2014)](https://arxiv.org/abs/1409.4842)
-and [Network In Network (Lin et al., 2014)](https://arxiv.org/abs/1312.4400).
+MLAir is based on several python frameworks. To work properly, you have to install all packages from the 
+`requirements.txt` file. Additionally to support the geographical plotting part it is required to install geo
+packages built for your operating system. Name names of these package may differ for different systems, we refer
+here to the opensuse / leap OS. The geo plot can be removed from the `plot_list`, in this case there is no need to 
+install the geo packages.
 
+* (geo) Install **proj** on your machine using the console. E.g. for opensuse / leap `zypper install proj`
+* (geo) A c++ compiler is required for the installation of the program **cartopy**
+* Install all requirements from [`requirements.txt`](https://gitlab.version.fz-juelich.de/toar/machinelearningtools/-/blob/master/requirements.txt)
+  preferably in a virtual environment
+* (tf) Currently, TensorFlow-1.13 is mentioned in the requirements. We already tested the TensorFlow-1.15 version and couldn't
+  find any compatibility errors. Please note, that tf-1.13 and 1.15 have two distinct branches each, the default branch 
+  for CPU support, and the "-gpu" branch for GPU support. If the GPU version is installed, MLAir will make use of the GPU
+  device.
+* Installation of **MLAir**:
+    * Either clone MLAir from the [gitlab repository](https://gitlab.version.fz-juelich.de/toar/machinelearningtools.git) 
+      and use it without installation (beside the requirements) 
+    * or download the distribution file (?? .whl) and install it via `pip install <??>`. In this case, you can simply
+      import MLAir in any python script inside your virtual environment using `import mlair`.
 
-# Installation
+# How to start with MLAir
 
-* Install __proj__ on your machine using the console. E.g. for opensuse / leap `zypper install proj`
-* c++ compiler required for cartopy installation
+In this section, we show three examples how to work with MLAir.
 
-## HPC - JUWELS and HDFML setup
-The following instruction guide you throug the installation on JUWELS and HDFML. 
-* Clone the repo to HPC system (we recommend to place it in `/p/projects/<project name>`.
-* Setup venv by executing `source setupHPC.sh`. This script loads all pre-installed modules and creates a venv for 
-all other packages. Furthermore, it creates slurm/batch scripts to execute code on compute nodes. <br> 
-You have to enter the HPC project's budget name (--account flag).
-* The default external data path on JUWELS and HDFML is set to `/p/project/deepacf/intelliaq/<user>/DATA/toar_<sampling>`. 
-<br>To choose a different location open `run.py` and add the following keyword argument to `ExperimentSetup`: 
-`data_path=<your>/<custom>/<path>`. 
-* Execute `python run.py` on a login node to download example data. The program will throw an OSerror after downloading.
-* Execute either `sbatch run_juwels_develgpus.bash` or `sbatch run_hdfml_batch.bash` to verify that the setup went well.
-* Currently cartopy is not working on our HPC system, therefore PlotStations does not create any output.
+## Example 1
 
-### HPC JUWELS and HDFML remarks 
-Please note, that the HPC setup is customised for JUWELS and HDFML. When using another HPC system, you can use the HPC setup files as a skeleton and customise it to your needs. 
+We start MLAir in a dry run without any modification. Just import mlair and run it.
+```python
+import mlair
 
-Note: The method `PartitionCheck` currently only checks if the hostname starts with `ju` or `hdfmll`. 
-Therefore, it might be necessary to adopt the `if` statement in `PartitionCheck._run`.
+# just give it a dry run without any modification 
+mlair.run()
+```
+The logging output will show you many informations. Additional information (including debug messages) are collected 
+inside the experiment path in the logging folder.
+```log
+INFO: mlair started
+INFO: ExperimentSetup started
+INFO: Experiment path is: /home/<usr>/mlair/testrun_network 
+...
+INFO: load data for DEBW001 from JOIN 
+...
+INFO: Training started
+...
+INFO: mlair finished after 00:00:12 (hh:mm:ss)
+```
 
+## Example 2
 
-# Security
+Now we update the stations and customise the window history size parameter.
 
-* To use hourly data from ToarDB via JOIN interface, a private token is required. Request your personal access token and
-add it to `src/join_settings.py` in the hourly data section. Replace the `TOAR_SERVICE_URL` and the `Authorization` 
-value. To make sure, that this **sensitive** data is not uploaded to the remote server, use the following command to
-prevent git from tracking this file: `git update-index --assume-unchanged src/join_settings.py
-`
+```python
+import mlair
+
+# our new stations to use
+stations = ['DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW107']
+
+# expanded temporal context to 14 (days, because of default sampling="daily")
+window_history_size = 14
+
+# restart the experiment with little customisation
+mlair.run(stations=stations, 
+          window_history_size=window_history_size)
+```
+The output looks similar, but we can see, that the new stations are loaded.
+```log
+INFO: mlair started
+INFO: ExperimentSetup started
+...
+INFO: load data for DEBW030 from JOIN 
+INFO: load data for DEBW037 from JOIN 
+...
+INFO: Training started
+...
+INFO: mlair finished after 00:00:24 (hh:mm:ss)
+```
+
+## Example 3
+
+Let's just apply our trained model to new data. Therefore we keep the window history size parameter but change the stations.
+In the run method, we need to disable the trainable and create new model parameters. MLAir will use the model we have
+trained before. Note, this only works if the experiment path has not changed or a suitable trained model is placed 
+inside the experiment path.
+```python
+import mlair
+
+# our new stations to use
+stations = ['DEBY002', 'DEBY079']
+
+# same setting for window_history_size
+window_history_size = 14
+
+# run experiment without training
+mlair.run(stations=stations, 
+          window_history_size=window_history_size, 
+          create_new_model=False, 
+          trainable=False)
+```
+We can see from the terminal that no training was performed. Analysis is now made on the new stations.
+```log
+INFO: mlair started
+...
+INFO: No training has started, because trainable parameter was false. 
+...
+INFO: mlair finished after 00:00:06 (hh:mm:ss)
+```
+
+# Customised workflows and models
+
+# Custom Workflow
+
+MLAir provides a default workflow. If additional steps are to be performed, you have to append custom run modules to 
+the workflow.
+
+```python
+import mlair
+import logging
+
+class CustomStage(mlair.RunEnvironment):
+    """A custom MLAir stage for demonstration."""
+
+    def __init__(self, test_string):
+        super().__init__()  # always call super init method
+        self._run(test_string)  # call a class method
+        
+    def _run(self, test_string):
+        logging.info("Just running a custom stage.")
+        logging.info("test_string = " + test_string)
+        epochs = self.data_store.get("epochs")
+        logging.info("epochs = " + str(epochs))
 
-# Customise your experiment
+        
+# create your custom MLAir workflow
+CustomWorkflow = mlair.Workflow()
+# provide stages without initialisation
+CustomWorkflow.add(mlair.ExperimentSetup, epochs=128)
+# add also keyword arguments for a specific stage
+CustomWorkflow.add(CustomStage, test_string="Hello World")
+# finally execute custom workflow in order of adding
+CustomWorkflow.run()
+```
+```log
+INFO: mlair started
+...
+INFO: ExperimentSetup finished after 00:00:12 (hh:mm:ss)
+INFO: CustomStage started
+INFO: Just running a custom stage.
+INFO: test_string = Hello World
+INFO: epochs = 128
+INFO: CustomStage finished after 00:00:01 (hh:mm:ss)
+INFO: mlair finished after 00:00:13 (hh:mm:ss)
+```
+
+## Custom Model
+
+Each model has to inherit from the abstract model class to ensure a smooth training and evaluation behaviour. It is 
+required to implement the set model and set compile options methods. The later has to set the loss at least.
+
+```python
+
+import keras
+from keras.losses import mean_squared_error as mse
+from keras.optimizers import SGD
+
+from mlair.model_modules import AbstractModelClass
+
+class MyLittleModel(AbstractModelClass):
+    """
+    A customised model with a 1x1 Conv, and 3 Dense layers (32, 16
+    window_lead_time). Dropout is used after Conv layer.
+    """
+    def __init__(self, window_history_size, window_lead_time, channels):
+        super().__init__()
+        # settings
+        self.window_history_size = window_history_size
+        self.window_lead_time = window_lead_time
+        self.channels = channels
+        self.dropout_rate = 0.1
+        self.activation = keras.layers.PReLU
+        self.lr = 1e-2
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(loss=self.compile_options['loss'])
+
+    def set_model(self):
+        # add 1 to window_size to include current time step t0
+        shape = (self.window_history_size + 1, 1, self.channels)
+        x_input = keras.layers.Input(shape=shape)
+        x_in = keras.layers.Conv2D(32, (1, 1), padding='same')(x_input)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Flatten()(x_in)
+        x_in = keras.layers.Dropout(self.dropout_rate)(x_in)
+        x_in = keras.layers.Dense(32)(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Dense(16)(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Dense(self.window_lead_time)(x_in)
+        out = self.activation()(x_in)
+        self.model = keras.Model(inputs=x_input, outputs=[out])
+
+    def set_compile_options(self):
+        self.compile_options = {"optimizer": SGD(lr=self.lr),
+                                "loss": mse, 
+                                "metrics": ["mse"]}
+```
 
-This section summarises which parameters can be customised for a training.
 
 ## Transformation
 
@@ -97,3 +264,36 @@ station-wise std is a decent estimate of the true std.
 scaling values instead of the calculation method. For method *centre*, std can still be None, but is required for the
 *standardise* method. **Important**: Format of given values **must** match internal data format of DataPreparation 
 class: `xr.DataArray` with `dims=["variables"]` and one value for each variable.
+
+
+
+
+
+# Special Remarks
+
+## Special instructions for installation on Jülich HPC systems
+
+_Please note, that the HPC setup is customised for JUWELS and HDFML. When using another HPC system, you can use the HPC 
+setup files as a skeleton and customise it to your needs._
+
+The following instruction guide you through the installation on JUWELS and HDFML. 
+* Clone the repo to HPC system (we recommend to place it in `/p/projects/<project name>`).
+* Setup venv by executing `source setupHPC.sh`. This script loads all pre-installed modules and creates a venv for 
+all other packages. Furthermore, it creates slurm/batch scripts to execute code on compute nodes. <br> 
+You have to enter the HPC project's budget name (--account flag).
+* The default external data path on JUWELS and HDFML is set to `/p/project/deepacf/intelliaq/<user>/DATA/toar_<sampling>`. 
+<br>To choose a different location open `run.py` and add the following keyword argument to `ExperimentSetup`: 
+`data_path=<your>/<custom>/<path>`. 
+* Execute `python run.py` on a login node to download example data. The program will throw an OSerror after downloading.
+* Execute either `sbatch run_juwels_develgpus.bash` or `sbatch run_hdfml_batch.bash` to verify that the setup went well.
+* Currently cartopy is not working on our HPC system, therefore PlotStations does not create any output.
+
+Note: The method `PartitionCheck` currently only checks if the hostname starts with `ju` or `hdfmll`. 
+Therefore, it might be necessary to adopt the `if` statement in `PartitionCheck._run`.
+
+## Security using JOIN
+
+* To use hourly data from ToarDB via JOIN interface, a private token is required. Request your personal access token and
+add it to `src/join_settings.py` in the hourly data section. Replace the `TOAR_SERVICE_URL` and the `Authorization` 
+value. To make sure, that this **sensitive** data is not uploaded to the remote server, use the following command to
+prevent git from tracking this file: `git update-index --assume-unchanged src/join_settings.py`
diff --git a/docs/_source/_plots/padding_example1.png b/docs/_source/_plots/padding_example1.png
new file mode 100755
index 0000000000000000000000000000000000000000..e609cbb9fe22f406c97ceb8637751e484d139409
Binary files /dev/null and b/docs/_source/_plots/padding_example1.png differ
diff --git a/docs/_source/_plots/padding_example2.png b/docs/_source/_plots/padding_example2.png
new file mode 100755
index 0000000000000000000000000000000000000000..cfc84c6961eb6d24aef135d9e8fc5bae74a78f8a
Binary files /dev/null and b/docs/_source/_plots/padding_example2.png differ
diff --git a/docs/_source/conf.py b/docs/_source/conf.py
index 6363f57eb45e686f6f2ef8ab07806e4feba0fe2d..573918ee35e9757b8c0b32b2697fc0cc2bc0b38f 100644
--- a/docs/_source/conf.py
+++ b/docs/_source/conf.py
@@ -17,7 +17,7 @@ sys.path.insert(0, os.path.abspath('../..'))
 
 # -- Project information -----------------------------------------------------
 
-project = 'machinelearningtools'
+project = 'MLAir'
 copyright = '2020, Lukas H Leufen, Felix Kleinert'
 author = 'Lukas H Leufen, Felix Kleinert'
 
@@ -55,7 +55,7 @@ extensions = [
 autosummary_generate = True
 
 autoapi_type = 'python'
-autoapi_dirs = ['../../src/.']
+autoapi_dirs = ['../../mlair/.']
 
 # Add any paths that contain templates here, relative to this directory.
 templates_path = ['_templates']
@@ -118,7 +118,7 @@ latex_elements = {
 # (source start file, target name, title,
 #  author, documentclass [howto, manual, or own class]).
 latex_documents = [
-    (master_doc, 'machinelearningtools.tex', 'MachineLearningTools Documentation',
+    (master_doc, 'mlair.tex', 'MLAir Documentation',
      author, 'manual'),
 ]
 
diff --git a/docs/_source/get-started.rst b/docs/_source/get-started.rst
index e5a82fdcf1d16ca2188a04e3dce76dc7ba9d477a..98a96d43675a0263be5bfc2d452b8af1c2626b60 100644
--- a/docs/_source/get-started.rst
+++ b/docs/_source/get-started.rst
@@ -1,16 +1,232 @@
-Get started with MachineLearningTools
-=====================================
+Get started with MLAir
+======================
 
-<what is machinelearningtools?>
+Install MLAir
+-------------
 
-MLT Module and Funtion Documentation
-------------------------------------
+MLAir is based on several python frameworks. To work properly, you have to install all packages from the
+`requirements.txt` file. Additionally to support the geographical plotting part it is required to install geo
+packages built for your operating system. Name names of these package may differ for different systems, we refer
+here to the opensuse / leap OS. The geo plot can be removed from the `plot_list`, in this case there is no need to
+install the geo packages.
 
-Install MachineLearningTools
-----------------------------
+* (geo) Install **proj** on your machine using the console. E.g. for opensuse / leap `zypper install proj`
+* (geo) A c++ compiler is required for the installation of the program **cartopy**
+* Install all requirements from [`requirements.txt`](https://gitlab.version.fz-juelich.de/toar/machinelearningtools/-/blob/master/requirements.txt)
+  preferably in a virtual environment
+* (tf) Currently, TensorFlow-1.13 is mentioned in the requirements. We already tested the TensorFlow-1.15 version and couldn't
+  find any compatibility errors. Please note, that tf-1.13 and 1.15 have two distinct branches each, the default branch
+  for CPU support, and the "-gpu" branch for GPU support. If the GPU version is installed, MLAir will make use of the GPU
+  device.
+* Installation of **MLAir**:
+    * Either clone MLAir from the [gitlab repository](https://gitlab.version.fz-juelich.de/toar/machinelearningtools.git)
+      and use it without installation (beside the requirements)
+    * or download the distribution file (?? .whl) and install it via `pip install <??>`. In this case, you can simply
+      import MLAir in any python script inside your virtual environment using `import mlair`.
 
-Dependencies
+
+How to start with MLAir
+-----------------------
+
+In this section, we show three examples how to work with MLAir.
+
+Example 1
+~~~~~~~~~
+
+We start MLAir in a dry run without any modification. Just import mlair and run it.
+
+.. code-block:: python
+
+    import mlair
+
+    # just give it a dry run without any modification
+    mlair.run()
+
+
+The logging output will show you many informations. Additional information (including debug messages) are collected
+inside the experiment path in the logging folder.
+
+.. code-block::
+
+    INFO: mlair started
+    INFO: ExperimentSetup started
+    INFO: Experiment path is: /home/<usr>/mlair/testrun_network
+    ...
+    INFO: load data for DEBW001 from JOIN
+    ...
+    INFO: Training started
+    ...
+    INFO: mlair finished after 00:00:12 (hh:mm:ss)
+
+
+Example 2
+~~~~~~~~~
+
+Now we update the stations and customise the window history size parameter.
+
+.. code-block:: python
+
+    import mlair
+
+    # our new stations to use
+    stations = ['DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW107']
+
+    # expanded temporal context to 14 (days, because of default sampling="daily")
+    window_history_size = 14
+
+    # restart the experiment with little customisation
+    mlair.run(stations=stations,
+              window_history_size=window_history_size)
+
+The output looks similar, but we can see, that the new stations are loaded.
+
+.. code-block::
+
+    INFO: mlair started
+    INFO: ExperimentSetup started
+    ...
+    INFO: load data for DEBW030 from JOIN
+    INFO: load data for DEBW037 from JOIN
+    ...
+    INFO: Training started
+    ...
+    INFO: mlair finished after 00:00:24 (hh:mm:ss)
+
+Example 3
+~~~~~~~~~
+
+Let's just apply our trained model to new data. Therefore we keep the window history size parameter but change the stations.
+In the run method, we need to disable the trainable and create new model parameters. MLAir will use the model we have
+trained before. Note, this only works if the experiment path has not changed or a suitable trained model is placed
+inside the experiment path.
+
+.. code-block:: python
+
+    import mlair
+
+    # our new stations to use
+    stations = ['DEBY002', 'DEBY079']
+
+    # same setting for window_history_size
+    window_history_size = 14
+
+    # run experiment without training
+    mlair.run(stations=stations,
+              window_history_size=window_history_size,
+              create_new_model=False,
+              trainable=False)
+
+We can see from the terminal that no training was performed. Analysis is now made on the new stations.
+
+.. code-block::
+
+    INFO: mlair started
+    ...
+    INFO: No training has started, because trainable parameter was false.
+    ...
+    INFO: mlair finished after 00:00:06 (hh:mm:ss)
+
+
+
+Customised workflows and models
+-------------------------------
+
+Custom Workflow
+~~~~~~~~~~~~~~~
+
+MLAir provides a default workflow. If additional steps are to be performed, you have to append custom run modules to
+the workflow.
+
+.. code-block:: python
+
+    import mlair
+    import logging
+
+    class CustomStage(mlair.RunEnvironment):
+        """A custom MLAir stage for demonstration."""
+
+        def __init__(self, test_string):
+            super().__init__()  # always call super init method
+            self._run(test_string)  # call a class method
+
+        def _run(self, test_string):
+            logging.info("Just running a custom stage.")
+            logging.info("test_string = " + test_string)
+            epochs = self.data_store.get("epochs")
+            logging.info("epochs = " + str(epochs))
+
+
+    # create your custom MLAir workflow
+    CustomWorkflow = mlair.Workflow()
+    # provide stages without initialisation
+    CustomWorkflow.add(mlair.ExperimentSetup, epochs=128)
+    # add also keyword arguments for a specific stage
+    CustomWorkflow.add(CustomStage, test_string="Hello World")
+    # finally execute custom workflow in order of adding
+    CustomWorkflow.run()
+
+.. code-block::
+
+    INFO: mlair started
+    ...
+    INFO: ExperimentSetup finished after 00:00:12 (hh:mm:ss)
+    INFO: CustomStage started
+    INFO: Just running a custom stage.
+    INFO: test_string = Hello World
+    INFO: epochs = 128
+    INFO: CustomStage finished after 00:00:01 (hh:mm:ss)
+    INFO: mlair finished after 00:00:13 (hh:mm:ss)
+
+Custom Model
 ~~~~~~~~~~~~
 
-Data
-~~~~
+Each model has to inherit from the abstract model class to ensure a smooth training and evaluation behaviour. It is
+required to implement the set model and set compile options methods. The later has to set the loss at least.
+
+.. code-block:: python
+
+    import keras
+    from keras.losses import mean_squared_error as mse
+    from keras.optimizers import SGD
+
+    from mlair.model_modules import AbstractModelClass
+
+    class MyLittleModel(AbstractModelClass):
+        """
+        A customised model with a 1x1 Conv, and 3 Dense layers (32, 16
+        window_lead_time). Dropout is used after Conv layer.
+        """
+        def __init__(self, window_history_size, window_lead_time, channels):
+            super().__init__()
+            # settings
+            self.window_history_size = window_history_size
+            self.window_lead_time = window_lead_time
+            self.channels = channels
+            self.dropout_rate = 0.1
+            self.activation = keras.layers.PReLU
+            self.lr = 1e-2
+            # apply to model
+            self.set_model()
+            self.set_compile_options()
+            self.set_custom_objects(loss=self.compile_options['loss'])
+
+        def set_model(self):
+            # add 1 to window_size to include current time step t0
+            shape = (self.window_history_size + 1, 1, self.channels)
+            x_input = keras.layers.Input(shape=shape)
+            x_in = keras.layers.Conv2D(32, (1, 1), padding='same')(x_input)
+            x_in = self.activation()(x_in)
+            x_in = keras.layers.Flatten()(x_in)
+            x_in = keras.layers.Dropout(self.dropout_rate)(x_in)
+            x_in = keras.layers.Dense(32)(x_in)
+            x_in = self.activation()(x_in)
+            x_in = keras.layers.Dense(16)(x_in)
+            x_in = self.activation()(x_in)
+            x_in = keras.layers.Dense(self.window_lead_time)(x_in)
+            out = self.activation()(x_in)
+            self.model = keras.Model(inputs=x_input, outputs=[out])
+
+        def set_compile_options(self):
+            self.compile_options = {"optimizer": SGD(lr=self.lr),
+                                    "loss": mse,
+                                    "metrics": ["mse"]}
diff --git a/src/__init__.py b/mlair/__init__.py
similarity index 61%
rename from src/__init__.py
rename to mlair/__init__.py
index 5b7073ff042f6173fd78362f55d698eb6745552f..7f55e47abd709d5747bf54d89595fa66f5839c64 100644
--- a/src/__init__.py
+++ b/mlair/__init__.py
@@ -1,12 +1,13 @@
 __version_info__ = {
     'major': 0,
-    'minor': 9,
+    'minor': 10,
     'micro': 0,
 }
 
-from src.run_modules import *
-from src.workflows import DefaultWorkflow, Workflow
-
+from mlair.run_modules import RunEnvironment, ExperimentSetup, PreProcessing, ModelSetup, Training, PostProcessing
+from mlair.workflows import DefaultWorkflow, Workflow
+from mlair.run_script import run
+from mlair.model_modules import AbstractModelClass
 
 
 def get_version():
diff --git a/src/configuration/.gitignore b/mlair/configuration/.gitignore
similarity index 100%
rename from src/configuration/.gitignore
rename to mlair/configuration/.gitignore
diff --git a/src/configuration/__init__.py b/mlair/configuration/__init__.py
similarity index 100%
rename from src/configuration/__init__.py
rename to mlair/configuration/__init__.py
diff --git a/src/configuration/defaults.py b/mlair/configuration/defaults.py
similarity index 91%
rename from src/configuration/defaults.py
rename to mlair/configuration/defaults.py
index 0038bb5512d602150905f6504bcd5e135b127382..31746ec889cc82ebbae8de82a05c5cff02a22ac0 100644
--- a/src/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -13,7 +13,8 @@ DEFAULT_START = "1997-01-01"
 DEFAULT_END = "2017-12-31"
 DEFAULT_WINDOW_HISTORY_SIZE = 13
 DEFAULT_OVERWRITE_LOCAL_DATA = False
-DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"}
+# DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"}
+DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise"}
 DEFAULT_HPC_LOGIN_LIST = ["ju", "hdfmll"]  # ju[wels} #hdfmll(ogin)
 DEFAULT_HPC_HOST_LIST = ["jw", "hdfmlc"]  # first part of node names for Juwels (jw[comp], hdfmlc(ompute).
 DEFAULT_CREATE_NEW_MODEL = True
@@ -28,9 +29,9 @@ DEFAULT_TARGET_VAR = "o3"
 DEFAULT_TARGET_DIM = "variables"
 DEFAULT_WINDOW_LEAD_TIME = 3
 DEFAULT_DIMENSIONS = {"new_index": ["datetime", "Stations"]}
-DEFAULT_INTERPOLATE_DIM = "datetime"
-DEFAULT_INTERPOLATE_METHOD = "linear"
-DEFAULT_LIMIT_NAN_FILL = 1
+DEFAULT_TIME_DIM = "datetime"
+DEFAULT_INTERPOLATION_METHOD = "linear"
+DEFAULT_INTERPOLATION_LIMIT = 1
 DEFAULT_TRAIN_START = "1997-01-01"
 DEFAULT_TRAIN_END = "2007-12-31"
 DEFAULT_TRAIN_MIN_LENGTH = 90
diff --git a/src/configuration/join_settings.py b/mlair/configuration/join_settings.py
similarity index 100%
rename from src/configuration/join_settings.py
rename to mlair/configuration/join_settings.py
diff --git a/src/configuration/path_config.py b/mlair/configuration/path_config.py
similarity index 69%
rename from src/configuration/path_config.py
rename to mlair/configuration/path_config.py
index 7af25875eea58de081012fc6040a76a04f001d54..9b3d6f250d97d93dd1d06004690885f44de30073 100644
--- a/src/configuration/path_config.py
+++ b/mlair/configuration/path_config.py
@@ -6,7 +6,8 @@ import re
 import socket
 from typing import Tuple
 
-ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+# ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+ROOT_PATH = os.getcwd()
 
 
 def prepare_host(create_new=True, data_path=None, sampling="daily") -> str:
@@ -23,35 +24,38 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str:
 
     :return: full path to data
     """
-    hostname = get_host()
-    user = getpass.getuser()
-    runner_regex = re.compile(r"runner-.*-project-2411-concurrent-\d+")
-    if hostname == "ZAM144":
-        path = f"/home/{user}/Data/toar_{sampling}/"
-    elif hostname == "zam347":
-        path = f"/home/{user}/Data/toar_{sampling}/"
-    elif hostname == "linux-aa9b":
-        path = f"/home/{user}/machinelearningtools/data/toar_{sampling}/"
-    elif (len(hostname) > 2) and (hostname[:2] == "jr"):
-        path = f"/p/project/cjjsc42/{user}/DATA/toar_{sampling}/"
-    elif (len(hostname) > 2) and (hostname[:2] in ['jw', 'ju'] or hostname[:5] in ['hdfml']):
-        path = f"/p/project/deepacf/intelliaq/{user}/DATA/toar_{sampling}/"
-    elif runner_regex.match(hostname) is not None:
-        path = f"/home/{user}/machinelearningtools/data/toar_{sampling}/"
-    else:
-        raise OSError(f"unknown host '{hostname}'")
-    if not os.path.exists(path):
+    if data_path is None:
+        hostname = get_host()
+        user = getpass.getuser()
+        runner_regex = re.compile(r"runner-.*-project-2411-concurrent-\d+")
+        if hostname == "ZAM144":
+            data_path = f"/home/{user}/Data/toar_{sampling}/"
+        elif hostname == "zam347":
+            data_path = f"/home/{user}/Data/toar_{sampling}/"
+        elif hostname == "linux-aa9b":
+            data_path = f"/home/{user}/mlair/data/toar_{sampling}/"
+        elif (len(hostname) > 2) and (hostname[:2] == "jr"):
+            data_path = f"/p/project/cjjsc42/{user}/DATA/toar_{sampling}/"
+        elif (len(hostname) > 2) and (hostname[:2] in ['jw', 'ju'] or hostname[:5] in ['hdfml']):
+            data_path = f"/p/project/deepacf/intelliaq/{user}/DATA/toar_{sampling}/"
+        elif runner_regex.match(hostname) is not None:
+            data_path = f"/home/{user}/mlair/data/toar_{sampling}/"
+        else:
+            data_path = os.path.join(os.getcwd(), "data", sampling)
+            # raise OSError(f"unknown host '{hostname}'")
+
+    if not os.path.exists(data_path):
         try:
             if create_new:
-                check_path_and_create(path)
-                return path
+                check_path_and_create(data_path)
+                return data_path
             else:
                 raise PermissionError
         except PermissionError:
-            raise NotADirectoryError(f"path '{path}' does not exist for host '{hostname}'.")
+            raise NotADirectoryError(f"path '{data_path}' does not exist for host '{hostname}'.")
     else:
-        logging.debug(f"set path to: {path}")
-        return path
+        logging.debug(f"set path to: {data_path}")
+        return data_path
 
 
 def set_experiment_path(name: str, path: str = None) -> str:
diff --git a/src/data_handling/__init__.py b/mlair/data_handler/__init__.py
similarity index 59%
rename from src/data_handling/__init__.py
rename to mlair/data_handler/__init__.py
index cb5aa5db0f29cf51d32ed54e810fa9b363d80cc6..451868b838ab7a0d165942e36b5ec6aa03e42721 100644
--- a/src/data_handling/__init__.py
+++ b/mlair/data_handler/__init__.py
@@ -10,6 +10,6 @@ __date__ = '2020-04-17'
 
 
 from .bootstraps import BootStraps
-from .data_preparation_join import DataPrepJoin
-from .data_generator import DataGenerator
-from .data_distributor import Distributor
+from .iterator import KerasIterator, DataCollection
+from .advanced_data_handler import DefaultDataPreparation, AbstractDataPreparation
+from .data_preparation_neighbors import DataPreparationNeighbors
diff --git a/src/data_handling/advanced_data_handling.py b/mlair/data_handler/advanced_data_handler.py
similarity index 62%
rename from src/data_handling/advanced_data_handling.py
rename to mlair/data_handler/advanced_data_handler.py
index e36e0c75fc9107431a69482d46755acbdf5334bd..57a9667f2a42575faa02d50e439252738a8dc8bb 100644
--- a/src/data_handling/advanced_data_handling.py
+++ b/mlair/data_handler/advanced_data_handler.py
@@ -3,7 +3,7 @@ __author__ = 'Lukas Leufen'
 __date__ = '2020-07-08'
 
 
-from src.helpers import to_list, remove_items
+from mlair.helpers import to_list, remove_items
 import numpy as np
 import xarray as xr
 import pickle
@@ -11,10 +11,15 @@ import os
 import pandas as pd
 import datetime as dt
 import shutil
+import inspect
+import copy
 
-from typing import Union, List, Tuple
+from typing import Union, List, Tuple, Dict
 import logging
 from functools import reduce
+from mlair.data_handler.station_preparation import StationPrep
+from mlair.helpers.join import EmptyQueryResult
+
 
 number = Union[float, int]
 num_or_list = Union[number, List[number]]
@@ -44,25 +49,79 @@ class DummyDataSingleStation:  # pragma: no cover
         return self.name
 
 
-class DataPreparation:
+class AbstractDataPreparation:
+
+    _requirements = []
+
+    def __init__(self, *args, **kwargs):
+        pass
+
+    @classmethod
+    def build(cls, *args, **kwargs):
+        """Return initialised class."""
+        return cls(*args, **kwargs)
+
+    @classmethod
+    def requirements(cls):
+        """Return requirements and own arguments without duplicates."""
+        return list(set(cls._requirements + cls.own_args()))
+
+    @classmethod
+    def own_args(cls, *args):
+        return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args))
+
+    @classmethod
+    def transformation(cls, *args, **kwargs):
+        return None
+
+    def get_X(self, upsampling=False, as_numpy=False):
+        raise NotImplementedError
+
+    def get_Y(self, upsampling=False, as_numpy=False):
+        raise NotImplementedError
+
+    def get_data(self, upsampling=False, as_numpy=False):
+        return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
+
+    def get_coordinates(self) -> Union[None, Dict]:
+        return None
+
+
+class DefaultDataPreparation(AbstractDataPreparation):
 
-    def __init__(self, id_class, interpolate_dim: str, store_path, neighbors=None, min_length=0,
-                 extreme_values: num_or_list = None,extremes_on_right_tail_only: bool = False,):
+    _requirements = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"])
+
+    def __init__(self, id_class, data_path, min_length=0,
+                 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None):
+        super().__init__()
         self.id_class = id_class
-        self.neighbors = to_list(neighbors) if neighbors is not None else []
-        self.interpolate_dim = interpolate_dim
+        self.interpolation_dim = "datetime"
         self.min_length = min_length
         self._X = None
         self._Y = None
         self._X_extreme = None
         self._Y_extreme = None
-        self._save_file = os.path.join(store_path, f"data_preparation_{str(self.id_class)}.pickle")
-        self._collection = []
-        self._create_collection()
+        _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
+        self._save_file = os.path.join(data_path, f"data_preparation_{_name_affix}.pickle")
+        self._collection = self._create_collection()
         self.harmonise_X()
-        self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolate_dim)
+        self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim)
         self._store(fresh_store=True)
 
+    @classmethod
+    def build(cls, station, **kwargs):
+        sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
+        sp = StationPrep(station, **sp_keys)
+        dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
+        return cls(sp, **dp_args)
+
+    def _create_collection(self):
+        return [self.id_class]
+
+    @classmethod
+    def requirements(cls):
+        return remove_items(super().requirements(), "id_class")
+
     def _reset_data(self):
         self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
 
@@ -98,10 +157,6 @@ class DataPreparation:
         self._reset_data()
         return X, Y
 
-    def _create_collection(self):
-        for data_class in [self.id_class] + self.neighbors:
-            self._collection.append(data_class)
-
     def __repr__(self):
         return ";".join(list(map(lambda x: str(x), self._collection)))
 
@@ -119,23 +174,23 @@ class DataPreparation:
     def _to_numpy(d):
         return list(map(lambda x: np.copy(x), d))
 
-    def get_X(self, upsamling=False, as_numpy=True):
+    def get_X(self, upsampling=False, as_numpy=True):
         no_data = (self._X is None)
         self._load() if no_data is True else None
-        X = self._X if upsamling is False else self._X_extreme
+        X = self._X if upsampling is False else self._X_extreme
         self._reset_data() if no_data is True else None
         return self._to_numpy(X) if as_numpy is True else X
 
-    def get_Y(self, upsamling=False, as_numpy=True):
+    def get_Y(self, upsampling=False, as_numpy=True):
         no_data = (self._Y is None)
         self._load() if no_data is True else None
-        Y = self._Y if upsamling is False else self._Y_extreme
+        Y = self._Y if upsampling is False else self._Y_extreme
         self._reset_data() if no_data is True else None
         return self._to_numpy([Y]) if as_numpy is True else Y
 
     def harmonise_X(self):
         X_original, Y_original = self.get_X_original(), self.get_Y_original()
-        dim = self.interpolate_dim
+        dim = self.interpolation_dim
         intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original))
         if len(intersect) < max(self.min_length, 1):
             X, Y = None, None
@@ -144,6 +199,12 @@ class DataPreparation:
             Y = Y_original.sel({dim: intersect})
         self._X, self._Y = X, Y
 
+    def get_observation(self):
+        return self.id_class.observation.copy().squeeze()
+
+    def get_transformation_Y(self):
+        return self.id_class.get_transformation_information()
+
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
                           timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"):
         """
@@ -212,52 +273,84 @@ class DataPreparation:
         for d in data:
             d.coords[dim].values += np.timedelta64(*timedelta)
 
+    @classmethod
+    def transformation(cls, set_stations, **kwargs):
+        sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
+        transformation_dict = sp_keys.pop("transformation")
+        if transformation_dict is None:
+            return
+        scope = transformation_dict.pop("scope")
+        method = transformation_dict.pop("method")
+        if transformation_dict.pop("mean", None) is not None:
+            return
+        mean, std = None, None
+        for station in set_stations:
+            try:
+                sp = StationPrep(station, transformation={"method": method}, **sp_keys)
+                mean = sp.mean.copy(deep=True) if mean is None else mean.combine_first(sp.mean)
+                std = sp.std.copy(deep=True) if std is None else std.combine_first(sp.std)
+            except (AttributeError, EmptyQueryResult):
+                continue
+        if mean is None:
+            return None
+        mean_estimated = mean.mean("Stations")
+        std_estimated = std.mean("Stations")
+        return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated}
+
+    def get_coordinates(self):
+        return self.id_class.get_coordinates()
+
 
 def run_data_prep():
 
+    from .data_preparation_neighbors import DataPreparationNeighbors
     data = DummyDataSingleStation("main_class")
     data.get_X()
     data.get_Y()
 
     path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
-    data_prep = DataPreparation(DummyDataSingleStation("main_class"), "datetime", path,
-                                neighbors=[DummyDataSingleStation("neighbor1"), DummyDataSingleStation("neighbor2")],
-                                extreme_values=[1., 1.2])
+    data_prep = DataPreparationNeighbors(DummyDataSingleStation("main_class"),
+                                         path,
+                                         neighbors=[DummyDataSingleStation("neighbor1"),
+                                                    DummyDataSingleStation("neighbor2")],
+                                         extreme_values=[1., 1.2])
     data_prep.get_data(upsampling=False)
 
 
 def create_data_prep():
 
+    from .data_preparation_neighbors import DataPreparationNeighbors
     path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
     station_type = None
     network = 'UBA'
     sampling = 'daily'
     target_dim = 'variables'
     target_var = 'o3'
-    interpolate_dim = 'datetime'
+    interpolation_dim = 'datetime'
     window_history_size = 7
     window_lead_time = 3
-    central_station = StationPrep(path, "DEBW011", {'o3': 'dma8eu', 'temp': 'maximum'}, {},station_type, network, sampling, target_dim,
-                                  target_var, interpolate_dim, window_history_size, window_lead_time)
-    neighbor1 = StationPrep(path, "DEBW013", {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},station_type, network, sampling, target_dim,
-                                  target_var, interpolate_dim, window_history_size, window_lead_time)
-    neighbor2 = StationPrep(path, "DEBW034", {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim,
-                                  target_var, interpolate_dim, window_history_size, window_lead_time)
+    central_station = StationPrep("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {},station_type, network, sampling, target_dim,
+                                  target_var, interpolation_dim, window_history_size, window_lead_time)
+    neighbor1 = StationPrep("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},station_type, network, sampling, target_dim,
+                                  target_var, interpolation_dim, window_history_size, window_lead_time)
+    neighbor2 = StationPrep("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim,
+                                  target_var, interpolation_dim, window_history_size, window_lead_time)
 
     data_prep = []
-    data_prep.append(DataPreparation(central_station, interpolate_dim, path, neighbors=[neighbor1, neighbor2]))
-    data_prep.append(DataPreparation(neighbor1, interpolate_dim, path, neighbors=[central_station, neighbor2]))
-    data_prep.append(DataPreparation(neighbor2, interpolate_dim, path, neighbors=[neighbor1, central_station]))
+    data_prep.append(DataPreparationNeighbors(central_station, path, neighbors=[neighbor1, neighbor2]))
+    data_prep.append(DataPreparationNeighbors(neighbor1, path, neighbors=[central_station, neighbor2]))
+    data_prep.append(DataPreparationNeighbors(neighbor2, path, neighbors=[neighbor1, central_station]))
     return data_prep
 
+
 if __name__ == "__main__":
-    from src.data_handling.data_preparation import StationPrep
-    from src.data_handling.iterator import KerasIterator, DataCollection
+    from mlair.data_handler.station_preparation import StationPrep
+    from mlair.data_handler.iterator import KerasIterator, DataCollection
     data_prep = create_data_prep()
     data_collection = DataCollection(data_prep)
     for data in data_collection:
         print(data)
     path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras")
-    keras_it = KerasIterator(data_collection, 100, path)
+    keras_it = KerasIterator(data_collection, 100, path, upsampling=True)
     keras_it[2]
 
diff --git a/mlair/data_handler/bootstraps.py b/mlair/data_handler/bootstraps.py
new file mode 100644
index 0000000000000000000000000000000000000000..91603b41822b92e28fbd077c502d84707fff746f
--- /dev/null
+++ b/mlair/data_handler/bootstraps.py
@@ -0,0 +1,130 @@
+"""
+Collections of bootstrap methods and classes.
+
+How to use
+----------
+
+test
+
+"""
+
+__author__ = 'Felix Kleinert, Lukas Leufen'
+__date__ = '2020-02-07'
+
+
+import os
+from collections import Iterator, Iterable
+from itertools import chain
+
+import numpy as np
+import xarray as xr
+
+from mlair.data_handler.advanced_data_handler import AbstractDataPreparation
+
+
+class BootstrapIterator(Iterator):
+
+    _position: int = None
+
+    def __init__(self, data: "BootStraps"):
+        assert isinstance(data, BootStraps)
+        self._data = data
+        self._dimension = data.bootstrap_dimension
+        self._collection = self._data.bootstraps()
+        self._position = 0
+
+    def __next__(self):
+        """Return next element or stop iteration."""
+        try:
+            index, dimension = self._collection[self._position]
+            nboot = self._data.number_of_bootstraps
+            _X, _Y = self._data.data.get_data(as_numpy=False)
+            _X = list(map(lambda x: x.expand_dims({'boots': range(nboot)}, axis=-1), _X))
+            _Y = _Y.expand_dims({"boots": range(nboot)}, axis=-1)
+            single_variable = _X[index].sel({self._dimension: [dimension]})
+            shuffled_variable = self.shuffle(single_variable.values)
+            shuffled_data = xr.DataArray(shuffled_variable, coords=single_variable.coords, dims=single_variable.dims)
+            _X[index] = shuffled_data.combine_first(_X[index]).reindex_like(_X[index])
+            self._position += 1
+        except IndexError:
+            raise StopIteration()
+        _X, _Y = self._to_numpy(_X), self._to_numpy(_Y)
+        return self._reshape(_X), self._reshape(_Y), (index, dimension)
+
+    @staticmethod
+    def _reshape(d):
+        if isinstance(d, list):
+            return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d))
+        else:
+            shape = d.shape
+            return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1])
+
+    @staticmethod
+    def _to_numpy(d):
+        if isinstance(d, list):
+            return list(map(lambda x: x.values, d))
+        else:
+            return d.values
+
+    @staticmethod
+    def shuffle(data: np.ndarray) -> np.ndarray:
+        """
+        Shuffle randomly from given data (draw elements with replacement).
+
+        :param data: data to shuffle
+        :return: shuffled data as numpy array
+        """
+        size = data.shape
+        return np.random.choice(data.reshape(-1, ), size=size)
+
+
+class BootStraps(Iterable):
+    """
+    Main class to perform bootstrap operations.
+
+    This class requires a data handler following the definition of the AbstractDataPreparation, the number of bootstraps
+    to create and the dimension along this bootstrapping is performed (default dimension is `variables`).
+
+    When iterating on this class, it returns the bootstrapped X, Y and a tuple with (position of variable in X, name of
+    this variable). The tuple is interesting if X consists on mutliple input streams X_i (e.g. two or more stations)
+    because it shows which variable of which input X_i has been bootstrapped. All bootstrap combinations can be
+    retrieved by calling the .bootstraps() method. Further more, by calling the .get_orig_prediction() this class
+    imitates according to the set number of bootstraps the original prediction
+    """
+    def __init__(self, data: AbstractDataPreparation, number_of_bootstraps: int = 10,
+                 bootstrap_dimension: str = "variables"):
+        """
+        Create iterable class to be ready to iter.
+
+        :param data: a data generator object to get data / history
+        :param number_of_bootstraps: the number of bootstrap realisations
+        """
+        self.data = data
+        self.number_of_bootstraps = number_of_bootstraps
+        self.bootstrap_dimension = bootstrap_dimension
+
+    def __iter__(self):
+        return BootstrapIterator(self)
+
+    def __len__(self):
+        return len(self.bootstraps())
+
+    def bootstraps(self):
+        l = []
+        for i, x in enumerate(self.data.get_X(as_numpy=False)):
+            l.append(list(map(lambda y: (i, y), x.indexes['variables'])))
+        return list(chain(*l))
+
+    def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN") -> np.ndarray:
+        """
+        Repeat predictions from given file(_name) in path by the number of boots.
+
+        :param path: path to file
+        :param file_name: file name
+        :param prediction_name: name of the prediction to select from loaded file (default CNN)
+        :return: repeated predictions
+        """
+        file = os.path.join(path, file_name)
+        prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
+        vals = np.tile(prediction.data, (self.number_of_bootstraps, 1))
+        return vals[~np.isnan(vals).any(axis=1), :]
diff --git a/mlair/data_handler/data_preparation_neighbors.py b/mlair/data_handler/data_preparation_neighbors.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c95b242e1046618403ebb6592407ef8b680e890
--- /dev/null
+++ b/mlair/data_handler/data_preparation_neighbors.py
@@ -0,0 +1,64 @@
+
+__author__ = 'Lukas Leufen'
+__date__ = '2020-07-17'
+
+
+from mlair.helpers import to_list
+from mlair.data_handler.station_preparation import StationPrep
+from mlair.data_handler.advanced_data_handler import DefaultDataPreparation
+import os
+
+from typing import Union, List
+
+number = Union[float, int]
+num_or_list = Union[number, List[number]]
+
+
+class DataPreparationNeighbors(DefaultDataPreparation):
+
+    def __init__(self, id_class, data_path, neighbors=None, min_length=0,
+                 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
+        self.neighbors = to_list(neighbors) if neighbors is not None else []
+        super().__init__(id_class, data_path, min_length=min_length, extreme_values=extreme_values,
+                         extremes_on_right_tail_only=extremes_on_right_tail_only)
+
+    @classmethod
+    def build(cls, station, **kwargs):
+        sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
+        sp = StationPrep(station, **sp_keys)
+        n_list = []
+        for neighbor in kwargs.get("neighbors", []):
+            n_list.append(StationPrep(neighbor, **sp_keys))
+        else:
+            kwargs["neighbors"] = n_list if len(n_list) > 0 else None
+        dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs}
+        return cls(sp, **dp_args)
+
+    def _create_collection(self):
+        return [self.id_class] + self.neighbors
+
+    def get_coordinates(self, include_neighbors=False):
+        neighbors = list(map(lambda n: n.get_coordinates(), self.neighbors)) if include_neighbors is True else []
+        return [super(DataPreparationNeighbors, self).get_coordinates()].append(neighbors)
+
+
+if __name__ == "__main__":
+
+    a = DataPreparationNeighbors
+    requirements = a.requirements()
+
+    kwargs = {"path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
+              "station_type": None,
+              "network": 'UBA',
+              "sampling": 'daily',
+              "target_dim": 'variables',
+              "target_var": 'o3',
+              "time_dim": 'datetime',
+              "window_history_size": 7,
+              "window_lead_time": 3,
+              "neighbors": ["DEBW034"],
+              "data_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
+              "statistics_per_var":  {'o3': 'dma8eu', 'temp': 'maximum'},
+              "transformation": None,}
+    a_inst = a.build("DEBW011", **kwargs)
+    print(a_inst)
diff --git a/src/data_handling/iterator.py b/mlair/data_handler/iterator.py
similarity index 72%
rename from src/data_handling/iterator.py
rename to mlair/data_handler/iterator.py
index 14d71a9afc23d3a0d80bacf60bbaa928fb34407a..49569405a587920da795820d48f8d968a8142cc7 100644
--- a/src/data_handling/iterator.py
+++ b/mlair/data_handler/iterator.py
@@ -33,23 +33,51 @@ class StandardIterator(Iterator):
 
 class DataCollection(Iterable):
 
-    def __init__(self, collection: list):
+    def __init__(self, collection: list = None):
+        if collection is None:
+            collection = []
         assert isinstance(collection, list)
         self._collection = collection
+        self._mapping = {}
+        self._set_mapping()
+
+    def __len__(self):
+        return len(self._collection)
 
     def __iter__(self) -> Iterator:
         return StandardIterator(self._collection)
 
+    def __getitem__(self, index):
+        if isinstance(index, int):
+            return self._collection[index]
+        else:
+            return self._collection[self._mapping[str(index)]]
+
+    def add(self, element):
+        self._collection.append(element)
+        self._mapping[str(element)] = len(self._collection)
+
+    def _set_mapping(self):
+        for i, e in enumerate(self._collection):
+            self._mapping[str(e)] = i
+
+    def keys(self):
+        return list(self._mapping.keys())
+
 
 class KerasIterator(keras.utils.Sequence):
 
-    def __init__(self, collection: DataCollection, batch_size: int, path: str, shuffle: bool = False):
+    def __init__(self, collection: DataCollection, batch_size: int, batch_path: str, shuffle_batches: bool = False,
+                 model=None, upsampling=False, name=None):
         self._collection = collection
-        self._path = os.path.join(path, "%i.pickle")
+        batch_path = os.path.join(batch_path, str(name if name is not None else id(self)))
+        self._path = os.path.join(batch_path, "%i.pickle")
         self.batch_size = batch_size
-        self.shuffle = shuffle
+        self.model = model
+        self.shuffle = shuffle_batches
+        self.upsampling = upsampling
         self.indexes: list = []
-        self._cleanup_path(path)
+        self._cleanup_path(batch_path)
         self._prepare_batches()
 
     def __len__(self) -> int:
@@ -59,6 +87,19 @@ class KerasIterator(keras.utils.Sequence):
         """Get batch for given index."""
         return self.__data_generation(self.indexes[index])
 
+    def _get_model_rank(self):
+        if self.model is not None:
+            mod_out = self.model.output_shape
+            if isinstance(mod_out, tuple):  # only one output branch: (None, ahead)
+                mod_rank = 1
+            elif isinstance(mod_out, list):  # multiple output branches, e.g.: [(None, ahead), (None, ahead)]
+                mod_rank = len(mod_out)
+            else:  # pragma: no cover
+                raise TypeError("model output shape must either be tuple or list.")
+            return mod_rank
+        else:  # no model provided, assume to use single output
+            return 1
+
     def __data_generation(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
         """Load pickle data from disk."""
         file = self._path % index
@@ -75,6 +116,12 @@ class KerasIterator(keras.utils.Sequence):
         """Get batch according to batch size from data list."""
         return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list))
 
+    def _permute_data(self, X, Y):
+        p = np.random.permutation(len(X[0]))  # equiv to .shape[0]
+        X = list(map(lambda x: x[p], X))
+        Y = list(map(lambda x: x[p], Y))
+        return X, Y
+
     def _prepare_batches(self) -> None:
         """
         Prepare all batches as locally stored files.
@@ -86,8 +133,12 @@ class KerasIterator(keras.utils.Sequence):
         """
         index = 0
         remaining = None
+        mod_rank = self._get_model_rank()
         for data in self._collection:
-            X, Y = data.get_X(), data.get_Y()
+            X = data.get_X(upsampling=self.upsampling)
+            Y = [data.get_Y(upsampling=self.upsampling)[0] for _ in range(mod_rank)]
+            if self.upsampling:
+                X, Y = self._permute_data(X, Y)
             if remaining is not None:
                 X, Y = self._concatenate(X, remaining[0]), self._concatenate(Y, remaining[1])
             length = X[0].shape[0]
diff --git a/mlair/data_handler/station_preparation.py b/mlair/data_handler/station_preparation.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff8496ab30a3b6392ea2314ef2526c80e0f57591
--- /dev/null
+++ b/mlair/data_handler/station_preparation.py
@@ -0,0 +1,701 @@
+"""Data Preparation class to handle data processing for machine learning."""
+
+__author__ = 'Lukas Leufen, Felix Kleinert'
+__date__ = '2020-07-20'
+
+import datetime as dt
+import logging
+import os
+from functools import reduce
+from typing import Union, List, Iterable, Tuple, Dict
+
+import numpy as np
+import pandas as pd
+import xarray as xr
+
+from mlair.configuration import check_path_and_create
+from mlair import helpers
+from mlair.helpers import join, statistics
+
+# define a more general date type for type hinting
+date = Union[dt.date, dt.datetime]
+str_or_list = Union[str, List[str]]
+number = Union[float, int]
+num_or_list = Union[number, List[number]]
+data_or_none = Union[xr.DataArray, None]
+
+# defaults
+DEFAULT_STATION_TYPE = "background"
+DEFAULT_NETWORK = "AIRBASE"
+DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
+                        'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
+                        'pblheight': 'maximum'}
+DEFAULT_WINDOW_LEAD_TIME = 3
+DEFAULT_WINDOW_HISTORY_SIZE = 13
+DEFAULT_TIME_DIM = "datetime"
+DEFAULT_TARGET_VAR = "o3"
+DEFAULT_TARGET_DIM = "variables"
+DEFAULT_SAMPLING = "daily"
+DEFAULT_INTERPOLATION_METHOD = "linear"
+
+
+class AbstractStationPrep(object):
+    def __init__(self): #, path, station, statistics_per_var, transformation, **kwargs):
+        pass
+
+    def get_X(self):
+        raise NotImplementedError
+
+    def get_Y(self):
+        raise NotImplementedError
+
+
+class StationPrep(AbstractStationPrep):
+
+    def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE,
+                 network=DEFAULT_NETWORK, sampling=DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM,
+                 target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM,
+                 window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_lead_time=DEFAULT_WINDOW_LEAD_TIME,
+                 interpolation_limit: int = 0, interpolation_method: str = DEFAULT_INTERPOLATION_METHOD,
+                 overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
+                 min_length: int = 0, start=None, end=None, **kwargs):
+        super().__init__()  # path, station, statistics_per_var, transformation, **kwargs)
+        self.station = helpers.to_list(station)
+        self.path = os.path.abspath(data_path)
+        self.statistics_per_var = statistics_per_var
+        self.transformation = self.setup_transformation(transformation)
+
+        self.station_type = station_type
+        self.network = network
+        self.sampling = sampling
+        self.target_dim = target_dim
+        self.target_var = target_var
+        self.time_dim = time_dim
+        self.window_history_size = window_history_size
+        self.window_lead_time = window_lead_time
+
+        self.interpolation_limit = interpolation_limit
+        self.interpolation_method = interpolation_method
+
+        self.overwrite_local_data = overwrite_local_data
+        self.store_data_locally = store_data_locally
+        self.min_length = min_length
+        self.start = start
+        self.end = end
+
+        # internal
+        self.data = None
+        self.meta = None
+        self.variables = kwargs.get('variables', list(statistics_per_var.keys()))
+        self.history = None
+        self.label = None
+        self.observation = None
+
+        # internal for transformation
+        self.mean = None
+        self.std = None
+        self.max = None
+        self.min = None
+        self._transform_method = None
+
+        self.kwargs = kwargs
+        # self.kwargs["overwrite_local_data"] = overwrite_local_data
+
+        # self.make_samples()
+        self.setup_samples()
+
+    def __str__(self):
+        return self.station[0]
+
+    def __len__(self):
+        assert len(self.get_X()) == len(self.get_Y())
+        return len(self.get_X())
+
+    @property
+    def shape(self):
+        return self.data.shape, self.get_X().shape, self.get_Y().shape
+
+    def __repr__(self):
+        return f"StationPrep(station={self.station}, data_path='{self.path}', " \
+               f"statistics_per_var={self.statistics_per_var}, " \
+               f"station_type='{self.station_type}', network='{self.network}', " \
+               f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \
+               f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \
+               f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \
+               f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data}, " \
+               f"transformation={self._print_transformation_as_string}, **{self.kwargs})"
+
+    @property
+    def _print_transformation_as_string(self):
+        str_name = ''
+        if self.transformation is None:
+            str_name = f'{None}'
+        else:
+            for k, v in self.transformation.items():
+                if v is not None:
+                    try:
+                        v_pr = f"xr.DataArray.from_dict({v.to_dict()})"
+                    except AttributeError:
+                        v_pr = f"'{v}'"
+                    str_name += f"'{k}':{v_pr}, "
+            str_name = f"{{{str_name}}}"
+        return str_name
+
+    def get_transposed_history(self) -> xr.DataArray:
+        """Return history.
+
+        :return: history with dimensions datetime, window, Stations, variables.
+        """
+        return self.history.transpose("datetime", "window", "Stations", "variables").copy()
+
+    def get_transposed_label(self) -> xr.DataArray:
+        """Return label.
+
+        :return: label with dimensions datetime*, window*, Stations, variables.
+        """
+        return self.label.squeeze("Stations").transpose("datetime", "window").copy()
+
+    def get_X(self):
+        return self.get_transposed_history()
+
+    def get_Y(self):
+        return self.get_transposed_label()
+
+    def get_coordinates(self):
+        coords = self.meta.loc[["station_lon", "station_lat"]].astype(float)
+        return coords.rename(index={"station_lon": "lon", "station_lat": "lat"}).to_dict()[str(self)]
+
+    def call_transform(self, inverse=False):
+        self.transform(dim=self.time_dim, method=self.transformation["method"],
+                       mean=self.transformation['mean'], std=self.transformation["std"],
+                       min_val=self.transformation["min"], max_val=self.transformation["max"],
+                       inverse=inverse
+                       )
+
+    def set_transformation(self, transformation: dict):
+        if self._transform_method is not None:
+            self.call_transform(inverse=True)
+        self.transformation = self.setup_transformation(transformation)
+        self.call_transform()
+        self.make_samples()
+
+    def setup_samples(self):
+        """
+        Setup samples. This method prepares and creates samples X, and labels Y.
+        """
+        self.load_data()
+        self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit)
+        if self.transformation is not None:
+            self.call_transform()
+        self.make_samples()
+
+    def make_samples(self):
+        self.make_history_window(self.target_dim, self.window_history_size, self.time_dim)
+        self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time)
+        self.make_observation(self.target_dim, self.target_var, self.time_dim)
+        self.remove_nan(self.time_dim)
+
+    def read_data_from_disk(self, source_name=""):
+        """
+        Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
+
+        Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both
+        cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
+        set, it is assumed, that data should be saved locally.
+        """
+        source_name = source_name if len(source_name) == 0 else f" from {source_name}"
+        check_path_and_create(self.path)
+        file_name = self._set_file_name()
+        meta_file = self._set_meta_file_name()
+        if self.overwrite_local_data is True:
+            logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}")
+            if os.path.exists(file_name):
+                os.remove(file_name)
+            if os.path.exists(meta_file):
+                os.remove(meta_file)
+            data, self.meta = self.download_data(file_name, meta_file)
+            logging.debug(f"loaded new data{source_name}")
+        else:
+            try:
+                logging.debug(f"try to load local data from: {file_name}")
+                data = xr.open_dataarray(file_name)
+                self.meta = pd.read_csv(meta_file, index_col=0)
+                self.check_station_meta()
+                logging.debug("loading finished")
+            except FileNotFoundError as e:
+                logging.debug(e)
+                logging.debug(f"load new data{source_name}")
+                data, self.meta = self.download_data(file_name, meta_file)
+                logging.debug("loading finished")
+        # create slices and check for negative concentration.
+        data = self._slice_prep(data)
+        self.data = self.check_for_negative_concentrations(data)
+
+    def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]:
+        """
+        Download data from TOAR database using the JOIN interface.
+
+        Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally
+        stored locally using given names for file and meta file.
+
+        :param file_name: name of file to save data to (containing full path)
+        :param meta_file: name of the meta data file (also containing full path)
+
+        :return: downloaded data and its meta data
+        """
+        df_all = {}
+        df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var,
+                                      station_type=self.station_type, network_name=self.network, sampling=self.sampling)
+        df_all[self.station[0]] = df
+        # convert df_all to xarray
+        xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()}
+        xarr = xr.Dataset(xarr).to_array(dim='Stations')
+        if self.store_data_locally is True:
+            # save locally as nc/csv file
+            xarr.to_netcdf(path=file_name)
+            meta.to_csv(meta_file)
+        return xarr, meta
+
+    def download_data(self, file_name, meta_file):
+        data, meta = self.download_data_from_join(file_name, meta_file)
+        return data, meta
+
+    def check_station_meta(self):
+        """
+        Search for the entries in meta data and compare the value with the requested values.
+
+        Will raise a FileNotFoundError if the values mismatch.
+        """
+        if self.station_type is not None:
+            check_dict = {"station_type": self.station_type, "network_name": self.network}
+            for (k, v) in check_dict.items():
+                if v is None:
+                    continue
+                if self.meta.at[k, self.station[0]] != v:
+                    logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
+                                  f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new "
+                                  f"grapping from web.")
+                    raise FileNotFoundError
+
+    def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
+        """
+        Set all negative concentrations to zero.
+
+        Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
+        #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
+        "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
+
+        :param data: data array containing variables to check
+        :param minimum: minimum value, by default this should be 0
+
+        :return: corrected data
+        """
+        chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
+                     "propane", "so2", "toluene"]
+        used_chem_vars = list(set(chem_vars) & set(self.variables))
+        data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
+        return data
+
+    def shift(self, dim: str, window: int) -> xr.DataArray:
+        """
+        Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0).
+
+        :param dim: dimension along shift is applied
+        :param window: number of steps to shift (corresponds to the window length)
+
+        :return: shifted data
+        """
+        start = 1
+        end = 1
+        if window <= 0:
+            start = window
+        else:
+            end = window + 1
+        res = []
+        for w in range(start, end):
+            res.append(self.data.shift({dim: -w}))
+        window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim)
+        res = xr.concat(res, dim=window_array)
+        return res
+
+    @staticmethod
+    def create_index_array(index_name: str, index_value: Iterable[int], squeeze_dim: str) -> xr.DataArray:
+        """
+        Create an 1D xr.DataArray with given index name and value.
+
+        :param index_name: name of dimension
+        :param index_value: values of this dimension
+
+        :return: this array
+        """
+        ind = pd.DataFrame({'val': index_value}, index=index_value)
+        # res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim=squeez/e_dim, drop=True)
+        res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze(
+            dim=squeeze_dim,
+            drop=True
+        )
+        res.name = index_name
+        return res
+
+    def _set_file_name(self):
+        all_vars = sorted(self.statistics_per_var.keys())
+        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc")
+
+    def _set_meta_file_name(self):
+        all_vars = sorted(self.statistics_per_var.keys())
+        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv")
+
+    def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
+                    **kwargs):
+        """
+        Interpolate values according to different methods.
+
+        (Copy paste from dataarray.interpolate_na)
+
+        :param dim:
+                Specifies the dimension along which to interpolate.
+        :param method:
+                {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
+                          'polynomial', 'barycentric', 'krog', 'pchip',
+                          'spline', 'akima'}, optional
+                    String indicating which method to use for interpolation:
+
+                    - 'linear': linear interpolation (Default). Additional keyword
+                      arguments are passed to ``numpy.interp``
+                    - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
+                      'polynomial': are passed to ``scipy.interpolate.interp1d``. If
+                      method=='polynomial', the ``order`` keyword argument must also be
+                      provided.
+                    - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their
+                      respective``scipy.interpolate`` classes.
+        :param limit:
+                    default None
+                    Maximum number of consecutive NaNs to fill. Must be greater than 0
+                    or None for no limit.
+        :param use_coordinate:
+                default True
+                    Specifies which index to use as the x values in the interpolation
+                    formulated as `y = f(x)`. If False, values are treated as if
+                    eqaully-spaced along `dim`. If True, the IndexVariable `dim` is
+                    used. If use_coordinate is a string, it specifies the name of a
+                    coordinate variariable to use as the index.
+        :param kwargs:
+
+        :return: xarray.DataArray
+        """
+        self.data = self.data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate,
+                                             **kwargs)
+
+    def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
+        """
+        Create a xr.DataArray containing history data.
+
+        Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted
+        data. This is used to represent history in the data. Results are stored in history attribute.
+
+        :param dim_name_of_inputs: Name of dimension which contains the input variables
+        :param window: number of time steps to look back in history
+                Note: window will be treated as negative value. This should be in agreement with looking back on
+                a time line. Nonetheless positive values are allowed but they are converted to its negative
+                expression
+        :param dim_name_of_shift: Dimension along shift will be applied
+        """
+        window = -abs(window)
+        self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables})
+
+    def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
+                    window: int) -> None:
+        """
+        Create a xr.DataArray containing labels.
+
+        Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label
+        attribute.
+
+        :param dim_name_of_target: Name of dimension which contains the target variable
+        :param target_var: Name of target variable in 'dimension'
+        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
+        :param window: lead time of label
+        """
+        window = abs(window)
+        self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var})
+
+    def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
+        """
+        Create a xr.DataArray containing observations.
+
+        Observations are defined as value of the current time step t. Set observation attribute.
+
+        :param dim_name_of_target: Name of dimension which contains the observation variable
+        :param target_var: Name of observation variable(s) in 'dimension'
+        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
+        """
+        self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var})
+
+    def remove_nan(self, dim: str) -> None:
+        """
+        Remove all NAs slices along dim which contain nans in history, label and observation.
+
+        This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute.
+
+        :param dim: dimension along the remove is performed.
+        """
+        intersect = []
+        if (self.history is not None) and (self.label is not None):
+            non_nan_history = self.history.dropna(dim=dim)
+            non_nan_label = self.label.dropna(dim=dim)
+            non_nan_observation = self.observation.dropna(dim=dim)
+            intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values,
+                                                non_nan_observation.coords[dim].values))
+
+        if len(intersect) < max(self.min_length, 1):
+            self.history = None
+            self.label = None
+            self.observation = None
+        else:
+            self.history = self.history.sel({dim: intersect})
+            self.label = self.label.sel({dim: intersect})
+            self.observation = self.observation.sel({dim: intersect})
+
+    def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray:
+        """
+        Set start and end date for slicing and execute self._slice().
+
+        :param data: data to slice
+        :param coord: name of axis to slice
+
+        :return: sliced data
+        """
+        start = self.start if self.start is not None else data.coords[coord][0].values
+        end = self.end if self.end is not None else data.coords[coord][-1].values
+        return self._slice(data, start, end, coord)
+
+    @staticmethod
+    def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray:
+        """
+        Slice through a given data_item (for example select only values of 2011).
+
+        :param data: data to slice
+        :param start: start date of slice
+        :param end: end date of slice
+        :param coord: name of axis to slice
+
+        :return: sliced data
+        """
+        return data.loc[{coord: slice(str(start), str(end))}]
+
+    def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
+        """
+        Set all negative concentrations to zero.
+
+        Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
+        #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
+        "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
+
+        :param data: data array containing variables to check
+        :param minimum: minimum value, by default this should be 0
+
+        :return: corrected data
+        """
+        chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
+                     "propane", "so2", "toluene"]
+        used_chem_vars = list(set(chem_vars) & set(self.variables))
+        data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
+        return data
+
+    @staticmethod
+    def setup_transformation(transformation: Dict):
+        """
+        Set up transformation by extracting all relevant information.
+
+        Extract all information from transformation dictionary. Possible keys are method, mean, std, min, max.
+        * If a transformation should be applied on base of existing values, these need to be provided in the respective
+          keys "mean" and "std" (again only if required for given method).
+
+        :param transformation: the transformation dictionary as described above.
+
+        :return: updated transformation dictionary
+        """
+        if transformation is None:
+            return
+        elif not isinstance(transformation, dict):
+            raise TypeError(f"`transformation' must be either `None' or dict like e.g. `{{'method': 'standardise'}},"
+                            f" but transformation is of type {type(transformation)}.")
+        transformation = transformation.copy()
+        method = transformation.get("method", None)
+        mean = transformation.get("mean", None)
+        std = transformation.get("std", None)
+        max_val = transformation.get("max", None)
+        min_val = transformation.get("min", None)
+
+        transformation["method"] = method
+        transformation["mean"] = mean
+        transformation["std"] = std
+        transformation["max"] = max_val
+        transformation["min"] = min_val
+        return transformation
+
+    def load_data(self):
+        try:
+            self.read_data_from_disk()
+        except FileNotFoundError:
+            self.download_data()
+            self.load_data()
+
+    def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean=None,
+                  std=None, min_val=None, max_val=None) -> None:
+        """
+        Transform data according to given transformation settings.
+
+        This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
+        and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
+        (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
+        method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
+        internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
+
+        :param string/int dim: This param is not used for inverse transformation.
+                | for xarray.DataArray as string: name of dimension which should be standardised
+                | for pandas.DataFrame as int: axis of dimension which should be standardised
+        :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented
+                    yet. This param is not used for inverse transformation.
+        :param inverse: Switch between transformation and inverse transformation.
+        :param mean: Used for transformation (if required by 'method') based on external data. If 'None' the mean is
+                    calculated over the data in this class instance.
+        :param std: Used for transformation (if required by 'method') based on external data. If 'None' the std is
+                    calculated over the data in this class instance.
+        :param min_val: Used for transformation (if required by 'method') based on external data. If 'None' min_val is
+                    extracted from the data in this class instance.
+        :param max_val: Used for transformation (if required by 'method') based on external data. If 'None' max_val is
+                    extracted from the data in this class instance.
+
+        :return: xarray.DataArrays or pandas.DataFrames:
+                #. mean: Mean of data
+                #. std: Standard deviation of data
+                #. data: Standardised data
+        """
+
+        def f(data):
+            if method == 'standardise':
+                return statistics.standardise(data, dim)
+            elif method == 'centre':
+                return statistics.centre(data, dim)
+            elif method == 'normalise':
+                # use min/max of data or given min/max
+                raise NotImplementedError
+            else:
+                raise NotImplementedError
+
+        def f_apply(data):
+            if method == "standardise":
+                return mean, std, statistics.standardise_apply(data, mean, std)
+            elif method == "centre":
+                return mean, None, statistics.centre_apply(data, mean)
+            else:
+                raise NotImplementedError
+
+        if not inverse:
+            if self._transform_method is not None:
+                raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with "
+                                     f"{self._transform_method}. Please perform inverse transformation of data first.")
+            # apply transformation on local data instance (f) if mean is None, else apply by using mean (and std) from
+            # external data.
+            self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data)
+
+            # set transform method to find correct method for inverse transformation.
+            self._transform_method = method
+        else:
+            self.inverse_transform()
+
+    @staticmethod
+    def check_inverse_transform_params(mean: data_or_none, std: data_or_none, method: str) -> None:
+        """
+        Support inverse_transformation method.
+
+        Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas
+        normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements.
+
+        :param mean: data with all mean values
+        :param std: data with all standard deviation values
+        :param method: name of transformation method
+        """
+        msg = ""
+        if method in ['standardise', 'centre'] and mean is None:
+            msg += "mean, "
+        if method == 'standardise' and std is None:
+            msg += "std, "
+        if len(msg) > 0:
+            raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}")
+
+    def inverse_transform(self) -> None:
+        """
+        Perform inverse transformation.
+
+        Will raise an AssertionError, if no transformation was performed before. Checks first, if all required
+        statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by
+        new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the
+        current data is not transformed.
+        """
+
+        def f_inverse(data, mean, std, method_inverse):
+            if method_inverse == 'standardise':
+                return statistics.standardise_inverse(data, mean, std), None, None
+            elif method_inverse == 'centre':
+                return statistics.centre_inverse(data, mean), None, None
+            elif method_inverse == 'normalise':
+                raise NotImplementedError
+            else:
+                raise NotImplementedError
+
+        if self._transform_method is None:
+            raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.")
+        self.check_inverse_transform_params(self.mean, self.std, self._transform_method)
+        self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method)
+        self._transform_method = None
+        # update X and Y
+        self.make_samples()
+
+    def get_transformation_information(self, variable: str = None) -> Tuple[data_or_none, data_or_none, str]:
+        """
+        Extract transformation statistics and method.
+
+        Get mean and standard deviation for given variable and the transformation method if set. If a transformation
+        depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are
+        returned with None as fill value.
+
+        :param variable: Variable for which the information on transformation is requested.
+
+        :return: mean, standard deviation and transformation method
+        """
+        variable = self.target_var if variable is None else variable
+        try:
+            mean = self.mean.sel({'variables': variable}).values
+        except AttributeError:
+            mean = None
+        try:
+            std = self.std.sel({'variables': variable}).values
+        except AttributeError:
+            std = None
+        return mean, std, self._transform_method
+
+
+if __name__ == "__main__":
+    # dp = AbstractDataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+    # print(dp)
+    statistics_per_var = {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}
+    sp = StationPrep(data_path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122',
+                     statistics_per_var=statistics_per_var, station_type='background',
+                     network='UBA', sampling='daily', target_dim='variables', target_var='o3',
+                     time_dim='datetime', window_history_size=7, window_lead_time=3,
+                     interpolation_limit=0
+                     )  # transformation={'method': 'standardise'})
+    # sp.set_transformation({'method': 'standardise', 'mean': sp.mean+2, 'std': sp.std+1})
+    sp2 = StationPrep(data_path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122',
+                      statistics_per_var=statistics_per_var, station_type='background',
+                      network='UBA', sampling='daily', target_dim='variables', target_var='o3',
+                      time_dim='datetime', window_history_size=7, window_lead_time=3,
+                      transformation={'method': 'standardise'})
+    sp2.transform(inverse=True)
+    sp.get_X()
+    sp.get_Y()
+    print(len(sp))
+    print(sp.shape)
+    print(sp)
diff --git a/src/helpers/__init__.py b/mlair/helpers/__init__.py
similarity index 92%
rename from src/helpers/__init__.py
rename to mlair/helpers/__init__.py
index 546713b3f18f2cb64c1527b57d1e9e2138e927aa..9e2f612c86dc0477693567210493fbdcf3002954 100644
--- a/src/helpers/__init__.py
+++ b/mlair/helpers/__init__.py
@@ -3,4 +3,4 @@
 from .testing import PyTestRegex, PyTestAllEqual
 from .time_tracking import TimeTracking, TimeTrackingWrapper
 from .logger import Logger
-from .helpers import remove_items, float_round, dict_to_xarray, to_list
+from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value
diff --git a/src/helpers/datastore.py b/mlair/helpers/datastore.py
similarity index 100%
rename from src/helpers/datastore.py
rename to mlair/helpers/datastore.py
diff --git a/src/helpers/helpers.py b/mlair/helpers/helpers.py
similarity index 94%
rename from src/helpers/helpers.py
rename to mlair/helpers/helpers.py
index 968ee5385f5a44cdbbce5653a864875011874150..b12d9028747aa677802c4a99e35852b514128e4c 100644
--- a/src/helpers/helpers.py
+++ b/mlair/helpers/helpers.py
@@ -92,3 +92,10 @@ def remove_items(obj: Union[List, Dict], items: Any):
         return remove_from_dict(obj, items)
     else:
         raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.")
+
+
+def extract_value(encapsulated_value):
+    try:
+        return extract_value(encapsulated_value[0])
+    except TypeError:
+        return encapsulated_value
diff --git a/src/helpers/join.py b/mlair/helpers/join.py
similarity index 98%
rename from src/helpers/join.py
rename to mlair/helpers/join.py
index 7d9c3aad23c402ae63f26bdf998074a86e35ffbf..a3c6876e3ea43ff4d03243430cf6cd791d62dec2 100644
--- a/src/helpers/join.py
+++ b/mlair/helpers/join.py
@@ -9,8 +9,8 @@ from typing import Iterator, Union, List, Dict
 import pandas as pd
 import requests
 
-from src import helpers
-from src.configuration.join_settings import join_settings
+from mlair import helpers
+from mlair.configuration.join_settings import join_settings
 
 # join_url_base = 'https://join.fz-juelich.de/services/rest/surfacedata/'
 str_or_none = Union[str, None]
diff --git a/src/helpers/logger.py b/mlair/helpers/logger.py
similarity index 100%
rename from src/helpers/logger.py
rename to mlair/helpers/logger.py
diff --git a/src/helpers/statistics.py b/mlair/helpers/statistics.py
similarity index 100%
rename from src/helpers/statistics.py
rename to mlair/helpers/statistics.py
diff --git a/src/helpers/testing.py b/mlair/helpers/testing.py
similarity index 100%
rename from src/helpers/testing.py
rename to mlair/helpers/testing.py
diff --git a/src/helpers/time_tracking.py b/mlair/helpers/time_tracking.py
similarity index 100%
rename from src/helpers/time_tracking.py
rename to mlair/helpers/time_tracking.py
diff --git a/mlair/model_modules/GUIDE.md b/mlair/model_modules/GUIDE.md
new file mode 100644
index 0000000000000000000000000000000000000000..3cda63538b06a83afe9c0c20d9c6ef46d00633fe
--- /dev/null
+++ b/mlair/model_modules/GUIDE.md
@@ -0,0 +1,49 @@
+
+## Model Extensions
+
+### Inception Blocks
+
+MLAir provides an easy interface to add extensions. Specifically, the code comes with an extension for inception blocks 
+as proposed by Szegedy et al. (2014). Those inception blocks are a collection of multiple network towers. A tower is a 
+collection of successive (standard) layers and generally contains at least a padding layer, and one convolution or a 
+pooling layer. Additionally such towers can also contain an additional convolutional layer of kernel size 1x1 for 
+information compression (reduction of filter size), or batch normalisation layers. 
+
+After initialising the the inception blocks by using *InceptionModelBase*, one can add an arbitrary number of 
+individual inception blocks. The initialisation sets all counters for internal naming conventions.
+
+The inception model requires two dictionaries as inputs specifying the convolutional and the pooling towers, 
+respectively. The convolutional dictionary contains dictionaries for each individual tower, allowing to use different 
+reduction filters, kernel and filter sizes of the main convolution and the activation function. 
+
+See a description [here](https://towardsdatascience.com/a-simple-guide-to-the-versions-of-the-inception-network-7fc52b863202)
+or take a look on the papers [Going Deeper with Convolutions (Szegedy et al., 2014)](https://arxiv.org/abs/1409.4842)
+and [Network In Network (Lin et al., 2014)](https://arxiv.org/abs/1312.4400).
+
+
+### Paddings
+
+For some network layers like convolutions, it is common to pad the input data to prevent shrinking of dimensions. In 
+classical image recognition tasks zero paddings are most often used. In the context of meteorology, a zero padding might 
+create artificial effects on the boundaries. We therefore adopted the symmetric and reflection padding layers from 
+*TensorFlow*, to be used as *Keras* layers. The layers are named *SymmetricPadding2D* and *ReflectionPadding2D*. Both 
+layers need the information on *padding* size. We provide a helper function to calculate the padding size given a 
+convolutional kernel size. 
+
+![pad1](./../../docs/_source/_plots/padding_example1.png)
+
+Additionally, we provide the wrapper class *Padding2D*, which combines symmetric, refection and zero padding. This class 
+allows to switch between different types of padding while keeping the overall model structure untouched. 
+
+![pad2](./../../docs/_source/_plots/padding_example2.png)
+
+This figure shows an example on how to easily apply the wrapper Padding2D and specify the *padding_type* (e.g. 
+"SymmetricPadding2D" or "ReflectionPadding2D"). The following table lists all padding types which are currently 
+supported. The padding wrapper can also handle other user specific padding types.
+
+| padding layer (long name) | short name |
+|---------------------------|------------|
+| ReflectionPadding2D*      | RefPad2D   |
+| SymmetricPadding2D*       | SymPad2D   |
+| ZeroPadding2D**           | ZeroPad2D  |
+\*  implemented in MLAir    \** implemented in keras
diff --git a/src/model_modules/__init__.py b/mlair/model_modules/__init__.py
similarity index 57%
rename from src/model_modules/__init__.py
rename to mlair/model_modules/__init__.py
index 35f4060886036d3f51c24b4480738566ff80a445..ea2067bdfdaacb6290157be681786212b0422812 100644
--- a/src/model_modules/__init__.py
+++ b/mlair/model_modules/__init__.py
@@ -1 +1,3 @@
 """Collection of all modules that are related to a model."""
+
+from .model_class import AbstractModelClass
diff --git a/src/model_modules/advanced_paddings.py b/mlair/model_modules/advanced_paddings.py
similarity index 100%
rename from src/model_modules/advanced_paddings.py
rename to mlair/model_modules/advanced_paddings.py
diff --git a/src/model_modules/flatten.py b/mlair/model_modules/flatten.py
similarity index 100%
rename from src/model_modules/flatten.py
rename to mlair/model_modules/flatten.py
diff --git a/src/model_modules/inception_model.py b/mlair/model_modules/inception_model.py
similarity index 99%
rename from src/model_modules/inception_model.py
rename to mlair/model_modules/inception_model.py
index 74cd4d806f706a70d554adae468e7fa8c5de153e..d7354c37899bbb7d8f80bc76b4cd9237c7df96dc 100644
--- a/src/model_modules/inception_model.py
+++ b/mlair/model_modules/inception_model.py
@@ -6,7 +6,7 @@ import logging
 import keras
 import keras.layers as layers
 
-from src.model_modules.advanced_paddings import PadUtils, ReflectionPadding2D, Padding2D
+from mlair.model_modules.advanced_paddings import PadUtils, ReflectionPadding2D, Padding2D
 
 
 class InceptionModelBase:
diff --git a/src/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
similarity index 99%
rename from src/model_modules/keras_extensions.py
rename to mlair/model_modules/keras_extensions.py
index 479913811a668d8330a389b2876360f096f57dbf..33358e566ef80f28ee7740531b71d1a83abde115 100644
--- a/src/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -13,7 +13,7 @@ import numpy as np
 from keras import backend as K
 from keras.callbacks import History, ModelCheckpoint, Callback
 
-from src import helpers
+from mlair import helpers
 
 
 class HistoryAdvanced(History):
diff --git a/src/model_modules/linear_model.py b/mlair/model_modules/linear_model.py
similarity index 73%
rename from src/model_modules/linear_model.py
rename to mlair/model_modules/linear_model.py
index e556f0358a2a5e5247f7b6cc7d416af25a8a664d..341c787e3060fd7e7cc3ff468ba40add9b9936d2 100644
--- a/src/model_modules/linear_model.py
+++ b/mlair/model_modules/linear_model.py
@@ -42,21 +42,27 @@ class OrdinaryLeastSquaredModel:
         return self.ordinary_least_squared_model(self.x, self.y)
 
     def _set_x_y_from_generator(self):
-        data_x = None
-        data_y = None
+        data_x, data_y = None, None
         for item in self.generator:
-            x = self.reshape_xarray_to_numpy(item[0])
-            y = item[1].values
-            data_x = np.concatenate((data_x, x), axis=0) if data_x is not None else x
-            data_y = np.concatenate((data_y, y), axis=0) if data_y is not None else y
-        self.x = data_x
-        self.y = data_y
+            x, y = item.get_data(as_numpy=True)
+            x = self.flatten(x)
+            data_x = self._concatenate(x, data_x)
+            data_y = self._concatenate(y, data_y)
+        self.x, self.y = np.concatenate(data_x, axis=1), data_y[0]
+
+    def _concatenate(self, new, old):
+        return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new)) if old is not None else new
 
     def predict(self, data):
         """Apply OLS model on data."""
-        data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add")
+        data = sm.add_constant(np.concatenate(self.flatten(data), axis=1), has_constant="add")
         return np.atleast_2d(self.model.predict(data))
 
+    @staticmethod
+    def flatten(data):
+        shapes = list(map(lambda x: x.shape, data))
+        return list(map(lambda x, shape: x.reshape(shape[0], -1), data, shapes))
+
     @staticmethod
     def reshape_xarray_to_numpy(data):
         """Reshape xarray data to numpy data and flatten."""
diff --git a/src/model_modules/loss.py b/mlair/model_modules/loss.py
similarity index 100%
rename from src/model_modules/loss.py
rename to mlair/model_modules/loss.py
diff --git a/src/model_modules/model_class.py b/mlair/model_modules/model_class.py
similarity index 81%
rename from src/model_modules/model_class.py
rename to mlair/model_modules/model_class.py
index dab2e168c5a9f87d4aee42fc94489fd0fa67772a..56e7b4c347a69781854a9cf8ad9a719f7d6ac8b9 100644
--- a/src/model_modules/model_class.py
+++ b/mlair/model_modules/model_class.py
@@ -2,7 +2,7 @@
 Module for neural models to use during experiment.
 
 To work properly, each customised model needs to inherit from AbstractModelClass and needs an implementation of the
-set_model and set_loss method.
+set_model method.
 
 In this module, you can find some exemplary model classes that have been build and were running in a experiment.
 
@@ -33,10 +33,11 @@ How to create a customised model?
 
                 # apply to model
                 self.set_model()
-                self.set_loss()
-                self.set_custom_objects(loss=self.loss)
+                self.set_compile_options()
+                self.set_custom_objects(loss=self.compile_options['loss'])
 
-* Make sure to add the `super().__init__()` and at least `set_model()` and  `set_loss()` to your custom init method.
+* Make sure to add the `super().__init__()` and at least `set_model()` and `set_compile_options()` to your custom init
+  method.
 * If you have custom objects in your model, that are not part of keras, you need to add them to custom objects. To do
   this, call `set_custom_objects` with arbitrarily kwargs. In the shown example, the loss has been added, because it
   wasn't a standard loss. Apart from this, we always encourage you to add the loss as custom object, to prevent
@@ -60,14 +61,20 @@ How to create a customised model?
                 self.model = keras.Model(inputs=x_input, outputs=[out_main])
 
 * Your are free, how to design your model. Just make sure to save it in the class attribute model.
-* Finally, set your custom loss.
+* Additionally, set your custom compile options including the loss.
 
     .. code-block:: python
 
         class MyCustomisedModel(AbstractModelClass):
 
-            def set_loss(self):
+            def set_compile_options(self):
+                self.initial_lr = 1e-2
+                self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
+                self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr,
+                                                                                       drop=.94,
+                                                                                       epochs_drop=10)
                 self.loss = keras.losses.mean_squared_error
+                self.compile_options = {"metrics": ["mse", "mae"]}
 
 * If you have a branched model with multiple outputs, you need either set only a single loss for all branch outputs or
   to provide the same number of loss functions considering the right order. E.g.
@@ -80,7 +87,7 @@ How to create a customised model?
                 ...
                 self.model = keras.Model(inputs=x_input, outputs=[out_minor_1, out_minor_2, out_main])
 
-            def set_loss(self):
+            def set_compile_options(self):
                 self.loss = [keras.losses.mean_absolute_error] +  # for out_minor_1
                             [keras.losses.mean_squared_error] +   # for out_minor_2
                             [keras.losses.mean_squared_error]     # for out_main
@@ -108,10 +115,9 @@ True
 
 """
 
-import src.model_modules.keras_extensions
+import mlair.model_modules.keras_extensions
 
 __author__ = "Lukas Leufen, Felix Kleinert"
-# __date__ = '2019-12-12'
 __date__ = '2020-05-12'
 
 from abc import ABC
@@ -119,9 +125,9 @@ from typing import Any, Callable, Dict
 
 import keras
 import tensorflow as tf
-from src.model_modules.inception_model import InceptionModelBase
-from src.model_modules.flatten import flatten_tail
-from src.model_modules.advanced_paddings import PadUtils, Padding2D
+from mlair.model_modules.inception_model import InceptionModelBase
+from mlair.model_modules.flatten import flatten_tail
+from mlair.model_modules.advanced_paddings import PadUtils, Padding2D
 
 
 class AbstractModelClass(ABC):
@@ -133,7 +139,7 @@ class AbstractModelClass(ABC):
     the corresponding loss function.
     """
 
-    def __init__(self) -> None:
+    def __init__(self, shape_inputs, shape_outputs) -> None:
         """Predefine internal attributes for model and loss."""
         self.__model = None
         self.model_name = self.__class__.__name__
@@ -147,6 +153,8 @@ class AbstractModelClass(ABC):
                                           'target_tensors': None
                                           }
         self.__compile_options = self.__allowed_compile_options
+        self.shape_inputs = shape_inputs
+        self.shape_outputs = self.__extract_from_tuple(shape_outputs)
 
     def __getattr__(self, name: str) -> Any:
         """
@@ -267,6 +275,11 @@ class AbstractModelClass(ABC):
                 raise ValueError(
                     f"Got different values or arguments for same argument: self.{allow_k}={new_v_attr.__class__} and '{allow_k}': {new_v_dic.__class__}")
 
+    @staticmethod
+    def __extract_from_tuple(tup):
+        """Return element of tuple if it contains only a single element."""
+        return tup[0] if isinstance(tup, tuple) and len(tup) == 1 else tup
+
     @staticmethod
     def __compare_keras_optimizers(first, second):
         if first.__class__ == second.__class__ and first.__module__ == 'keras.optimizers':
@@ -334,24 +347,19 @@ class MyLittleModel(AbstractModelClass):
     Dense layer.
     """
 
-    def __init__(self, window_history_size, window_lead_time, channels):
+    def __init__(self, shape_inputs: list, shape_outputs: list):
         """
         Sets model and loss depending on the given arguments.
 
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param regularizer: <not used here>
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
+        :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast))
         """
 
-        super().__init__()
+        assert len(shape_inputs) == 1
+        assert len(shape_outputs) == 1
+        super().__init__(shape_inputs[0], shape_outputs[0])
 
         # settings
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-        self.channels = channels
         self.dropout_rate = 0.1
         self.regularizer = keras.regularizers.l2(0.1)
         self.activation = keras.layers.PReLU
@@ -364,17 +372,10 @@ class MyLittleModel(AbstractModelClass):
     def set_model(self):
         """
         Build the model.
-
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
-        :return: built keras model
         """
 
         # add 1 to window_size to include current time step t0
-        x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels))
+        x_input = keras.layers.Input(shape=self.shape_inputs)
         x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input)
         x_in = self.activation(name='{}_conv_act'.format("major"))(x_in)
         x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
@@ -385,16 +386,16 @@ class MyLittleModel(AbstractModelClass):
         x_in = self.activation()(x_in)
         x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
         x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
+        x_in = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("major"))(x_in)
         out_main = self.activation()(x_in)
         self.model = keras.Model(inputs=x_input, outputs=[out_main])
 
     def set_compile_options(self):
         self.initial_lr = 1e-2
-        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
-        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
-                                                                             epochs_drop=10)
-        self.compile_options = {"loss": keras.losses.mean_squared_error, "metrics": ["mse", "mae"]}
+        self.optimizer = keras.optimizers.adam(lr=self.initial_lr)
+        self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
+                                                                               epochs_drop=10)
+        self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
 
 
 class MyBranchedModel(AbstractModelClass):
@@ -406,24 +407,19 @@ class MyBranchedModel(AbstractModelClass):
     Dense layer.
     """
 
-    def __init__(self, window_history_size, window_lead_time, channels):
+    def __init__(self, shape_inputs: list, shape_outputs: list):
         """
         Sets model and loss depending on the given arguments.
 
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param regularizer: <not used here>
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
+        :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast))
         """
 
-        super().__init__()
+        assert len(shape_inputs) == 1
+        assert len(shape_outputs) == 1
+        super().__init__(shape_inputs[0], shape_outputs[0])
 
         # settings
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-        self.channels = channels
         self.dropout_rate = 0.1
         self.regularizer = keras.regularizers.l2(0.1)
         self.activation = keras.layers.PReLU
@@ -436,69 +432,57 @@ class MyBranchedModel(AbstractModelClass):
     def set_model(self):
         """
         Build the model.
-
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
-        :return: built keras model
         """
 
         # add 1 to window_size to include current time step t0
-        x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels))
+        x_input = keras.layers.Input(shape=self.shape_inputs)
         x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input)
         x_in = self.activation(name='{}_conv_act'.format("major"))(x_in)
         x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
         x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in)
         x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in)
         x_in = self.activation()(x_in)
-        out_minor_1 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_1"))(x_in)
+        out_minor_1 = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("minor_1"))(x_in)
         out_minor_1 = self.activation(name="minor_1")(out_minor_1)
         x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in)
         x_in = self.activation()(x_in)
-        out_minor_2 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_2"))(x_in)
+        out_minor_2 = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("minor_2"))(x_in)
         out_minor_2 = self.activation(name="minor_2")(out_minor_2)
         x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
         x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
+        x_in = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("major"))(x_in)
         out_main = self.activation(name="main")(x_in)
         self.model = keras.Model(inputs=x_input, outputs=[out_minor_1, out_minor_2, out_main])
 
     def set_compile_options(self):
         self.initial_lr = 1e-2
         self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
-        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
-                                                                             epochs_drop=10)
+        self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
+                                                                               epochs_drop=10)
         self.compile_options = {"loss": [keras.losses.mean_absolute_error] + [keras.losses.mean_squared_error] + [
             keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
 
 
 class MyTowerModel(AbstractModelClass):
 
-    def __init__(self, window_history_size, window_lead_time, channels):
+    def __init__(self, shape_inputs: list, shape_outputs: list):
         """
         Sets model and loss depending on the given arguments.
 
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param regularizer: <not used here>
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
+        :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast))
         """
 
-        super().__init__()
+        assert len(shape_inputs) == 1
+        assert len(shape_outputs) == 1
+        super().__init__(shape_inputs[0], shape_outputs[0])
 
         # settings
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-        self.channels = channels
         self.dropout_rate = 1e-2
         self.regularizer = keras.regularizers.l2(0.1)
         self.initial_lr = 1e-2
-        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
-                                                                             epochs_drop=10)
+        self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
+                                                                               epochs_drop=10)
         self.activation = keras.layers.PReLU
 
         # apply to model
@@ -509,13 +493,6 @@ class MyTowerModel(AbstractModelClass):
     def set_model(self):
         """
         Build the model.
-
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
-        :return: built keras model
         """
         activation = self.activation
         conv_settings_dict1 = {
@@ -549,9 +526,7 @@ class MyTowerModel(AbstractModelClass):
         ##########################################
         inception_model = InceptionModelBase()
 
-        X_input = keras.layers.Input(
-            shape=(
-            self.window_history_size + 1, 1, self.channels))  # add 1 to window_size to include current time step t0
+        X_input = keras.layers.Input(shape=self.shape_inputs)
 
         X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1,
                                                regularizer=self.regularizer,
@@ -573,7 +548,7 @@ class MyTowerModel(AbstractModelClass):
         # out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=self.dropout_rate,
         #                         reduction_filter=64, inner_neurons=64, output_neurons=self.window_lead_time)
 
-        out_main = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time,
+        out_main = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.shape_outputs,
                                 output_activation='linear', reduction_filter=64,
                                 name='Main', bound_weight=True, dropout_rate=self.dropout_rate,
                                 kernel_regularizer=self.regularizer
@@ -588,29 +563,24 @@ class MyTowerModel(AbstractModelClass):
 
 class MyPaperModel(AbstractModelClass):
 
-    def __init__(self, window_history_size, window_lead_time, channels):
+    def __init__(self, shape_inputs: list, shape_outputs: list):
         """
         Sets model and loss depending on the given arguments.
 
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param regularizer: <not used here>
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
+        :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast))
         """
 
-        super().__init__()
+        assert len(shape_inputs) == 1
+        assert len(shape_outputs) == 1
+        super().__init__(shape_inputs[0], shape_outputs[0])
 
         # settings
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-        self.channels = channels
         self.dropout_rate = .3
         self.regularizer = keras.regularizers.l2(0.001)
         self.initial_lr = 1e-3
-        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
-                                                                             epochs_drop=10)
+        self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
+                                                                               epochs_drop=10)
         self.activation = keras.layers.ELU
         self.padding = "SymPad2D"
 
@@ -670,9 +640,7 @@ class MyPaperModel(AbstractModelClass):
         ##########################################
         inception_model = InceptionModelBase()
 
-        X_input = keras.layers.Input(
-            shape=(
-            self.window_history_size + 1, 1, self.channels))  # add 1 to window_size to include current time step t0
+        X_input = keras.layers.Input(shape=self.shape_inputs)
 
         pad_size = PadUtils.get_padding_for_same(first_kernel)
         # X_in = adv_pad.SymmetricPadding2D(padding=pad_size)(X_input)
@@ -690,7 +658,7 @@ class MyPaperModel(AbstractModelClass):
                                                padding=self.padding)
         # out_minor1 = flatten_tail(X_in, 'minor_1', False, self.dropout_rate, self.window_lead_time,
         #                           self.activation, 32, 64)
-        out_minor1 = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time,
+        out_minor1 = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.shape_outputs,
                                   output_activation='linear', reduction_filter=32,
                                   name='minor_1', bound_weight=False, dropout_rate=self.dropout_rate,
                                   kernel_regularizer=self.regularizer
@@ -708,7 +676,7 @@ class MyPaperModel(AbstractModelClass):
         #                                        batch_normalisation=True)
         #############################################
 
-        out_main = flatten_tail(X_in, inner_neurons=64 * 2, activation=activation, output_neurons=self.window_lead_time,
+        out_main = flatten_tail(X_in, inner_neurons=64 * 2, activation=activation, output_neurons=self.shape_outputs,
                                 output_activation='linear',  reduction_filter=64 * 2,
                                 name='Main', bound_weight=False, dropout_rate=self.dropout_rate,
                                 kernel_regularizer=self.regularizer
diff --git a/src/plotting/__init__.py b/mlair/plotting/__init__.py
similarity index 100%
rename from src/plotting/__init__.py
rename to mlair/plotting/__init__.py
diff --git a/src/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
similarity index 94%
rename from src/plotting/postprocessing_plotting.py
rename to mlair/plotting/postprocessing_plotting.py
index 4b7f15219ee5506f34e4fc2d76c15fb0e569394d..5cc449aac88ebab58689656820769fe7751f6098 100644
--- a/src/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -18,9 +18,9 @@ import seaborn as sns
 import xarray as xr
 from matplotlib.backends.backend_pdf import PdfPages
 
-from src import helpers
-from src.data_handling import DataGenerator
-from src.helpers import TimeTrackingWrapper
+from mlair import helpers
+from mlair.data_handler.iterator import DataCollection
+from mlair.helpers import TimeTrackingWrapper
 
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
@@ -236,12 +236,10 @@ class PlotStationMap(AbstractPlotClass):
 
         import cartopy.crs as ccrs
         if generators is not None:
-            for color, gen in generators.items():
-                for k, v in enumerate(gen):
-                    station_coords = gen.get_data_generator(k).meta.loc[['station_lon', 'station_lat']]
-                    # station_names = gen.get_data_generator(k).meta.loc[['station_id']]
-                    IDx, IDy = float(station_coords.loc['station_lon'].values), float(
-                        station_coords.loc['station_lat'].values)
+            for color, data_collection in generators.items():
+                for station in data_collection:
+                    coords = station.get_coordinates()
+                    IDx, IDy = coords["lon"], coords["lat"]
                     self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())
 
     def _plot(self, generators: Dict):
@@ -258,14 +256,28 @@ class PlotStationMap(AbstractPlotClass):
         from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
         fig = plt.figure(figsize=(10, 5))
         self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
-        self._ax.set_extent([4, 17, 44, 58], crs=ccrs.PlateCarree())
         self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
         self._gl.xformatter = LONGITUDE_FORMATTER
         self._gl.yformatter = LATITUDE_FORMATTER
         self._draw_background()
         self._plot_stations(generators)
+        self._adjust_extent()
         plt.tight_layout()
 
+    def _adjust_extent(self):
+        import cartopy.crs as ccrs
+
+        def diff(arr):
+            return arr[1] - arr[0], arr[3] - arr[2]
+
+        def find_ratio(delta, reference=5):
+            return max(abs(reference / delta[0]), abs(reference / delta[1]))
+
+        extent = self._ax.get_extent(crs=ccrs.PlateCarree())
+        ratio = find_ratio(diff(extent))
+        new_extent = extent + np.array([-1, 1, -1, 1]) * ratio
+        self._ax.set_extent(new_extent, crs=ccrs.PlateCarree())
+
 
 @TimeTrackingWrapper
 class PlotConditionalQuantiles(AbstractPlotClass):
@@ -699,6 +711,8 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         """
         data = helpers.dict_to_xarray(data, "station").sortby(self._x_name)
         self._labels = [str(i) + "d" for i in data.coords["ahead"].values]
+        if "station" not in data.dims:
+            data = data.expand_dims("station")
         return data.to_dataframe("data").reset_index(level=[0, 1, 2])
 
     def _label_add(self, score_only: bool):
@@ -771,8 +785,8 @@ class PlotTimeSeries:
 
     def _plot(self, plot_folder):
         pdf_pages = self._create_pdf_pages(plot_folder)
-        start, end = self._get_time_range(self._load_data(self._stations[0]))
         for pos, station in enumerate(self._stations):
+            start, end = self._get_time_range(self._load_data(self._stations[0]))
             data = self._load_data(station)
             fig, axes, factor = self._create_subplots(start, end)
             nan_list = []
@@ -882,11 +896,12 @@ class PlotAvailability(AbstractPlotClass):
 
     """
 
-    def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily",
-                 summary_name="data availability"):
+    def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
+                 summary_name="data availability", time_dimension="datetime"):
         """Initialise."""
         # create standard Gantt plot for all stations (currently in single pdf file with single page)
         super().__init__(plot_folder, "data_availability")
+        self.dim = time_dimension
         self.sampling = self._get_sampling(sampling)
         plot_dict = self._prepare_data(generators)
         lgd = self._plot(plot_dict)
@@ -909,34 +924,30 @@ class PlotAvailability(AbstractPlotClass):
         elif sampling == "hourly":
             return "h"
 
-    def _prepare_data(self, generators: Dict[str, DataGenerator]):
+    def _prepare_data(self, generators: Dict[str, DataCollection]):
         plt_dict = {}
-        for subset, generator in generators.items():
-            stations = generator.stations
-            for station in stations:
-                station_data = generator.get_data_generator(station)
-                labels = station_data.get_transposed_label().resample(datetime=self.sampling, skipna=True).mean()
+        for subset, data_collection in generators.items():
+            for station in data_collection:
+                labels = station.get_Y(as_numpy=False).resample({self.dim: self.sampling}, skipna=True).mean()
                 labels_bool = labels.sel(window=1).notnull()
-                group = (labels_bool != labels_bool.shift(datetime=1)).cumsum()
+                group = (labels_bool != labels_bool.shift({self.dim: 1})).cumsum()
                 plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values},
-                                         index=labels.datetime.values)
+                                         index=labels.coords[self.dim].values)
                 t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
                 t2 = [i[1:] for i in t if i[0]]
 
-                if plt_dict.get(station) is None:
-                    plt_dict[station] = {subset: t2}
+                if plt_dict.get(str(station)) is None:
+                    plt_dict[str(station)] = {subset: t2}
                 else:
-                    plt_dict[station].update({subset: t2})
+                    plt_dict[str(station)].update({subset: t2})
         return plt_dict
 
-    def _summarise_data(self, generators: Dict[str, DataGenerator], summary_name: str):
+    def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
         plt_dict = {}
-        for subset, generator in generators.items():
+        for subset, data_collection in generators.items():
             all_data = None
-            stations = generator.stations
-            for station in stations:
-                station_data = generator.get_data_generator(station)
-                labels = station_data.get_transposed_label().resample(datetime=self.sampling, skipna=True).mean()
+            for station in data_collection:
+                labels = station.get_Y(as_numpy=False).resample({self.dim: self.sampling}, skipna=True).mean()
                 labels_bool = labels.sel(window=1).notnull()
                 if all_data is None:
                     all_data = labels_bool
@@ -945,8 +956,9 @@ class PlotAvailability(AbstractPlotClass):
                     all_data = np.logical_or(tmp, labels_bool).combine_first(
                         all_data)  # apply logical on merge and fill missing with all_data
 
-            group = (all_data != all_data.shift(datetime=1)).cumsum()
-            plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, index=all_data.datetime.values)
+            group = (all_data != all_data.shift({self.dim: 1})).cumsum()
+            plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values},
+                                     index=all_data.coords[self.dim].values)
             t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
             t2 = [i[1:] for i in t if i[0]]
             if plt_dict.get(summary_name) is None:
diff --git a/src/plotting/tracker_plot.py b/mlair/plotting/tracker_plot.py
similarity index 99%
rename from src/plotting/tracker_plot.py
rename to mlair/plotting/tracker_plot.py
index 20db5d9d9f22df548b1d499c4e8e0faa3fbfa1ee..406c32feb1ebda2d32d886051e32778d6c17f5db 100644
--- a/src/plotting/tracker_plot.py
+++ b/mlair/plotting/tracker_plot.py
@@ -4,7 +4,7 @@ import numpy as np
 import os
 from typing import Union, List, Optional, Dict
 
-from src.helpers import to_list
+from mlair.helpers import to_list
 
 from matplotlib import pyplot as plt, lines as mlines, ticker as ticker
 from matplotlib.patches import Rectangle
diff --git a/src/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py
similarity index 98%
rename from src/plotting/training_monitoring.py
rename to mlair/plotting/training_monitoring.py
index 473b966ce52ee7e2885bc14beef2e68b8835b15e..913c11dd8a4e0d23c2bde6864c12f17c65922644 100644
--- a/src/plotting/training_monitoring.py
+++ b/mlair/plotting/training_monitoring.py
@@ -10,7 +10,7 @@ import matplotlib
 import matplotlib.pyplot as plt
 import pandas as pd
 
-from src.model_modules.keras_extensions import LearningRateDecay
+from mlair.model_modules.keras_extensions import LearningRateDecay
 
 matplotlib.use('Agg')
 history_object = Union[Dict, keras.callbacks.History]
diff --git a/src/run_modules/README.md b/mlair/run_modules/README.md
similarity index 100%
rename from src/run_modules/README.md
rename to mlair/run_modules/README.md
diff --git a/mlair/run_modules/__init__.py b/mlair/run_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba38d3e90fb5d66c4129f6645ef34b8137e48375
--- /dev/null
+++ b/mlair/run_modules/__init__.py
@@ -0,0 +1,7 @@
+from mlair.run_modules.experiment_setup import ExperimentSetup
+from mlair.run_modules.model_setup import ModelSetup
+from mlair.run_modules.partition_check import PartitionCheck
+from mlair.run_modules.post_processing import PostProcessing
+from mlair.run_modules.pre_processing import PreProcessing
+from mlair.run_modules.run_environment import RunEnvironment
+from mlair.run_modules.training import Training
diff --git a/src/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
similarity index 91%
rename from src/run_modules/experiment_setup.py
rename to mlair/run_modules/experiment_setup.py
index 1d375c32be06b583abbfb06a20ea482e6775b232..407465ad4cd99b85c3c5b37eb2aef6e9e71c6424 100644
--- a/src/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -6,21 +6,21 @@ import logging
 import os
 from typing import Union, Dict, Any, List
 
-from src.configuration import path_config
-from src import helpers
-from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_NETWORK, DEFAULT_STATION_TYPE, \
+from mlair.configuration import path_config
+from mlair import helpers
+from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_NETWORK, DEFAULT_STATION_TYPE, \
     DEFAULT_START, DEFAULT_END, DEFAULT_WINDOW_HISTORY_SIZE, DEFAULT_OVERWRITE_LOCAL_DATA, DEFAULT_TRANSFORMATION, \
     DEFAULT_HPC_LOGIN_LIST, DEFAULT_HPC_HOST_LIST, DEFAULT_CREATE_NEW_MODEL, DEFAULT_TRAINABLE, \
     DEFAULT_FRACTION_OF_TRAINING, DEFAULT_EXTREME_VALUES, DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY, DEFAULT_PERMUTE_DATA, \
     DEFAULT_BATCH_SIZE, DEFAULT_EPOCHS, DEFAULT_TARGET_VAR, DEFAULT_TARGET_DIM, DEFAULT_WINDOW_LEAD_TIME, \
-    DEFAULT_DIMENSIONS, DEFAULT_INTERPOLATE_DIM, DEFAULT_INTERPOLATE_METHOD, DEFAULT_LIMIT_NAN_FILL, \
+    DEFAULT_DIMENSIONS, DEFAULT_TIME_DIM, DEFAULT_INTERPOLATION_METHOD, DEFAULT_INTERPOLATION_LIMIT, \
     DEFAULT_TRAIN_START, DEFAULT_TRAIN_END, DEFAULT_TRAIN_MIN_LENGTH, DEFAULT_VAL_START, DEFAULT_VAL_END, \
     DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
     DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST
-from src.data_handling import DataPrepJoin
-from src.run_modules.run_environment import RunEnvironment
-from src.model_modules.model_class import MyLittleModel as VanillaModel
+from mlair.data_handler.advanced_data_handler import DefaultDataPreparation
+from mlair.run_modules.run_environment import RunEnvironment
+from mlair.model_modules.model_class import MyLittleModel as VanillaModel
 
 
 class ExperimentSetup(RunEnvironment):
@@ -50,8 +50,6 @@ class ExperimentSetup(RunEnvironment):
         * `plot_path` [.]
         * `forecast_path` [.]
         * `stations` [.]
-        * `network` [.]
-        * `station_type` [.]
         * `statistics_per_var` [.]
         * `variables` [.]
         * `start` [.]
@@ -66,8 +64,8 @@ class ExperimentSetup(RunEnvironment):
 
         # interpolation
         self._set_param("dimensions", dimensions, default={'new_index': ['datetime', 'Stations']})
-        self._set_param("interpolate_dim", interpolate_dim, default='datetime')
-        self._set_param("interpolate_method", interpolate_method, default='linear')
+        self._set_param("time_dim", time_dim, default='datetime')
+        self._set_param("interpolation_method", interpolation_method, default='linear')
         self._set_param("limit_nan_fill", limit_nan_fill, default=1)
 
         # train set parameters
@@ -116,10 +114,6 @@ class ExperimentSetup(RunEnvironment):
         investigations are stored outside this structure.
     :param stations: list of stations or single station to use in experiment. If not provided, stations are set to
         :py:const:`default stations <DEFAULT_STATIONS>`.
-    :param network: name of network to restrict to use only stations from this measurement network. Default is
-        `AIRBASE` .
-    :param station_type: restrict network type to one of TOAR's categories (background, traffic, industrial). Default is
-        `None` to use all categories.
     :param variables: list of all variables to use. Valid names can be found in
         `Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_. If not provided, this
         parameter is filled with keys from ``statistics_per_var``.
@@ -140,8 +134,8 @@ class ExperimentSetup(RunEnvironment):
     :param window_lead_time: number of time steps to predict by model (default 3). Time steps `t_0+1` to `t_0+w` are
         predicted.
     :param dimensions:
-    :param interpolate_dim:
-    :param interpolate_method:
+    :param time_dim:
+    :param interpolation_method:
     :param limit_nan_fill:
     :param train_start:
     :param train_end:
@@ -209,8 +203,6 @@ class ExperimentSetup(RunEnvironment):
     def __init__(self,
                  experiment_date=None,
                  stations: Union[str, List[str]] = None,
-                 network: str = None,
-                 station_type: str = None,
                  variables: Union[str, List[str]] = None,
                  statistics_per_var: Dict = None,
                  start: str = None,
@@ -220,16 +212,16 @@ class ExperimentSetup(RunEnvironment):
                  target_dim=None,
                  window_lead_time: int = None,
                  dimensions=None,
-                 interpolate_dim=None,
-                 interpolate_method=None,
-                 limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
+                 time_dim=None,
+                 interpolation_method=None,
+                 interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
                  test_end=None, use_all_stations_on_all_data_sets=None, trainable: bool = None, fraction_of_train: float = None,
                  experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data = None, sampling: str = "daily",
                  create_new_model = None, bootstrap_path=None, permute_data_on_training = None, transformation=None,
                  train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
                  extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None,
-                 create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None,
-                 batch_size=None, epochs=None, data_preparation=None):
+                 create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
+                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_preparation=None, **kwargs):
 
         # create run framework
         super().__init__()
@@ -265,6 +257,9 @@ class ExperimentSetup(RunEnvironment):
         logging.info(f"Experiment path is: {experiment_path}")
         path_config.check_path_and_create(self.data_store.get("experiment_path"))
 
+        # batch path (temporary)
+        self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data"))
+
         # set model path
         self._set_param("model_path", None, os.path.join(experiment_path, "model"))
         path_config.check_path_and_create(self.data_store.get("model_path"))
@@ -285,8 +280,6 @@ class ExperimentSetup(RunEnvironment):
 
         # setup for data
         self._set_param("stations", stations, default=DEFAULT_STATIONS)
-        self._set_param("network", network, default=DEFAULT_NETWORK)
-        self._set_param("station_type", station_type, default=DEFAULT_STATION_TYPE)
         self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
         self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys()))
         self._set_param("start", start, default=DEFAULT_START)
@@ -297,7 +290,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("sampling", sampling)
         self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
         self._set_param("transformation", None, scope="preprocessing")
-        self._set_param("data_preparation", data_preparation, default=DataPrepJoin)
+        self._set_param("data_preparation", data_preparation, default=DefaultDataPreparation)
 
         # target
         self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
@@ -306,9 +299,9 @@ class ExperimentSetup(RunEnvironment):
 
         # interpolation
         self._set_param("dimensions", dimensions, default=DEFAULT_DIMENSIONS)
-        self._set_param("interpolate_dim", interpolate_dim, default=DEFAULT_INTERPOLATE_DIM)
-        self._set_param("interpolate_method", interpolate_method, default=DEFAULT_INTERPOLATE_METHOD)
-        self._set_param("limit_nan_fill", limit_nan_fill, default=DEFAULT_LIMIT_NAN_FILL)
+        self._set_param("time_dim", time_dim, default=DEFAULT_TIME_DIM)
+        self._set_param("interpolation_method", interpolation_method, default=DEFAULT_INTERPOLATION_METHOD)
+        self._set_param("interpolation_limit", interpolation_limit, default=DEFAULT_INTERPOLATION_LIMIT)
 
         # train set parameters
         self._set_param("start", train_start, default=DEFAULT_TRAIN_START, scope="train")
@@ -344,6 +337,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS,
                         scope="general.postprocessing")
         self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
+        self._set_param("neighbors", ["DEBW030"])  # TODO: just for testing
 
         # check variables, statistics and target variable
         self._check_target_var()
@@ -352,6 +346,15 @@ class ExperimentSetup(RunEnvironment):
         # set model architecture class
         self._set_param("model_class", model, VanillaModel)
 
+        # set remaining kwargs
+        if len(kwargs) > 0:
+            for k, v in kwargs.items():
+                if len(self.data_store.search_name(k)) == 0:
+                    self._set_param(k, v)
+                else:
+                    raise KeyError(f"Given argument {k} with value {v} cannot be set for this experiment due to a "
+                                   f"conflict with an existing entry with same naming: {k}={self.data_store.get(k)}")
+
     def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None:
         """Set given parameter and log in debug."""
         if value is None and default is not None:
@@ -391,6 +394,7 @@ class ExperimentSetup(RunEnvironment):
         if not set(target_var).issubset(stat.keys()):
             raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.")
 
+
 if __name__ == "__main__":
     formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
     logging.basicConfig(format=formatter, level=logging.DEBUG)
diff --git a/src/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
similarity index 69%
rename from src/run_modules/model_setup.py
rename to mlair/run_modules/model_setup.py
index f9683b953d85bacf6e452e0a1922e85dfe946cd1..3dc56f01c4f37ce9fc53086d837386af81e5f53d 100644
--- a/src/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -5,12 +5,15 @@ __date__ = '2019-12-02'
 
 import logging
 import os
+import re
 
 import keras
+import pandas as pd
 import tensorflow as tf
 
-from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
-from src.run_modules.run_environment import RunEnvironment
+from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
+from mlair.run_modules.run_environment import RunEnvironment
+from mlair.configuration import path_config
 
 
 class ModelSetup(RunEnvironment):
@@ -31,8 +34,6 @@ class ModelSetup(RunEnvironment):
         * `trainable` [.]
         * `create_new_model` [.]
         * `generator` [train]
-        * `window_lead_time` [.]
-        * `window_history_size` [.]
         * `model_class` [.]
 
     Optional objects
@@ -70,7 +71,7 @@ class ModelSetup(RunEnvironment):
     def _run(self):
 
         # set channels depending on inputs
-        self._set_channels()
+        self._set_shapes()
 
         # build model graph using settings from my_model_settings()
         self.build_model()
@@ -88,10 +89,15 @@ class ModelSetup(RunEnvironment):
         # compile model
         self.compile_model()
 
-    def _set_channels(self):
-        """Set channels as number of variables of train generator."""
-        channels = self.data_store.get("generator", "train")[0][0].shape[-1]
-        self.data_store.set("channels", channels, self.scope)
+        # report settings
+        self.report_model()
+
+    def _set_shapes(self):
+        """Set input and output shapes from train collection."""
+        shape = list(map(lambda x: x.shape[1:], self.data_store.get("data_collection", "train")[0].get_X()))
+        self.data_store.set("shape_inputs", shape, self.scope)
+        shape = list(map(lambda y: y.shape[1:], self.data_store.get("data_collection", "train")[0].get_Y()))
+        self.data_store.set("shape_outputs", shape, self.scope)
 
     def compile_model(self):
         """
@@ -128,8 +134,8 @@ class ModelSetup(RunEnvironment):
             logging.info('no weights to reload...')
 
     def build_model(self):
-        """Build model using window_history_size, window_lead_time and channels from data store."""
-        args_list = ["window_history_size", "window_lead_time", "channels"]
+        """Build model using input and output shapes from data store."""
+        args_list = ["shape_inputs", "shape_outputs"]
         args = self.data_store.create_args_dict(args_list, self.scope)
         model = self.data_store.get("model_class")
         self.model = model(**args)
@@ -147,3 +153,30 @@ class ModelSetup(RunEnvironment):
         with tf.device("/cpu:0"):
             file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
             keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
+
+    def report_model(self):
+        model_settings = self.model.get_settings()
+        model_settings.update(self.model.compile_options)
+        df = pd.DataFrame(columns=["model setting"])
+        for k, v in model_settings.items():
+            if v is None:
+                continue
+            if isinstance(v, list):
+                v = ",".join(self._clean_name(str(u)) for u in v)
+            if "<" in str(v):
+                v = self._clean_name(str(v))
+            df.loc[k] = str(v)
+        df.sort_index(inplace=True)
+        column_format = "ll"
+        path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
+        path_config.check_path_and_create(path)
+        df.to_latex(os.path.join(path, "model_settings.tex"), na_rep='---', column_format=column_format)
+        df.to_markdown(open(os.path.join(path, "model_settings.md"), mode="w", encoding='utf-8'),
+                       tablefmt="github")
+
+    @staticmethod
+    def _clean_name(orig_name: str):
+        mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
+        mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name[0]
+        return mod_name[:-1] if mod_name[-1] == ">" else mod_name
+
diff --git a/src/run_modules/partition_check.py b/mlair/run_modules/partition_check.py
similarity index 93%
rename from src/run_modules/partition_check.py
rename to mlair/run_modules/partition_check.py
index 8f4c703e6b94f11905121d93c44dd8bf583abdec..c45f350079756282fbb43a1732d256c960f9e274 100644
--- a/src/run_modules/partition_check.py
+++ b/mlair/run_modules/partition_check.py
@@ -1,7 +1,7 @@
 __author__ = "Felix Kleinert"
 __date__ = '2020-04-07'
 
-from src.run_modules.run_environment import RunEnvironment
+from mlair.run_modules.run_environment import RunEnvironment
 
 
 class PartitionCheck(RunEnvironment):
diff --git a/src/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
similarity index 76%
rename from src/run_modules/post_processing.py
rename to mlair/run_modules/post_processing.py
index b97d28c1cf71d35526207450d6b0bb386ddefdb7..d4f409ec503ba0ae37bdd1d1bec4b0207eec453c 100644
--- a/src/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -13,14 +13,14 @@ import numpy as np
 import pandas as pd
 import xarray as xr
 
-from src.data_handling import BootStraps, Distributor, DataGenerator, DataPrepJoin
-from src.helpers.datastore import NameNotFoundInDataStore
-from src.helpers import TimeTracking, statistics
-from src.model_modules.linear_model import OrdinaryLeastSquaredModel
-from src.model_modules.model_class import AbstractModelClass
-from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
+from mlair.data_handler import BootStraps, KerasIterator
+from mlair.helpers.datastore import NameNotFoundInDataStore
+from mlair.helpers import TimeTracking, statistics, extract_value
+from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
+from mlair.model_modules.model_class import AbstractModelClass
+from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
     PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles
-from src.run_modules.run_environment import RunEnvironment
+from mlair.run_modules.run_environment import RunEnvironment
 
 
 class PostProcessing(RunEnvironment):
@@ -42,7 +42,7 @@ class PostProcessing(RunEnvironment):
         * `model_path` [.]
         * `target_var` [.]
         * `sampling` [.]
-        * `window_lead_time` [.]
+        * `output_shape` [model]
         * `evaluate_bootstraps` [postprocessing] and if enabled:
 
             * `create_new_bootstraps` [postprocessing]
@@ -65,14 +65,16 @@ class PostProcessing(RunEnvironment):
         self.model: keras.Model = self._load_model()
         self.ols_model = None
         self.batch_size: int = self.data_store.get_default("batch_size", "model", 64)
-        self.test_data: DataGenerator = self.data_store.get("generator", "test")
-        self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size)
-        self.train_data: DataGenerator = self.data_store.get("generator", "train")
-        self.val_data: DataGenerator = self.data_store.get("generator", "val")
-        self.train_val_data: DataGenerator = self.data_store.get("generator", "train_val")
+        self.test_data = self.data_store.get("data_collection", "test")
+        batch_path = self.data_store.get("batch_path", scope="test")
+        self.test_data_distributed = KerasIterator(self.test_data, self.batch_size, model=self.model, name="test", batch_path=batch_path)
+        self.train_data = self.data_store.get("data_collection", "train")
+        self.val_data = self.data_store.get("data_collection", "val")
+        self.train_val_data = self.data_store.get("data_collection", "train_val")
         self.plot_path: str = self.data_store.get("plot_path")
         self.target_var = self.data_store.get("target_var")
         self._sampling = self.data_store.get("sampling")
+        self.window_lead_time = extract_value(self.data_store.get("shape_outputs", "model"))
         self.skill_scores = None
         self.bootstrap_skill_scores = None
         self._run()
@@ -141,34 +143,29 @@ class PostProcessing(RunEnvironment):
             bootstrap_path = self.data_store.get("bootstrap_path")
             forecast_path = self.data_store.get("forecast_path")
             number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
-
-            # set bootstrap class
-            bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
-
-            # create bootstrapped predictions for all stations and variables and save it to disk
             dims = ["index", "ahead", "type"]
-            for station in bootstraps.stations:
-                with TimeTracking(name=station):
-                    logging.info(station)
-                    for var in bootstraps.variables:
-                        station_bootstrap = bootstraps.get_generator(station, var)
-
-                        # make bootstrap predictions
-                        bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap,
-                                                                             workers=2,
-                                                                             use_multiprocessing=True)
-                        if isinstance(bootstrap_predictions, list):  # if model is branched model
-                            bootstrap_predictions = bootstrap_predictions[-1]
-                        # save bootstrap predictions separately for each station and variable combination
-                        bootstrap_predictions = np.expand_dims(bootstrap_predictions, axis=-1)
-                        shape = bootstrap_predictions.shape
-                        coords = (range(shape[0]), range(1, shape[1] + 1))
-                        tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims)
-                        file_name = os.path.join(forecast_path, f"bootstraps_{var}_{station}.nc")
-                        tmp.to_netcdf(file_name)
+            for station in self.test_data:
+                logging.info(str(station))
+                X, Y = None, None
+                bootstraps = BootStraps(station, number_of_bootstraps)
+                for boot in bootstraps:
+                    X, Y, (index, dimension) = boot
+                    # make bootstrap predictions
+                    bootstrap_predictions = self.model.predict(X)
+                    if isinstance(bootstrap_predictions, list):  # if model is branched model
+                        bootstrap_predictions = bootstrap_predictions[-1]
+                    # save bootstrap predictions separately for each station and variable combination
+                    bootstrap_predictions = np.expand_dims(bootstrap_predictions, axis=-1)
+                    shape = bootstrap_predictions.shape
+                    coords = (range(shape[0]), range(1, shape[1] + 1))
+                    var = f"{index}_{dimension}"
+                    tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims)
+                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_{var}.nc")
+                    tmp.to_netcdf(file_name)
+                else:
                     # store also true labels for each station
-                    labels = np.expand_dims(bootstraps.get_labels(station), axis=-1)
-                    file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
+                    labels = np.expand_dims(Y, axis=-1)
+                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_labels.nc")
                     labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=dims)
                     labels.to_netcdf(file_name)
 
@@ -186,42 +183,50 @@ class PostProcessing(RunEnvironment):
             # extract all requirements from data store
             bootstrap_path = self.data_store.get("bootstrap_path")
             forecast_path = self.data_store.get("forecast_path")
-            window_lead_time = self.data_store.get("window_lead_time")
             number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
-            bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
-
+            forecast_file = f"forecasts_norm_%s_test.nc"
+            bootstraps = BootStraps(self.test_data[0], number_of_bootstraps).bootstraps()
             skill_scores = statistics.SkillScores(None)
             score = {}
-            for station in self.test_data.stations:
+            for station in self.test_data:
                 logging.info(station)
 
                 # get station labels
-                file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
+                file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_labels.nc")
                 labels = xr.open_dataarray(file_name)
                 shape = labels.shape
 
                 # get original forecasts
-                orig = bootstraps.get_orig_prediction(forecast_path, f"forecasts_norm_{station}_test.nc").reshape(shape)
+                orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), number_of_bootstraps)
+                orig = orig.reshape(shape)
                 coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"])
                 orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"])
 
                 # calculate skill scores for each variable
-                skill = pd.DataFrame(columns=range(1, window_lead_time + 1))
-                for boot in self.test_data.variables:
-                    file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc")
+                skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1))
+                for boot_set in bootstraps:
+                    boot_var = f"{boot_set[0]}_{boot_set[1]}"
+                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_{boot_var}.nc")
                     boot_data = xr.open_dataarray(file_name)
                     boot_data = boot_data.combine_first(labels).combine_first(orig)
                     boot_scores = []
-                    for ahead in range(1, window_lead_time + 1):
+                    for ahead in range(1, self.window_lead_time + 1):
                         data = boot_data.sel(ahead=ahead)
                         boot_scores.append(
-                            skill_scores.general_skill_score(data, forecast_name=boot, reference_name="orig"))
-                    skill.loc[boot] = np.array(boot_scores)
+                            skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name="orig"))
+                    skill.loc[boot_var] = np.array(boot_scores)
 
                 # collect all results in single dictionary
-                score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"])
+                score[str(station)] = xr.DataArray(skill, dims=["boot_var", "ahead"])
             return score
 
+    @staticmethod
+    def get_orig_prediction(path, file_name, number_of_bootstraps, prediction_name="CNN"):
+        file = os.path.join(path, file_name)
+        prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
+        vals = np.tile(prediction.data, (number_of_bootstraps, 1))
+        return vals[~np.isnan(vals).any(axis=1), :]
+
     def _load_model(self) -> keras.models:
         """
         Load NN model either from data store or from local path.
@@ -259,12 +264,13 @@ class PostProcessing(RunEnvironment):
         path = self.data_store.get("forecast_path")
 
         plot_list = self.data_store.get("plot_list", "postprocessing")
+        time_dimension = self.data_store.get("time_dim")
 
         if self.bootstrap_skill_scores is not None and "PlotBootstrapSkillScore" in plot_list:
             PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN")
 
         if "PlotConditionalQuantiles" in plot_list:
-            PlotConditionalQuantiles(self.test_data.stations, data_pred_path=path, plot_folder=self.plot_path)
+            PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=path, plot_folder=self.plot_path)
         if "PlotStationMap" in plot_list:
             if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get(
                     "hostname")[:6] in self.data_store.get("hpc_hosts"):
@@ -273,7 +279,7 @@ class PostProcessing(RunEnvironment):
             else:
                 PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
         if "PlotMonthlySummary" in plot_list:
-            PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", self.target_var,
+            PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var,
                                plot_folder=self.plot_path)
         if "PlotClimatologicalSkillScore" in plot_list:
             PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN")
@@ -282,16 +288,16 @@ class PostProcessing(RunEnvironment):
         if "PlotCompetitiveSkillScore" in plot_list:
             PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN")
         if "PlotTimeSeries" in plot_list:
-            PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path,
+            PlotTimeSeries(self.test_data.keys(), path, r"forecasts_%s_test.nc", plot_folder=self.plot_path,
                            sampling=self._sampling)
         if "PlotAvailability" in plot_list:
             avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
-            PlotAvailability(avail_data, plot_folder=self.plot_path)
+            PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dimension)
 
     def calculate_test_score(self):
         """Evaluate test score of model and save locally."""
-        test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),
-                                                   use_multiprocessing=False, verbose=0, steps=1)
+        test_score = self.model.evaluate_generator(generator=self.test_data_distributed,
+                                                   use_multiprocessing=True, verbose=0, steps=1)
         path = self.data_store.get("model_path")
         with open(os.path.join(path, "test_scores.txt"), "a") as f:
             for index, item in enumerate(test_score):
@@ -311,24 +317,26 @@ class PostProcessing(RunEnvironment):
         be found inside `forecast_path`.
         """
         logging.debug("start make_prediction")
-        for i, _ in enumerate(self.test_data):
-            data = self.test_data.get_data_generator(i)
-            input_data = data.get_transposed_history()
+        time_dimension = self.data_store.get("time_dim")
+        for i, data in enumerate(self.test_data):
+            input_data = data.get_X()
+            target_data = data.get_Y(as_numpy=False)
+            observation_data = data.get_observation()
 
             # get scaling parameters
-            mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
+            mean, std, transformation_method = data.get_transformation_Y()
 
             for normalised in [True, False]:
                 # create empty arrays
                 nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
-                    data, count=4)
+                    target_data, count=4)
 
                 # nn forecast
                 nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method,
                                                          normalised)
 
                 # persistence
-                persistence_prediction = self._create_persistence_forecast(data, persistence_prediction, mean, std,
+                persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, mean, std,
                                                                            transformation_method, normalised)
 
                 # ols
@@ -336,11 +344,12 @@ class PostProcessing(RunEnvironment):
                                                            normalised)
 
                 # observation
-                observation = self._create_observation(data, observation, mean, std, transformation_method, normalised)
+                observation = self._create_observation(target_data, observation, mean, std, transformation_method, normalised)
 
                 # merge all predictions
-                full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency())
-                all_predictions = self.create_forecast_arrays(full_index, list(data.label.indexes['window']),
+                full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency())
+                all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes['window']),
+                                                              time_dimension,
                                                               CNN=nn_prediction,
                                                               persi=persistence_prediction,
                                                               obs=observation,
@@ -349,7 +358,7 @@ class PostProcessing(RunEnvironment):
                 # save all forecasts locally
                 path = self.data_store.get("forecast_path")
                 prefix = "forecasts_norm" if normalised else "forecasts"
-                file = os.path.join(path, f"{prefix}_{data.station[0]}_test.nc")
+                file = os.path.join(path, f"{prefix}_{str(data)}_test.nc")
                 all_predictions.to_netcdf(file)
 
     def _get_frequency(self) -> str:
@@ -358,14 +367,14 @@ class PostProcessing(RunEnvironment):
         return getter.get(self._sampling, None)
 
     @staticmethod
-    def _create_observation(data: DataPrepJoin, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str,
+    def _create_observation(data, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str,
                             normalised: bool) -> xr.DataArray:
         """
         Create observation as ground truth from given data.
 
         Inverse transformation is applied to the ground truth to get the output in the original space.
 
-        :param data: transposed observation from DataPrep
+        :param data: observation
         :param mean: mean of target value transformation
         :param std: standard deviation of target value transformation
         :param transformation_method: target values transformation method
@@ -373,10 +382,9 @@ class PostProcessing(RunEnvironment):
 
         :return: filled data array with observation
         """
-        obs = data.label.copy()
         if not normalised:
-            obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method)
-        return obs
+            data = statistics.apply_inverse_transformation(data, mean, std, transformation_method)
+        return data
 
     def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray, mean: xr.DataArray,
                              std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray:
@@ -397,12 +405,11 @@ class PostProcessing(RunEnvironment):
         tmp_ols = self.ols_model.predict(input_data)
         if not normalised:
             tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method)
-        tmp_ols = np.expand_dims(tmp_ols, axis=1)
         target_shape = ols_prediction.values.shape
         ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
         return ols_prediction
 
-    def _create_persistence_forecast(self, data: DataPrepJoin, persistence_prediction: xr.DataArray, mean: xr.DataArray,
+    def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, mean: xr.DataArray,
                                      std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray:
         """
         Create persistence forecast with given data.
@@ -410,7 +417,7 @@ class PostProcessing(RunEnvironment):
         Persistence is deviated from the value at t=0 and applied to all following time steps (t+1, ..., t+window).
         Inverse transformation is applied to the forecast to get the output in the original space.
 
-        :param data: DataPrep
+        :param data: observation
         :param persistence_prediction: empty array in right shape to fill with data
         :param mean: mean of target value transformation
         :param std: standard deviation of target value transformation
@@ -419,12 +426,10 @@ class PostProcessing(RunEnvironment):
 
         :return: filled data array with persistence predictions
         """
-        tmp_persi = data.observation.copy().sel({'window': 0})
+        tmp_persi = data.copy()
         if not normalised:
             tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
-        window_lead_time = self.data_store.get("window_lead_time")
-        persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)),
-                                                       axis=1)
+        persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T
         return persistence_prediction
 
     def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, mean: xr.DataArray,
@@ -449,18 +454,20 @@ class PostProcessing(RunEnvironment):
         if not normalised:
             tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method)
         if isinstance(tmp_nn, list):
-            nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1], axis=1), 2, 0)
+            nn_prediction.values = tmp_nn[-1]
         elif tmp_nn.ndim == 3:
-            nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0)
+            nn_prediction.values = tmp_nn[-1, ...]
         elif tmp_nn.ndim == 2:
-            nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0)
+            nn_prediction.values = tmp_nn
         else:
             raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
         return nn_prediction
 
     @staticmethod
-    def _create_empty_prediction_arrays(generator, count=1):
-        return [generator.label.copy() for _ in range(count)]
+    def _create_empty_prediction_arrays(target_data, count=1):
+        """
+        Create array to collect all predictions. Expand target data by a station dimension. """
+        return [target_data.copy() for _ in range(count)]
 
     @staticmethod
     def create_fullindex(df: Union[xr.DataArray, pd.DataFrame, pd.DatetimeIndex], freq: str) -> pd.DataFrame:
@@ -488,7 +495,7 @@ class PostProcessing(RunEnvironment):
         return index
 
     @staticmethod
-    def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], **kwargs):
+    def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, **kwargs):
         """
         Combine different forecast types into single xarray.
 
@@ -503,12 +510,8 @@ class PostProcessing(RunEnvironment):
         res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
                            coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type'])
         for k, v in kwargs.items():
-            try:
-                match_index = np.stack(set(res.index.values) & set(v.index.values))
-                res.loc[match_index, :, k] = v.loc[match_index]
-            except AttributeError:  # v is xarray type and has no attribute .index
-                match_index = np.stack(set(res.index.values) & set(v.indexes['datetime'].values))
-                res.loc[match_index, :, k] = v.sel({'datetime': match_index}).squeeze('Stations').transpose()
+            match_index = np.stack(set(res.index.values) & set(v.indexes[time_dimension].values))
+            res.loc[match_index, :, k] = v.loc[match_index]
         return res
 
     def _get_external_data(self, station: str) -> Union[xr.DataArray, None]:
@@ -521,12 +524,15 @@ class PostProcessing(RunEnvironment):
         :param station: name of station to load external data.
         """
         try:
-            data = self.train_val_data.get_data_generator(station)
-            mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
-            external_data = self._create_observation(data, None, mean, std, transformation_method, normalised=False)
-            external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"])
-            return external_data.rename({'datetime': 'index'})
-        except KeyError:
+            data = self.train_val_data[station]
+            # target_data = data.get_Y(as_numpy=False)
+            observation = data.get_observation()
+            mean, std, transformation_method = data.get_transformation_Y()
+            # external_data = self._create_observation(target_data, None, mean, std, transformation_method, normalised=False)
+            # external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"])
+            external_data = self._create_observation(observation, None, mean, std, transformation_method, normalised=False)
+            return external_data.rename({external_data.dims[0]: 'index'})
+        except IndexError:
             return None
 
     def calculate_skill_scores(self) -> Tuple[Dict, Dict]:
@@ -540,15 +546,14 @@ class PostProcessing(RunEnvironment):
         :return: competitive and climatological skill scores
         """
         path = self.data_store.get("forecast_path")
-        window_lead_time = self.data_store.get("window_lead_time")
         skill_score_competitive = {}
         skill_score_climatological = {}
-        for station in self.test_data.stations:
-            file = os.path.join(path, f"forecasts_{station}_test.nc")
+        for station in self.test_data:
+            file = os.path.join(path, f"forecasts_{str(station)}_test.nc")
             data = xr.open_dataarray(file)
             skill_score = statistics.SkillScores(data)
             external_data = self._get_external_data(station)
-            skill_score_competitive[station] = skill_score.skill_scores(window_lead_time)
+            skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time)
             skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data,
-                                                                                          window_lead_time)
+                                                                                          self.window_lead_time)
         return skill_score_competitive, skill_score_climatological
diff --git a/src/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
similarity index 67%
rename from src/run_modules/pre_processing.py
rename to mlair/run_modules/pre_processing.py
index db7fff2ab9e385ce769f86ef95d1565ea783cc95..b4185df2f6699cb20ac96e32661433e7a6164abc 100644
--- a/src/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -5,21 +5,16 @@ __date__ = '2019-11-25'
 
 import logging
 import os
-from typing import Tuple, Dict, List
+from typing import Tuple
 
 import numpy as np
 import pandas as pd
 
-from src.data_handling import DataGenerator
-from src.helpers import TimeTracking
-from src.configuration import path_config
-from src.helpers.join import EmptyQueryResult
-from src.run_modules.run_environment import RunEnvironment
-
-DEFAULT_ARGS_LIST = ["data_path", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
-DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length",
-                       "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation",
-                       "extreme_values", "extremes_on_right_tail_only", "network", "data_preparation"]
+from mlair.data_handler import DataCollection
+from mlair.helpers import TimeTracking
+from mlair.configuration import path_config
+from mlair.helpers.join import EmptyQueryResult
+from mlair.run_modules.run_environment import RunEnvironment
 
 
 class PreProcessing(RunEnvironment):
@@ -59,10 +54,11 @@ class PreProcessing(RunEnvironment):
         self._run()
 
     def _run(self):
-        args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="preprocessing")
-        kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="preprocessing")
         stations = self.data_store.get("stations")
-        valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False, name="all")
+        data_preparation = self.data_store.get("data_preparation")
+        _, valid_stations = self.validate_station(data_preparation, stations, "preprocessing", overwrite_local_data=True)
+        if len(valid_stations) == 0:
+            raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.")
         self.data_store.set("stations", valid_stations)
         self.split_train_val_test()
         self.report_pre_processing()
@@ -70,16 +66,14 @@ class PreProcessing(RunEnvironment):
     def report_pre_processing(self):
         """Log some metrics on data and create latex report."""
         logging.debug(20 * '##')
-        n_train = len(self.data_store.get('generator', 'train'))
-        n_val = len(self.data_store.get('generator', 'val'))
-        n_test = len(self.data_store.get('generator', 'test'))
+        n_train = len(self.data_store.get('data_collection', 'train'))
+        n_val = len(self.data_store.get('data_collection', 'val'))
+        n_test = len(self.data_store.get('data_collection', 'test'))
         n_total = n_train + n_val + n_test
         logging.debug(f"Number of all stations: {n_total}")
         logging.debug(f"Number of training stations: {n_train}")
         logging.debug(f"Number of val stations: {n_val}")
         logging.debug(f"Number of test stations: {n_test}")
-        logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'test')[0][0].shape}"
-                      f"{self.data_store.get('generator', 'test')[0][1].shape}")
         self.create_latex_report()
 
     def create_latex_report(self):
@@ -121,11 +115,12 @@ class PreProcessing(RunEnvironment):
         set_names = ["train", "val", "test"]
         df = pd.DataFrame(columns=meta_data + set_names)
         for set_name in set_names:
-            data: DataGenerator = self.data_store.get("generator", set_name)
-            for station in data.stations:
-                df.loc[station, set_name] = data.get_data_generator(station).get_transposed_label().shape[0]
-                if df.loc[station, meta_data].isnull().any():
-                    df.loc[station, meta_data] = data.get_data_generator(station).meta.loc[meta_data].values.flatten()
+            data = self.data_store.get("data_collection", set_name)
+            for station in data:
+                station_name = str(station.id_class)
+                df.loc[station_name, set_name] = station.get_Y()[0].shape[0]
+                if df.loc[station_name, meta_data].isnull().any():
+                    df.loc[station_name, meta_data] = station.id_class.meta.loc[meta_data].values.flatten()
             df.loc["# Samples", set_name] = df.loc[:, set_name].sum()
             df.loc["# Stations", set_name] = df.loc[:, set_name].count()
         df[meta_round] = df[meta_round].astype(float).round(precision)
@@ -147,7 +142,7 @@ class PreProcessing(RunEnvironment):
         Split data into subsets.
 
         Currently: train, val, test and train_val (actually this is only the merge of train and val, but as an separate
-        generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs
+        data_collection). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs
         always to be executed at first, to set a proper transformation.
         """
         fraction_of_training = self.data_store.get("fraction_of_training")
@@ -184,40 +179,20 @@ class PreProcessing(RunEnvironment):
         return train_index, val_index, test_index, train_val_index
 
     def create_set_split(self, index_list: slice, set_name: str) -> None:
-        """
-        Create subsets and store in data store.
-
-        Create the subset for given split index and stores the DataGenerator with given set name in data store as
-        `generator`. Check for all valid stations using the default (kw)args for given scope and create the
-        DataGenerator for all valid stations. Also set all transformation information, if subset is training set. Make
-        sure, that the train set is executed first, and all other subsets afterwards.
-
-        :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True,
-            this list is ignored.
-        :param set_name: name to load/save all information from/to data store.
-        """
-        args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name)
-        kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope=set_name)
-        stations = args["stations"]
+        # get set stations
+        stations = self.data_store.get("stations", scope=set_name)
         if self.data_store.get("use_all_stations_on_all_data_sets"):
             set_stations = stations
         else:
             set_stations = stations[index_list]
         logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
-        # validate set
-        set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False, name=set_name)
-        self.data_store.set("stations", set_stations, scope=set_name)
-        # create set generator and store
-        set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name)
-        data_set = DataGenerator(**set_args, **kwargs)
-        self.data_store.set("generator", data_set, scope=set_name)
-        # extract transformation from train set
-        if set_name == "train":
-            self.data_store.set("transformation", data_set.transformation)
+        # create set data_collection and store
+        data_preparation = self.data_store.get("data_preparation")
+        collection, valid_stations = self.validate_station(data_preparation, set_stations, set_name)
+        self.data_store.set("stations", valid_stations, scope=set_name)
+        self.data_store.set("data_collection", collection, scope=set_name)
 
-    @staticmethod
-    def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True,
-                             name=None):
+    def validate_station(self, data_preparation, set_stations, set_name=None, overwrite_local_data=False):
         """
         Check if all given stations in `all_stations` are valid.
 
@@ -225,8 +200,8 @@ class PreProcessing(RunEnvironment):
         loading time are logged in debug mode.
 
         :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
-            `variables`, `interpolate_dim`, `target_dim`, `target_var`).
-        :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`,
+            `variables`, `time_dim`, `target_dim`, `target_var`).
+        :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolation_method`,
             `window_lead_time`).
         :param all_stations: All stations to check.
         :param name: name to display in the logging info message
@@ -234,26 +209,31 @@ class PreProcessing(RunEnvironment):
         :return: Corrected list containing only valid station IDs.
         """
         t_outer = TimeTracking()
-        t_inner = TimeTracking(start=False)
-        logging.info(f"check valid stations started{' (%s)' % name if name else ''}")
+        logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}")
+        # calculate transformation using train data
+        if set_name == "train":
+            self.transformation(data_preparation, set_stations)
+        # start station check
+        collection = DataCollection()
         valid_stations = []
-
-        # all required arguments of the DataGenerator can be found in args, positional arguments in args and kwargs
-        data_gen = DataGenerator(**args, **kwargs)
-        for pos, station in enumerate(all_stations):
-            t_inner.run()
-            logging.info(f"check station {station} ({pos + 1} / {len(all_stations)})")
+        kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name)
+        for station in set_stations:
             try:
-                data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp,
-                                                   save_local_tmp_storage=save_tmp)
-                if data.history is None:
-                    raise AttributeError
+                dp = data_preparation.build(station, name_affix=set_name, **kwargs)
+                collection.add(dp)
                 valid_stations.append(station)
-                logging.debug(
-                    f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
-                logging.debug(f"{station}: loading time = {t_inner}")
             except (AttributeError, EmptyQueryResult):
                 continue
-        logging.info(f"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/"
-                     f"{len(all_stations)} valid stations.")
-        return valid_stations
+        logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/"
+                     f"{len(set_stations)} valid stations.")
+        return collection, valid_stations
+
+    def transformation(self, data_preparation, stations):
+        if hasattr(data_preparation, "transformation"):
+            kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope="train")
+            transformation_dict = data_preparation.transformation(stations, **kwargs)
+            if transformation_dict is not None:
+                self.data_store.set("transformation", transformation_dict)
+
+
+
diff --git a/src/run_modules/run_environment.py b/mlair/run_modules/run_environment.py
similarity index 86%
rename from src/run_modules/run_environment.py
rename to mlair/run_modules/run_environment.py
index 45d0a4a019b305d477838bd9ec4c5b6f920ac6fb..ecb55282f25c369d6f5eddd81907a7d28ec7d62b 100644
--- a/src/run_modules/run_environment.py
+++ b/mlair/run_modules/run_environment.py
@@ -9,11 +9,11 @@ import os
 import shutil
 import time
 
-from src.helpers.datastore import DataStoreByScope as DataStoreObject
-from src.helpers.datastore import NameNotFoundInDataStore
-from src.helpers import Logger
-from src.helpers import TimeTracking
-from src.plotting.tracker_plot import TrackPlot
+from mlair.helpers.datastore import DataStoreByScope as DataStoreObject
+from mlair.helpers.datastore import NameNotFoundInDataStore
+from mlair.helpers import Logger
+from mlair.helpers import TimeTracking
+from mlair.plotting.tracker_plot import TrackPlot
 
 
 class RunEnvironment(object):
@@ -88,12 +88,16 @@ class RunEnvironment(object):
 
     # set data store and logger (both are mutable!)
     del_by_exit = False
-    data_store = DataStoreObject()
-    logger = Logger()
+    data_store = None
+    logger = None
     tracker_list = []
 
     def __init__(self):
         """Start time tracking automatically and logs as info."""
+        if RunEnvironment.data_store is None:
+            RunEnvironment.data_store = DataStoreObject()
+        if RunEnvironment.logger is None:
+            RunEnvironment.logger = Logger()
         self.time = TimeTracking()
         logging.info(f"{self.__class__.__name__} started")
         # atexit.register(self.__del__)
@@ -117,7 +121,7 @@ class RunEnvironment(object):
                 try:
                     self.__plot_tracking()
                     self.__save_tracking()
-                    self.__copy_log_file()
+                    self.__move_log_file()
                 except FileNotFoundError:
                     pass
                 self.data_store.clear_data_store()
@@ -132,11 +136,15 @@ class RunEnvironment(object):
             logging.error(exc_val, exc_info=(exc_type, exc_val, exc_tb))
         self.__del__()
 
-    def __copy_log_file(self):
+    def __move_log_file(self):
         try:
             new_file = self.__find_file_pattern("logging_%03i.log")
-            logging.info(f"Copy log file to {new_file}")
-            shutil.copyfile(self.logger.log_file, new_file)
+            logging.info(f"Move log file to {new_file}")
+            shutil.move(self.logger.log_file, new_file)
+            try:
+                os.rmdir(os.path.dirname(self.logger.log_file))
+            except (OSError, FileNotFoundError):
+                pass
         except (NameNotFoundInDataStore, FileNotFoundError):
             pass
 
diff --git a/src/run_modules/training.py b/mlair/run_modules/training.py
similarity index 84%
rename from src/run_modules/training.py
rename to mlair/run_modules/training.py
index 1a0d7beb1ec37bb5e59a4129da58572d79a73636..f8909e15341f959455b1e8da0b0cb7502bdfa81b 100644
--- a/src/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -11,10 +11,11 @@ from typing import Union
 import keras
 from keras.callbacks import Callback, History
 
-from src.data_handling import Distributor
-from src.model_modules.keras_extensions import CallbackHandler
-from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
-from src.run_modules.run_environment import RunEnvironment
+from mlair.data_handler import KerasIterator
+from mlair.model_modules.keras_extensions import CallbackHandler
+from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
+from mlair.run_modules.run_environment import RunEnvironment
+from mlair.configuration import path_config
 
 
 class Training(RunEnvironment):
@@ -64,9 +65,9 @@ class Training(RunEnvironment):
         """Set up and run training."""
         super().__init__()
         self.model: keras.Model = self.data_store.get("model", "model")
-        self.train_set: Union[Distributor, None] = None
-        self.val_set: Union[Distributor, None] = None
-        self.test_set: Union[Distributor, None] = None
+        self.train_set: Union[KerasIterator, None] = None
+        self.val_set: Union[KerasIterator, None] = None
+        self.test_set: Union[KerasIterator, None] = None
         self.batch_size = self.data_store.get("batch_size")
         self.epochs = self.data_store.get("epochs")
         self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model")
@@ -82,6 +83,7 @@ class Training(RunEnvironment):
         if self._trainable:
             self.train()
             self.save_model()
+            self.report_training()
         else:
             logging.info("No training has started, because trainable parameter was false.")
 
@@ -102,9 +104,9 @@ class Training(RunEnvironment):
 
         :param mode: name of set, should be from ["train", "val", "test"]
         """
-        gen = self.data_store.get("generator", mode)
-        kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=mode)
-        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs))
+        collection = self.data_store.get("data_collection", mode)
+        kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path"], scope=mode)
+        setattr(self, f"{mode}_set", KerasIterator(collection, self.batch_size, model=self.model, name=mode, **kwargs))
 
     def set_generators(self) -> None:
         """
@@ -128,15 +130,15 @@ class Training(RunEnvironment):
         """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
         logging.info(f"Train with option upsampling={self.train_set.upsampling}.")
-        logging.info(f"Train with option data_permutation={self.train_set.do_data_permutation}.")
+        logging.info(f"Train with option shuffle={self.train_set.shuffle}.")
 
         checkpoint = self.callbacks.get_checkpoint()
         if not os.path.exists(checkpoint.filepath) or self._create_new_model:
-            history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+            history = self.model.fit_generator(generator=self.train_set,
                                                steps_per_epoch=len(self.train_set),
                                                epochs=self.epochs,
                                                verbose=2,
-                                               validation_data=self.val_set.distribute_on_batches(),
+                                               validation_data=self.val_set,
                                                validation_steps=len(self.val_set),
                                                callbacks=self.callbacks.get_callbacks(as_dict=False))
         else:
@@ -146,11 +148,11 @@ class Training(RunEnvironment):
             self.model = keras.models.load_model(checkpoint.filepath)
             hist: History = self.callbacks.get_callback_by_name("hist")
             initial_epoch = max(hist.epoch) + 1
-            _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+            _ = self.model.fit_generator(generator=self.train_set,
                                          steps_per_epoch=len(self.train_set),
                                          epochs=self.epochs,
                                          verbose=2,
-                                         validation_data=self.val_set.distribute_on_batches(),
+                                         validation_data=self.val_set,
                                          validation_steps=len(self.val_set),
                                          callbacks=self.callbacks.get_callbacks(as_dict=False),
                                          initial_epoch=initial_epoch)
@@ -228,3 +230,20 @@ class Training(RunEnvironment):
         # plot learning rate
         if lr_sc:
             PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
+
+    def report_training(self):
+        data = {"mini batches": len(self.train_set),
+                "upsampling extremes": self.train_set.upsampling,
+                "shuffling": self.train_set.shuffle,
+                "created new model": self._create_new_model,
+                "epochs": self.epochs,
+                "batch size": self.batch_size}
+        import pandas as pd
+        df = pd.DataFrame.from_dict(data, orient="index", columns=["training setting"])
+        df.sort_index(inplace=True)
+        column_format = "ll"
+        path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
+        path_config.check_path_and_create(path)
+        df.to_latex(os.path.join(path, "training_settings.tex"), na_rep='---', column_format=column_format)
+        df.to_markdown(open(os.path.join(path, "training_settings.md"), mode="w", encoding='utf-8'),
+                       tablefmt="github")
\ No newline at end of file
diff --git a/src/run.py b/mlair/run_script.py
similarity index 68%
rename from src/run.py
rename to mlair/run_script.py
index 7e262dd769204077697b7df3f3fbaedb4c012257..00a28f686bf392f76787b56a48790999e9fa5c05 100644
--- a/src/run.py
+++ b/mlair/run_script.py
@@ -1,22 +1,20 @@
 __author__ = "Lukas Leufen"
 __date__ = '2020-06-29'
 
-from src.workflows import DefaultWorkflow
+from mlair.workflows import DefaultWorkflow
 import inspect
 
 
 def run(stations=None,
-        station_type=None,
         trainable=None, create_new_model=None,
         window_history_size=None,
         experiment_date="testrun",
-        network=None,
         variables=None, statistics_per_var=None,
         start=None, end=None,
         target_var=None, target_dim=None,
         window_lead_time=None,
         dimensions=None,
-        interpolate_method=None, interpolate_dim=None, limit_nan_fill=None,
+        interpolation_method=None, interpolation_dim=None, interpolation_limit=None,
         train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None,
         use_all_stations_on_all_data_sets=None, fraction_of_train=None,
         experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, overwrite_local_data=None,
@@ -29,15 +27,17 @@ def run(stations=None,
         model=None,
         batch_size=None,
         epochs=None,
-        data_preparation=None):
+        data_preparation=None,
+        **kwargs):
 
     params = inspect.getfullargspec(DefaultWorkflow).args
-    kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
+    kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None}
 
-    workflow = DefaultWorkflow(**kwargs)
+    workflow = DefaultWorkflow(**kwargs_default, **kwargs)
     workflow.run()
 
 
 if __name__ == "__main__":
-
-    run()
+    from mlair.model_modules.model_class import MyBranchedModel
+    run(statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True,
+        create_new_model=True, model=MyBranchedModel, station_type="background")
diff --git a/mlair/workflows/__init__.py b/mlair/workflows/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..27c060f10975d86aa35c1f2d45e66966002ecd63
--- /dev/null
+++ b/mlair/workflows/__init__.py
@@ -0,0 +1,2 @@
+from mlair.workflows.abstract_workflow import Workflow
+from mlair.workflows.default_workflow import DefaultWorkflow, DefaultWorkflowHPC
\ No newline at end of file
diff --git a/src/workflows/abstract_workflow.py b/mlair/workflows/abstract_workflow.py
similarity index 94%
rename from src/workflows/abstract_workflow.py
rename to mlair/workflows/abstract_workflow.py
index 5d4e62c8a2e409e865f43412a6757a9cb4e4b1f3..d3fe480fdfe09393fbf2051d8795735e9217a8ad 100644
--- a/src/workflows/abstract_workflow.py
+++ b/mlair/workflows/abstract_workflow.py
@@ -5,7 +5,7 @@ __date__ = '2020-06-26'
 
 from collections import OrderedDict
 
-from src import RunEnvironment
+from mlair import RunEnvironment
 
 
 class Workflow:
@@ -26,4 +26,4 @@ class Workflow:
         """Run workflow embedded in a run environment and according to the stage's ordering."""
         with RunEnvironment():
             for stage, kwargs in self._registry.items():
-                stage(**kwargs)
\ No newline at end of file
+                stage(**kwargs)
diff --git a/src/workflows/default_workflow.py b/mlair/workflows/default_workflow.py
similarity index 81%
rename from src/workflows/default_workflow.py
rename to mlair/workflows/default_workflow.py
index bbad7428cb4ffa81e968420332caaaca7925fdc5..3dba7e6c5c5773fa4d74860b2cba67a5804123b7 100644
--- a/src/workflows/default_workflow.py
+++ b/mlair/workflows/default_workflow.py
@@ -4,9 +4,9 @@ __author__ = "Lukas Leufen"
 __date__ = '2020-06-26'
 
 import inspect
-from src.helpers import remove_items
-from src.run_modules import ExperimentSetup, PreProcessing, PartitionCheck, ModelSetup, Training, PostProcessing
-from src.workflows.abstract_workflow import Workflow
+from mlair.helpers import remove_items
+from mlair.run_modules import ExperimentSetup, PreProcessing, PartitionCheck, ModelSetup, Training, PostProcessing
+from mlair.workflows.abstract_workflow import Workflow
 
 
 class DefaultWorkflow(Workflow):
@@ -14,17 +14,15 @@ class DefaultWorkflow(Workflow):
     the mentioned ordering."""
 
     def __init__(self, stations=None,
-        station_type=None,
         trainable=None, create_new_model=None,
         window_history_size=None,
         experiment_date="testrun",
-        network=None,
         variables=None, statistics_per_var=None,
         start=None, end=None,
         target_var=None, target_dim=None,
         window_lead_time=None,
         dimensions=None,
-        interpolate_method=None, interpolate_dim=None, limit_nan_fill=None,
+        interpolation_method=None, time_dim=None, limit_nan_fill=None,
         train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None,
         use_all_stations_on_all_data_sets=None, fraction_of_train=None,
         experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, overwrite_local_data=None,
@@ -37,13 +35,14 @@ class DefaultWorkflow(Workflow):
         model=None,
         batch_size=None,
         epochs=None,
-        data_preparation=None):
+        data_preparation=None,
+                 **kwargs):
         super().__init__()
 
         # extract all given kwargs arguments
         params = remove_items(inspect.getfullargspec(self.__init__).args, "self")
-        kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
-        self._setup(**kwargs)
+        kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None}
+        self._setup(**kwargs_default, **kwargs)
 
     def _setup(self, **kwargs):
         """Set up default workflow."""
@@ -59,17 +58,15 @@ class DefaultWorkflowHPC(Workflow):
     Training and PostProcessing in exact the mentioned ordering."""
 
     def __init__(self, stations=None,
-        station_type=None,
         trainable=None, create_new_model=None,
         window_history_size=None,
         experiment_date="testrun",
-        network=None,
         variables=None, statistics_per_var=None,
         start=None, end=None,
         target_var=None, target_dim=None,
         window_lead_time=None,
         dimensions=None,
-        interpolate_method=None, interpolate_dim=None, limit_nan_fill=None,
+        interpolation_method=None, time_dim=None, limit_nan_fill=None,
         train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None,
         use_all_stations_on_all_data_sets=None, fraction_of_train=None,
         experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, overwrite_local_data=None,
@@ -82,13 +79,13 @@ class DefaultWorkflowHPC(Workflow):
         model=None,
         batch_size=None,
         epochs=None,
-        data_preparation=None):
+        data_preparation=None, **kwargs):
         super().__init__()
 
         # extract all given kwargs arguments
         params = remove_items(inspect.getfullargspec(self.__init__).args, "self")
-        kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
-        self._setup(**kwargs)
+        kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None}
+        self._setup(**kwargs_default, **kwargs)
 
     def _setup(self, **kwargs):
         """Set up default workflow."""
diff --git a/requirements.txt b/requirements.txt
index 71bb1338effff38092510982d4a2c1f37f7b026a..7da29a05b748531fd4ec327ff17f432ff1ecaabb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -38,9 +38,9 @@ pydot==1.4.1
 pyparsing==2.4.6
 pyproj==2.5.0
 pyshp==2.1.0
-pytest==5.3.5
-pytest-cov==2.8.1
-pytest-html==2.0.1
+pytest==6.0.0
+pytest-cov==2.10.0
+pytest-html==2.1.1
 pytest-lazy-fixture==0.6.3
 pytest-metadata==1.8.0
 pytest-sugar
diff --git a/run.py b/run.py
index a9d8190628e1692c4b2812d3c8790bccd6b1b589..15f30a7ee775948fa744832a464562cd40c3e460 100644
--- a/run.py
+++ b/run.py
@@ -2,7 +2,7 @@ __author__ = "Lukas Leufen"
 __date__ = '2020-06-29'
 
 import argparse
-from src.workflows import DefaultWorkflow
+from mlair.workflows import DefaultWorkflow
 
 
 def main(parser_args):
diff --git a/run_HPC.py b/run_HPC.py
index fc2ead406469f0a254f5819e43c1e0d3542bb8d9..d6dbb4dc61e88a1e139b3cbe549bc6a3f2f0ab8a 100644
--- a/run_HPC.py
+++ b/run_HPC.py
@@ -2,7 +2,7 @@ __author__ = "Lukas Leufen"
 __date__ = '2020-06-29'
 
 import argparse
-from src.workflows import DefaultWorkflowHPC
+from mlair.workflows import DefaultWorkflowHPC
 
 
 def main(parser_args):
diff --git a/run_hourly.py b/run_hourly.py
index 682988f6f730d02be713c074dd63fc732e2868dc..b831cf1e1ee733a3c652c6cea364013b44cf2c0d 100644
--- a/run_hourly.py
+++ b/run_hourly.py
@@ -3,7 +3,7 @@ __date__ = '2019-11-14'
 
 import argparse
 
-from src.workflows import DefaultWorkflow
+from mlair.workflows import DefaultWorkflow
 
 
 def main(parser_args):
diff --git a/run_zam347.py b/run_zam347.py
index 2d351a8925e67b0bdfc010e92a3937435e160b2f..9027bec807ad9beafcdac573a70aa32d34491034 100644
--- a/run_zam347.py
+++ b/run_zam347.py
@@ -5,13 +5,13 @@ import argparse
 import json
 import logging
 
-from src.run_modules.experiment_setup import ExperimentSetup
-from src.run_modules.model_setup import ModelSetup
-from src.run_modules.post_processing import PostProcessing
-from src.run_modules.pre_processing import PreProcessing
-from src.run_modules.run_environment import RunEnvironment
-from src.run_modules.training import Training
-from src.workflows import DefaultWorkflowHPC
+from mlair.run_modules.experiment_setup import ExperimentSetup
+from mlair.run_modules.model_setup import ModelSetup
+from mlair.run_modules.post_processing import PostProcessing
+from mlair.run_modules.pre_processing import PreProcessing
+from mlair.run_modules.run_environment import RunEnvironment
+from mlair.run_modules.training import Training
+from mlair.workflows import DefaultWorkflowHPC
 
 
 def load_stations():
diff --git a/setup.py b/setup.py
index 8e08e921f5fb728f7b1758e4bb385efc7d71c29b..f708febb5a70c957a91059d840a1f4e140ad35c0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,7 @@
 
 import setuptools
 
-from src import __version__, __author__, __email__
+from mlair import __version__, __author__, __email__
 
 
 with open("README.md", "r") as fh:
@@ -9,7 +9,7 @@ with open("README.md", "r") as fh:
 
 
 setuptools.setup(
-    name="mlt",
+    name="mlair",
     version=__version__,
     author=__author__,
     author_email=__email__,
@@ -17,8 +17,7 @@ setuptools.setup(
     long_description=long_description,
     long_description_content_type="text/markdown",
     url="https://gitlab.version.fz-juelich.de/toar/machinelearningtools",
-    package_dir={'': 'src'},
-    packages=setuptools.find_packages(where="src"),
+    packages=setuptools.find_packages(),
     classifiers=[
         "Programming Language :: Python :: 3",
         "License :: OSI Approved :: MIT License",  #  to be adjusted
diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py
deleted file mode 100644
index f50775900c053cbef0c94e6a3e2743c9a017bf88..0000000000000000000000000000000000000000
--- a/src/data_handling/bootstraps.py
+++ /dev/null
@@ -1,383 +0,0 @@
-"""
-Collections of bootstrap methods and classes.
-
-How to use
-----------
-
-test
-
-"""
-
-__author__ = 'Felix Kleinert, Lukas Leufen'
-__date__ = '2020-02-07'
-
-
-import logging
-import os
-import re
-from typing import List, Union, Pattern, Tuple
-
-import dask.array as da
-import keras
-import numpy as np
-import xarray as xr
-
-from src import helpers
-from src.data_handling.data_generator import DataGenerator
-
-
-class BootStrapGenerator(keras.utils.Sequence):
-    """
-    Generator that returns bootstrapped history objects for given boot index while iteration.
-
-    generator for bootstraps as keras sequence inheritance. Initialise with number of boots, the original history, the
-    shuffled data, all used variables and the current shuffled variable. While iterating over this generator, it returns
-    the bootstrapped history for given boot index (this is the iterator index) in the same format like the original
-    history ready to use. Note, that in some cases some samples can contain nan values (in these cases the entire data
-    row is null, not only single entries).
-    """
-
-    def __init__(self, number_of_boots: int, history: xr.DataArray, shuffled: xr.DataArray, variables: List[str],
-                 shuffled_variable: str):
-        """
-        Set up the generator.
-
-        :param number_of_boots: number of bootstrap realisations
-        :param history: original history (the ground truth)
-        :param shuffled: the shuffled history
-        :param variables: list with all variables of interest
-        :param shuffled_variable: name of the variable that shall be bootstrapped
-        """
-        self.number_of_boots = number_of_boots
-        self.variables = variables
-        self.history_orig = history
-        self.history = history.sel(variables=helpers.remove_items(self.variables, shuffled_variable))
-        self.shuffled = shuffled.sel(variables=shuffled_variable)
-
-    def __len__(self) -> int:
-        """
-        Return number of bootstraps.
-
-        :return: number of bootstraps
-        """
-        return self.number_of_boots
-
-    def __getitem__(self, index: int) -> xr.DataArray:
-        """
-        Return bootstrapped history for given bootstrap index in same index structure like the original history object.
-
-        :param index: boot index e [0, nboots-1]
-        :return: bootstrapped history ready to use
-        """
-        logging.debug(f"boot: {index}")
-        boot_hist = self.history.copy()
-        boot_hist = boot_hist.combine_first(self.__get_shuffled(index))
-        return boot_hist.reindex_like(self.history_orig)
-
-    def __get_shuffled(self, index: int) -> xr.DataArray:
-        """
-        Return shuffled data for given boot index from shuffled attribute.
-
-        :param index: boot index e [0, nboots-1]
-        :return: shuffled data
-        """
-        shuffled_var = self.shuffled.sel(boots=index).expand_dims("variables").drop("boots")
-        return shuffled_var.transpose("datetime", "window", "Stations", "variables")
-
-
-class CreateShuffledData:
-    """
-    Verify and create shuffled data for all data contained in given data generator class.
-
-    Starts automatically on initialisation, no further calls are required. Check and new creations are all performed
-    inside bootstrap_path.
-    """
-
-    def __init__(self, data: DataGenerator, number_of_bootstraps: int, bootstrap_path: str):
-        """
-        Shuffled data is automatically created in initialisation.
-
-        :param data: data to shuffle
-        :param number_of_bootstraps:
-        :param bootstrap_path: Path to find and store the bootstraps
-        """
-        self.data = data
-        self.number_of_bootstraps = number_of_bootstraps
-        self.bootstrap_path = bootstrap_path
-        self.create_shuffled_data()
-
-    def create_shuffled_data(self) -> None:
-        """
-        Create shuffled data.
-
-        Use original test data, add dimension 'boots' with length number of bootstraps and insert randomly selected
-        variables. If there is a suitable local file for requested window size and number of bootstraps, no additional
-        file will be created inside this function.
-        """
-        logging.info("create / check shuffled bootstrap data")
-        variables_str = '_'.join(sorted(self.data.variables))
-        window = self.data.window_history_size
-        for station in self.data.stations:
-            valid, nboot = self.valid_bootstrap_file(station, variables_str, window)
-            if not valid:
-                logging.info(f'create bootstap data for {station}')
-                hist = self.data.get_data_generator(station).get_transposed_history()
-                file_path = self._set_file_path(station, variables_str, window, nboot)
-                hist = hist.expand_dims({'boots': range(nboot)}, axis=-1)
-                shuffled_variable = []
-                chunks = (100, *hist.shape[1:3], hist.shape[-1])
-                for i, var in enumerate(hist.coords['variables']):
-                    single_variable = hist.sel(variables=var).values
-                    shuffled_variable.append(self.shuffle(single_variable, chunks=chunks))
-                shuffled_variable_da = da.stack(shuffled_variable, axis=-2).rechunk("auto")
-                shuffled_data = xr.DataArray(shuffled_variable_da, coords=hist.coords, dims=hist.dims)
-                shuffled_data.to_netcdf(file_path)
-
-    def _set_file_path(self, station: str, variables: str, window: int, nboots: int) -> str:
-        """
-        Set file name.
-
-        Set file name following naming convention <station>_<var1>_<var2>_..._hist<window>_nboots<nboots>_shuffled.nc
-        and create joined path using bootstrap_path attribute set on initialisation.
-
-        :param station: station name
-        :param variables: variables already preprocessed as single string with all variables seperated by underscore
-        :param window: window length
-        :param nboots: number of boots
-        :return: full file path
-        """
-        file_name = f"{station}_{variables}_hist{window}_nboots{nboots}_shuffled.nc"
-        return os.path.join(self.bootstrap_path, file_name)
-
-    def valid_bootstrap_file(self, station: str, variables: str, window: int) -> [bool, Union[None, int]]:
-        """
-        Compare local bootstrap file with given settings for station, variables, window and number of bootstraps.
-
-        If a match was found, this method returns a tuple (True, None). In any other case, it returns (False,
-        max_nboot), where max_nboot is the highest boot number found in the local storage. A match is defined so that
-        the window length is ge than given window size form args and the number of boots is also ge than the given
-        number of boots from this class. Furthermore, this functions deletes local files, if the match the station
-        pattern but don't fit the window and bootstrap condition. This is performed, because it is assumed, that the
-        corresponding file will be created with a longer or at the least same window size and numbers of bootstraps.
-
-        :param station: name of the station to validate
-        :param variables: all variables already merged in single string seperated by underscore
-        :param window: required window size
-        :return: tuple containing information if valid file was found first and second the number of boots that needs to
-            be used for the new boot creation (this is only relevant, if no valid file was found - otherwise the return
-            statement is anyway None).
-        """
-        regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled")
-        max_nboot = self.number_of_bootstraps
-        for file in os.listdir(self.bootstrap_path):
-            match = regex.match(file)
-            if match:
-                window_file = int(match.group(1))
-                nboot_file = int(match.group(2))
-                max_nboot = max([max_nboot, nboot_file])
-                if (window_file >= window) and (nboot_file >= self.number_of_bootstraps):
-                    return True, None
-                else:
-                    os.remove(os.path.join(self.bootstrap_path, file))
-        return False, max_nboot
-
-    @staticmethod
-    def shuffle(data: da.array, chunks: Tuple) -> da.core.Array:
-        """
-        Shuffle randomly from given data (draw elements with replacement).
-
-        :param data: data to shuffle
-        :param chunks: chunk size for dask
-        :return: shuffled data as dask core array (not computed yet)
-        """
-        size = data.shape
-        return da.random.choice(data.reshape(-1, ), size=size, chunks=chunks)
-
-
-class BootStraps:
-    """
-    Main class to perform bootstrap operations.
-
-    This class requires a DataGenerator object and a path, where to find and store all data related to the bootstrap
-    operation. In initialisation, this class will automatically call the class CreateShuffleData to set up the shuffled
-    data sets. How to use BootStraps:
-
-    * call .get_generator(<station>, <variable>) to get a generator for given station and variable combination that \
-        iterates over all bootstrap realisations (as keras sequence)
-    * call .get_labels(<station>) to get the measured observations in the same format as bootstrap predictions
-    * call .get_bootstrap_predictions(<station>, <variable>) to get the bootstrapped predictions
-    * call .get_orig_prediction(<station>) to get the non-bootstrapped predictions (referred as original predictions)
-    """
-
-    def __init__(self, data: DataGenerator, bootstrap_path: str, number_of_bootstraps: int = 10):
-        """
-        Automatically check and create (if needed) shuffled data on initialisation.
-
-        :param data: a data generator object to get data / history
-        :param bootstrap_path: path to find and store the bootstrap data
-        :param number_of_bootstraps: the number of bootstrap realisations
-        """
-        self.data = data
-        self.number_of_bootstraps = number_of_bootstraps
-        self.bootstrap_path = bootstrap_path
-        CreateShuffledData(data, number_of_bootstraps, bootstrap_path)
-
-    @property
-    def stations(self) -> List[str]:
-        """
-        Station property inherits directly from data generator object.
-
-        :return: list with all stations
-        """
-        return self.data.stations
-
-    @property
-    def variables(self) -> List[str]:
-        """
-        Variables property inherits directly from data generator object.
-
-        :return: list with all variables
-        """
-        return self.data.variables
-
-    @property
-    def window_history_size(self) -> int:
-        """
-        Window history size property inherits directly from data generator object.
-
-        :return: the window history size
-        """
-        return self.data.window_history_size
-
-    def get_generator(self, station: str, variable: str) -> BootStrapGenerator:
-        """
-        Return the actual generator to use for the bootstrap evaluation.
-
-        The generator requires information on station and bootstrapped variable. There is only a loop on the bootstrap
-        realisation and not on stations or variables.
-
-        :param station: name of the station
-        :param variable: name of the variable to bootstrap
-        :return: BootStrapGenerator class ready to use.
-        """
-        hist, _ = self.data[station]
-        shuffled_data = self._load_shuffled_data(station, self.variables).reindex_like(hist)
-        return BootStrapGenerator(self.number_of_bootstraps, hist, shuffled_data, self.variables, variable)
-
-    def get_labels(self, station: str) -> np.ndarray:
-        """
-        Repeat labels for given key by the number of boots and returns as single array.
-
-        :param station: name of station
-        :return: repeated labels as single array
-        """
-        labels = self.data[station][1]
-        return np.tile(labels.data, (self.number_of_bootstraps, 1))
-
-    def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN") -> np.ndarray:
-        """
-        Repeat predictions from given file(_name) in path by the number of boots.
-
-        :param path: path to file
-        :param file_name: file name
-        :param prediction_name: name of the prediction to select from loaded file (default CNN)
-        :return: repeated predictions
-        """
-        file = os.path.join(path, file_name)
-        prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
-        vals = np.tile(prediction.data, (self.number_of_bootstraps, 1))
-        return vals[~np.isnan(vals).any(axis=1), :]
-
-    def _load_shuffled_data(self, station: str, variables: List[str]) -> xr.DataArray:
-        """
-        Load shuffled data from bootstrap path.
-
-        Data is stored as '<station>_<var1>_<var2>_..._hist<histsize>_nboots<nboots>_shuffled.nc', e.g.
-        'DEBW107_cloudcover_no_no2_temp_u_v_hist13_nboots20_shuffled.nc'
-
-        :param station: name of station
-        :param variables: list of variables
-        :return: shuffled data as xarray
-        """
-        file_name = self._get_shuffled_data_file(station, variables)
-        shuffled_data = xr.open_dataarray(file_name, chunks=100)
-        return shuffled_data
-
-    def _get_shuffled_data_file(self, station: str, variables: List[str]) -> str:
-        """
-        Look for data file using regular expressions and returns found file or raise FileNotFoundError.
-
-        :param station: name of station
-        :param variables: name of variables
-        :return: found file with complete path
-        """
-        files = os.listdir(self.bootstrap_path)
-        regex = self._create_file_regex(station, variables)
-        file = self._filter_files(regex, files, self.window_history_size, self.number_of_bootstraps)
-        if file:
-            return os.path.join(self.bootstrap_path, file)
-        else:
-            raise FileNotFoundError(f"Could not find a file to match pattern {regex}")
-
-    @staticmethod
-    def _create_file_regex(station: str, variables: List[str]) -> Pattern:
-        """
-        Create regex for given station and variables.
-
-        With this regex, it is possible to look for shuffled data with pattern:
-        `<station>(_<var>)*_hist(<hist>)_nboots(<nboots>)_shuffled.nc`
-
-        :param station: station name to use as prefix
-        :param variables: variables to add after station
-        :return: compiled regular expression
-        """
-        var_regex = "".join([rf"(_\w+)*_{v}(_\w+)*" for v in sorted(variables)])
-        regex = re.compile(rf"{station}{var_regex}_hist(\d+)_nboots(\d+)_shuffled\.nc")
-        return regex
-
-    @staticmethod
-    def _filter_files(regex: Pattern, files: List[str], window: int, nboot: int) -> Union[str, None]:
-        """
-        Filter list of files by regex.
-
-        Regex has to be structured to match the following string structure
-        `<station>(_<var>)*_hist(<hist>)_nboots(<nboots>)_shuffled.nc`. Hist and nboots values have to be included as
-        group. All matches are compared to given window and nboot parameters. A valid file must have the same value (or
-        larger) than these parameters and contain all variables.
-
-        :param regex: compiled regular expression pattern following the style from method description
-        :param files: list of file names to filter
-        :param window: minimum length of window to look for
-        :param nboot: minimal number of boots to search
-        :return: matching file name or None, if no valid file was found
-        """
-        for f in files:
-            match = regex.match(f)
-            if match:
-                last = match.lastindex
-                if (int(match.group(last - 1)) >= window) and (int(match.group(last)) >= nboot):
-                    return f
-
-
-if __name__ == "__main__":
-
-    from src.run_modules.experiment_setup import ExperimentSetup
-    from src.run_modules.run_environment import RunEnvironment
-    from src.run_modules.pre_processing import PreProcessing
-
-    formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
-    logging.basicConfig(format=formatter, level=logging.INFO)
-
-    with RunEnvironment() as run_env:
-        ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'],
-                        station_type='background', trainable=True, window_history_size=9)
-        PreProcessing()
-
-        data = run_env.data_store.get("generator", "general.test")
-        path = run_env.data_store.get("bootstrap_path", "general")
-        number_bootstraps = 10
-
-        boots = BootStraps(data, path, number_bootstraps)
-        for b in boots.boot_strap_generator():
-            a, c = b
-        logging.info(f"len is {len(boots.get_boot_strap_meta())}")
diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
deleted file mode 100644
index 2600afcbd8948c26a2b4cf37329b424cac69f40a..0000000000000000000000000000000000000000
--- a/src/data_handling/data_distributor.py
+++ /dev/null
@@ -1,132 +0,0 @@
-"""
-Data Distribution Module.
-
-How to use
-----------
-
-Create distributor object from a generator object and parse it to the fit generator method. Provide the number of
-steps per epoch with distributor's length method.
-
-.. code-block:: python
-
-    model = YourKerasModel()
-    data_generator = DataGenerator(*args, **kwargs)
-    data_distributor = Distributor(data_generator, model, **kwargs)
-    history = model.fit_generator(generator=data_distributor.distribute_on_batches(),
-                                  steps_per_epoch=len(data_distributor),
-                                  epochs=10,)
-
-Additionally, a validation data set can be parsed using the length and distribute methods.
-"""
-
-from __future__ import generator_stop
-
-__author__ = "Lukas Leufen, Felix Kleinert"
-__date__ = '2019-12-05'
-
-import math
-
-import keras
-import numpy as np
-
-from src.data_handling.data_generator import DataGenerator
-
-
-class Distributor(keras.utils.Sequence):
-    """Distribute data generator elements according to mini batch size."""
-
-    def __init__(self, generator: DataGenerator, model: keras.models, batch_size: int = 256,
-                 permute_data: bool = False, upsampling: bool = False):
-        """
-        Set up distributor.
-
-        :param generator: The generator object must be iterable and return inputs and targets on each iteration
-        :param model: a keras model with one or more output branches
-        :param batch_size: batch size to use
-        :param permute_data: data is randomly permuted if enabled on each train step
-        :param upsampling: upsample data with upsample extremes data from generator object and shuffle data or use only
-            the standard input data.
-        """
-        self.generator = generator
-        self.model = model
-        self.batch_size = batch_size
-        self.do_data_permutation = permute_data
-        self.upsampling = upsampling
-
-    def _get_model_rank(self):
-        mod_out = self.model.output_shape
-        if isinstance(mod_out, tuple):
-            # only one output branch: (None, ahead)
-            mod_rank = 1
-        elif isinstance(mod_out, list):
-            # multiple output branches, e.g.: [(None, ahead), (None, ahead)]
-            mod_rank = len(mod_out)
-        else:  # pragma: no cover
-            raise TypeError("model output shape must either be tuple or list.")
-        return mod_rank
-
-    def _get_number_of_mini_batches(self, values):
-        return math.ceil(values.shape[0] / self.batch_size)
-
-    def _permute_data(self, x, y):
-        """
-        Permute inputs x and labels y if permutation is enabled in instance.
-
-        :param x: inputs
-        :param y: labels
-        :return: permuted or original data
-        """
-        if self.do_data_permutation:
-            p = np.random.permutation(len(x))  # equiv to .shape[0]
-            x = x[p]
-            y = y[p]
-        return x, y
-
-    def distribute_on_batches(self, fit_call=True):
-        """
-        Create generator object to distribute mini batches.
-
-        Split data from given generator object (usually for single station) according to the given batch size. Also
-        perform upsampling if enabled and random shuffling (either if data permutation is enabled or if upsampling is
-        enabled). Lastly multiply targets if provided model has multiple output branches.
-
-        :param fit_call: switch to exit while loop after first iteration. This is used to determine the length of all
-            distributed mini batches. For default, fit_call is True to obtain infinite loop for training.
-        :return: yields next mini batch
-        """
-        while True:
-            for k, v in enumerate(self.generator):
-                # get rank of output
-                mod_rank = self._get_model_rank()
-                # get data
-                x_total = np.copy(v[0])
-                y_total = np.copy(v[1])
-                if self.upsampling:
-                    try:
-                        s = self.generator.get_data_generator(k)
-                        x_total = np.concatenate([x_total, np.copy(s.get_extremes_history())], axis=0)
-                        y_total = np.concatenate([y_total, np.copy(s.get_extremes_label())], axis=0)
-                    except AttributeError:  # no extremes history / labels available, copy will fail
-                        pass
-                # get number of mini batches
-                num_mini_batches = self._get_number_of_mini_batches(x_total)
-                # permute order for mini-batches
-                x_total, y_total = self._permute_data(x_total, y_total)
-                for prev, curr in enumerate(range(1, num_mini_batches + 1)):
-                    x = x_total[prev * self.batch_size:curr * self.batch_size, ...]
-                    y = [y_total[prev * self.batch_size:curr * self.batch_size, ...] for _ in range(mod_rank)]
-                    if x is not None:  # pragma: no branch
-                        yield x, y
-                        if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call:
-                            return
-
-    def __len__(self) -> int:
-        """
-        Total number of distributed mini batches.
-
-        :return: the length of the distribute on batches object
-        """
-        num_batch = 0
-        for _ in self.distribute_on_batches(fit_call=False):
-            num_batch += 1
-        return num_batch
diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py
deleted file mode 100644
index 8e14d019634d134d01b7edec92021aed20b59ecc..0000000000000000000000000000000000000000
--- a/src/data_handling/data_generator.py
+++ /dev/null
@@ -1,366 +0,0 @@
-"""Data Generator class to handle large arrays for machine learning."""
-
-__author__ = 'Felix Kleinert, Lukas Leufen'
-__date__ = '2019-11-07'
-
-import logging
-import os
-import pickle
-from typing import Union, List, Tuple, Any, Dict
-
-import dask.array as da
-import keras
-import xarray as xr
-
-from src import helpers
-from src.data_handling.data_preparation import AbstractDataPrep
-from src.helpers.join import EmptyQueryResult
-
-number = Union[float, int]
-num_or_list = Union[number, List[number]]
-data_or_none = Union[xr.DataArray, None]
-
-
-class DataGenerator(keras.utils.Sequence):
-    """
-    This class is a generator to handle large arrays for machine learning.
-
-    .. code-block:: python
-
-        data_generator = DataGenerator(**args, **kwargs)
-
-    Data generator item can be called manually by position (integer) or  station id (string). Methods also accept lists
-    with exactly one entry of integer or string.
-
-    .. code-block::
-
-        # select generator elements by position index
-        first_element = data_generator.get_data_generator([0])  # 1st element
-        n_element = data_generator.get_data_generator([4])  # 5th element
-
-        # select by name
-        station_xy = data_generator.get_data_generator(["station_xy"])  # will raise KeyError if not available
-
-    If used as iterator or directly called by get item method, the data generator class returns transposed labels and
-    history object from underlying data preparation class DataPrep.
-
-    .. code-block:: python
-
-        # select history and label by position
-        hist, labels = data_generator[0]
-        # by name
-        hist, labels = data_generator["station_xy"]
-        # as iterator
-        for (hist, labels) in data_generator:
-            pass
-
-    This class can also be used with keras' fit_generator and predict_generator. Individual stations are the iterables.
-    """
-
-    def __init__(self, data_path: str, stations: Union[str, List[str]], variables: List[str],
-                 interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
-                 interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
-                 window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None,
-                 data_preparation=None, **kwargs):
-        """
-        Set up data generator.
-
-        :param data_path: path to data
-        :param stations: list with all stations to include
-        :param variables: list with all used variables
-        :param interpolate_dim: dimension along which interpolation is applied
-        :param target_dim: dimension of target variable
-        :param target_var: name of target variable
-        :param station_type: TOAR station type classification (background, traffic)
-        :param interpolate_method: method of interpolation
-        :param limit_nan_fill: maximum gab in data to fill by interpolation
-        :param window_history_size: length of the history window
-        :param window_lead_time: lenght of the label window
-        :param transformation: transformation method to apply on data
-        :param extreme_values: set up the extreme value upsampling
-        :param kwargs: additional kwargs that are used in either DataPrep (transformation, start / stop period, ...)
-            or extreme values
-        """
-        self.data_path = os.path.abspath(data_path)
-        self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
-        if not os.path.exists(self.data_path_tmp):
-            os.makedirs(self.data_path_tmp)
-        self.stations = helpers.to_list(stations)
-        self.variables = variables
-        self.interpolate_dim = interpolate_dim
-        self.target_dim = target_dim
-        self.target_var = target_var
-        self.station_type = station_type
-        self.interpolate_method = interpolate_method
-        self.limit_nan_fill = limit_nan_fill
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-        self.extreme_values = extreme_values
-        self.DataPrep = data_preparation if data_preparation is not None else AbstractDataPrep
-        self.kwargs = kwargs
-        self.transformation = self.setup_transformation(transformation)
-
-    def __repr__(self):
-        """Display all class attributes."""
-        return f"DataGenerator(path='{self.data_path}', stations={self.stations}, " \
-               f"variables={self.variables}, station_type={self.station_type}, " \
-               f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \
-               f"target_var='{self.target_var}', **{self.kwargs})"
-
-    def __len__(self):
-        """Return the number of stations."""
-        return len(self.stations)
-
-    def __iter__(self) -> "DataGenerator":
-        """
-        Define the __iter__ part of the iterator protocol to iterate through this generator.
-
-        Sets the private attribute `_iterator` to 0.
-        """
-        self._iterator = 0
-        return self
-
-    def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]:
-        """
-        Get the data generator, and return the history and label data of this generator.
-
-        This is the implementation of the __next__ method of the iterator protocol.
-        """
-        if self._iterator < self.__len__():
-            data = self.get_data_generator()
-            self._iterator += 1
-            if data.history is not None and data.label is not None:  # pragma: no branch
-                return data.get_transposed_history(), data.get_transposed_label()
-            else:
-                self.__next__()  # pragma: no cover
-        else:
-            raise StopIteration
-
-    def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]:
-        """
-        Define the get item method for this generator.
-
-        Retrieve data from generator and return history and labels.
-
-        :param item: station key to choose the data generator.
-        :return: The generator's time series of history data and its labels
-        """
-        data = self.get_data_generator(key=item)
-        return data.get_transposed_history(), data.get_transposed_label()
-
-    def setup_transformation(self, transformation: Dict):
-        """
-        Set up transformation by extracting all relevant information.
-
-        Extract all information from transformation dictionary. Possible keys are scope. method, mean, and std. Scope
-        can either be station or data. Station scope means, that data transformation is performed for each station
-        independently (somehow like batch normalisation), whereas data scope means a transformation applied on the
-        entire data set.
-
-        * If using data scope, mean and standard deviation (each only if required by transformation method) can either
-          be calculated accurate or as an estimate (faster implementation). This must be set in dictionary  either
-          as "mean": "accurate" or "mean": "estimate". In both cases, the required statistics are calculated and saved.
-          After this calculations, the mean key is overwritten by the actual values to use.
-        * If using station scope, no additional information is required.
-        * If a transformation should be applied on base of existing values, these need to be provided in the respective
-          keys "mean" and "std" (again only if required for given method).
-
-        :param transformation: the transformation dictionary as described above.
-
-        :return: updated transformation dictionary
-        """
-        if transformation is None:
-            return
-        transformation = transformation.copy()
-        scope = transformation.get("scope", "station")
-        method = transformation.get("method", "standardise")
-        mean = transformation.get("mean", None)
-        std = transformation.get("std", None)
-        if scope == "data":
-            if isinstance(mean, str):
-                if mean == "accurate":
-                    mean, std = self.calculate_accurate_transformation(method)
-                elif mean == "estimate":
-                    mean, std = self.calculate_estimated_transformation(method)
-                else:
-                    raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or"
-                                     f"be an array with already calculated means. Given was: {mean}")
-        elif scope == "station":
-            mean, std = None, None
-        else:
-            raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}")
-        transformation["method"] = method
-        transformation["mean"] = mean
-        transformation["std"] = std
-        return transformation
-
-    def calculate_accurate_transformation(self, method: str) -> Tuple[data_or_none, data_or_none]:
-        """
-        Calculate accurate transformation statistics.
-
-        Use all stations of this generator and calculate mean and standard deviation on entire data set using dask.
-        Because there can be much data, this can take a while.
-
-        :param method: name of transformation method
-
-        :return: accurate calculated mean and std (depending on transformation)
-        """
-        tmp = []
-        mean = None
-        std = None
-        for station in self.stations:
-            try:
-                data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
-                                     **self.kwargs)
-                chunks = (1, 100, data.data.shape[2])
-                tmp.append(da.from_array(data.data.data, chunks=chunks))
-            except EmptyQueryResult:
-                continue
-        tmp = da.concatenate(tmp, axis=1)
-        if method in ["standardise", "centre"]:
-            mean = da.nanmean(tmp, axis=1).compute()
-            mean = xr.DataArray(mean.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"])
-            if method == "standardise":
-                std = da.nanstd(tmp, axis=1).compute()
-                std = xr.DataArray(std.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"])
-        else:
-            raise NotImplementedError
-        return mean, std
-
-    def calculate_estimated_transformation(self, method):
-        """
-        Calculate estimated transformation statistics.
-
-        Use all stations of this generator and calculate mean and standard deviation first for each station separately.
-        Afterwards, calculate the average mean and standard devation as estimated statistics. Because this method does
-        not consider the length of each data set, the estimated mean distinguishes from the real data mean. Furthermore,
-        the estimated standard deviation is assumed to be the mean (also not weighted) of all deviations. But this is
-        mathematically not true, but still a rough and faster estimation of the true standard deviation. Do not use this
-        method for further statistical calculation. However, in the scope of data preparation for machine learning, this
-        approach is decent ("it is just scaling").
-
-        :param method: name of transformation method
-
-        :return: accurate calculated mean and std (depending on transformation)
-        """
-        data = [[]] * len(self.variables)
-        coords = {"variables": self.variables, "Stations": range(0)}
-        mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
-        std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
-        for station in self.stations:
-            try:
-                data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
-                                     **self.kwargs)
-                data.transform("datetime", method=method)
-                mean = mean.combine_first(data.mean)
-                std = std.combine_first(data.std)
-                data.transform("datetime", method=method, inverse=True)
-            except EmptyQueryResult:
-                continue
-        return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None
-
-    def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True,
-                           save_local_tmp_storage: bool = True) -> AbstractDataPrep:
-        """
-        Create DataPrep object and preprocess data for given key.
-
-        Select data for given key, create a DataPrep object and
-        * apply transformation (optional)
-        * interpolate
-        * make history, labels, and observation
-        * remove nans
-        * upsample extremes (optional).
-        Processed data can be stored locally in a .pickle file. If load local tmp storage is enabled, the get data
-        generator tries first to load data from local pickle file and only creates a new DataPrep object if it couldn't
-        load this data from disk.
-
-        :param key: station key to choose the data generator.
-        :param load_local_tmp_storage: say if data should be processed from scratch or loaded as already processed data
-            from tmp pickle file to save computational time (but of course more disk space required).
-        :param save_local_tmp_storage: save processed data as temporal file locally (default True)
-
-        :return: preprocessed data as a DataPrep instance
-        """
-        station = self.get_station_key(key)
-        try:
-            if not load_local_tmp_storage:
-                raise FileNotFoundError
-            data = self._load_pickle_data(station, self.variables)
-        except FileNotFoundError:
-            logging.debug(f"load not pickle data for {station}")
-            data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
-                                 **self.kwargs)
-            if self.transformation is not None:
-                data.transform("datetime", **helpers.remove_items(self.transformation, "scope"))
-            data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
-            data.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim)
-            data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
-            data.make_observation(self.target_dim, self.target_var, self.interpolate_dim)
-            data.remove_nan(self.interpolate_dim)
-            if self.extreme_values is not None:
-                kwargs = {"extremes_on_right_tail_only": self.kwargs.get("extremes_on_right_tail_only", False)}
-                data.multiply_extremes(self.extreme_values, **kwargs)
-            if save_local_tmp_storage:
-                self._save_pickle_data(data)
-        return data
-
-    def _save_pickle_data(self, data: Any):
-        """
-        Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle'.
-
-        :param data: any data, that should be saved
-        """
-        date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}"
-        vars = '_'.join(sorted(data.variables))
-        station = ''.join(data.station)
-        file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle")
-        with open(file, "wb") as f:
-            pickle.dump(data, f)
-        logging.debug(f"save pickle data to {file}")
-
-    def _load_pickle_data(self, station: Union[str, List[str]], variables: List[str]) -> Any:
-        """
-        Load locally saved data from self.data_path_tmp and name '<station>_<var1>_<var2>_..._<varX>.pickle'.
-
-        :param station: station to load
-        :param variables: list of variables to load
-        :return: loaded data
-        """
-        date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}"
-        vars = '_'.join(sorted(variables))
-        station = ''.join(station)
-        file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle")
-        with open(file, "rb") as f:
-            data = pickle.load(f)
-        logging.debug(f"load pickle data from {file}")
-        return data
-
-    def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str:
-        """
-        Return a valid station key or raise KeyError if this wasn't possible.
-
-        :param key: station key to choose the data generator.
-        :return: station key (id from database)
-        """
-        # extract value if given as list
-        if isinstance(key, list):
-            if len(key) == 1:
-                key = key[0]
-            else:
-                raise KeyError(f"More than one key was given: {key}")
-        # return station name either from key or the recent element from iterator
-        if key is None:
-            return self.stations[self._iterator]
-        else:
-            if isinstance(key, int):
-                if key < self.__len__():
-                    return self.stations[key]
-                else:
-                    raise KeyError(f"{key} is not in range(0, {self.__len__()})")
-            elif isinstance(key, str):
-                if key in self.stations:
-                    return key
-                else:
-                    raise KeyError(f"{key} is not in stations")
-            else:
-                raise KeyError(f"Key has to be from Union[str, int]. Given was {key} ({type(key)})")
diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
deleted file mode 100644
index d5933f193018efb1529db2c026981e8c4d7936d2..0000000000000000000000000000000000000000
--- a/src/data_handling/data_preparation.py
+++ /dev/null
@@ -1,1108 +0,0 @@
-"""Data Preparation class to handle data processing for machine learning."""
-
-__author__ = 'Lukas Leufen'
-__date__ = '2020-06-29'
-
-import datetime as dt
-import logging
-import os
-from functools import reduce
-from typing import Union, List, Iterable, Tuple, Dict
-from src.helpers.join import EmptyQueryResult
-
-import numpy as np
-import pandas as pd
-import xarray as xr
-import dask.array as da
-
-from src.configuration import check_path_and_create
-from src import helpers
-from src.helpers import join, statistics
-
-# define a more general date type for type hinting
-date = Union[dt.date, dt.datetime]
-str_or_list = Union[str, List[str]]
-number = Union[float, int]
-num_or_list = Union[number, List[number]]
-data_or_none = Union[xr.DataArray, None]
-
-
-class AbstractStationPrep():
-    def __init__(self): #, path, station, statistics_per_var, transformation, **kwargs):
-        pass
-        # passed parameters
-        # self.path = os.path.abspath(path)
-        # self.station = helpers.to_list(station)
-        # self.statistics_per_var = statistics_per_var
-        # # self.target_dim = 'variable'
-        # self.transformation = self.setup_transformation(transformation)
-        # self.kwargs = kwargs
-        #
-        # # internal
-        # self.data = None
-        # self.meta = None
-        # self.variables = kwargs.get('variables', list(statistics_per_var.keys()))
-        # self.history = None
-        # self.label = None
-        # self.observation = None
-
-
-    def get_X(self):
-        raise NotImplementedError
-
-    def get_Y(self):
-        raise NotImplementedError
-
-    # def load_data(self):
-    #     try:
-    #         self.read_data_from_disk()
-    #     except FileNotFoundError:
-    #         self.download_data()
-    #         self.load_data()
-    #
-    # def read_data_from_disk(self):
-    #     raise NotImplementedError
-    #
-    # def download_data(self):
-    #     raise NotImplementedError
-
-class StationPrep(AbstractStationPrep):
-
-    def __init__(self, path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var,
-                 interpolate_dim, window_history_size, window_lead_time, **kwargs):
-        super().__init__()  # path, station, statistics_per_var, transformation, **kwargs)
-        self.station_type = station_type
-        self.network = network
-        self.sampling = sampling
-        self.target_dim = target_dim
-        self.target_var = target_var
-        self.interpolate_dim = interpolate_dim
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-
-        self.path = os.path.abspath(path)
-        self.station = helpers.to_list(station)
-        self.statistics_per_var = statistics_per_var
-        # self.target_dim = 'variable'
-        self.transformation = self.setup_transformation(transformation)
-        self.kwargs = kwargs
-
-        # internal
-        self.data = None
-        self.meta = None
-        self.variables = kwargs.get('variables', list(statistics_per_var.keys()))
-        self.history = None
-        self.label = None
-        self.observation = None
-
-    def __str__(self):
-        return self.station[0]
-
-    def load_data(self):
-        try:
-            self.read_data_from_disk()
-        except FileNotFoundError:
-            self.download_data()
-            self.load_data()
-        self.make_samples()
-
-    def __repr__(self):
-        return f"StationPrep(path='{self.path}', station={self.station}, statistics_per_var={self.statistics_per_var}, " \
-               f"transformation={self.transformation}, station_type='{self.station_type}', network='{self.network}', " \
-               f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \
-               f"interpolate_dim='{self.interpolate_dim}', window_history_size={self.window_history_size}, " \
-               f"window_lead_time={self.window_lead_time}, **{self.kwargs})"
-
-    def get_transposed_history(self) -> xr.DataArray:
-        """Return history.
-
-        :return: history with dimensions datetime, window, Stations, variables.
-        """
-        return self.history.transpose("datetime", "window", "Stations", "variables").copy()
-
-    def get_transposed_label(self) -> xr.DataArray:
-        """Return label.
-
-        :return: label with dimensions datetime*, window*, Stations, variables.
-        """
-        return self.label.squeeze("Stations").transpose("datetime", "window").copy()
-
-    def get_X(self):
-        return self.get_transposed_history()
-
-    def get_Y(self):
-        return self.get_transposed_label()
-
-    def make_samples(self):
-        self.load_data()
-        self.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim)
-        self.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
-        self.make_observation(self.target_dim, self.target_var, self.interpolate_dim)
-        self.remove_nan(self.interpolate_dim)
-
-    def read_data_from_disk(self, source_name=""):
-        """
-        Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
-
-        Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both
-        cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
-        set, it is assumed, that data should be saved locally.
-        """
-        source_name = source_name if len(source_name) == 0 else f" from {source_name}"
-        check_path_and_create(self.path)
-        file_name = self._set_file_name()
-        meta_file = self._set_meta_file_name()
-        if self.kwargs.get('overwrite_local_data', False):
-            logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}")
-            if os.path.exists(file_name):
-                os.remove(file_name)
-            if os.path.exists(meta_file):
-                os.remove(meta_file)
-            data, self.meta = self.download_data(file_name, meta_file)
-            logging.debug(f"loaded new data{source_name}")
-        else:
-            try:
-                logging.debug(f"try to load local data from: {file_name}")
-                data = xr.open_dataarray(file_name)
-                self.meta = pd.read_csv(meta_file, index_col=0)
-                self.check_station_meta()
-                logging.debug("loading finished")
-            except FileNotFoundError as e:
-                logging.debug(e)
-                logging.debug(f"load new data{source_name}")
-                data, self.meta = self.download_data(file_name, meta_file)
-                logging.debug("loading finished")
-        # create slices and check for negative concentration.
-        data = self._slice_prep(data)
-        self.data = self.check_for_negative_concentrations(data)
-
-    def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]:
-        """
-        Download data from TOAR database using the JOIN interface.
-
-        Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally
-        stored locally using given names for file and meta file.
-
-        :param file_name: name of file to save data to (containing full path)
-        :param meta_file: name of the meta data file (also containing full path)
-
-        :return: downloaded data and its meta data
-        """
-        df_all = {}
-        df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var,
-                                      station_type=self.station_type, network_name=self.network, sampling=self.sampling)
-        df_all[self.station[0]] = df
-        # convert df_all to xarray
-        xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()}
-        xarr = xr.Dataset(xarr).to_array(dim='Stations')
-        if self.kwargs.get('store_data_locally', True):
-            # save locally as nc/csv file
-            xarr.to_netcdf(path=file_name)
-            meta.to_csv(meta_file)
-        return xarr, meta
-
-    def download_data(self, file_name, meta_file):
-        data, meta = self.download_data_from_join(file_name, meta_file)
-        return data, meta
-
-    def check_station_meta(self):
-        """
-        Search for the entries in meta data and compare the value with the requested values.
-
-        Will raise a FileNotFoundError if the values mismatch.
-        """
-        if self.station_type is not None:
-            check_dict = {"station_type": self.station_type, "network_name": self.network}
-            for (k, v) in check_dict.items():
-                if v is None:
-                    continue
-                if self.meta.at[k, self.station[0]] != v:
-                    logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
-                                  f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new "
-                                  f"grapping from web.")
-                    raise FileNotFoundError
-
-    def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
-        """
-        Set all negative concentrations to zero.
-
-        Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
-        #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
-        "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
-
-        :param data: data array containing variables to check
-        :param minimum: minimum value, by default this should be 0
-
-        :return: corrected data
-        """
-        chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
-                     "propane", "so2", "toluene"]
-        used_chem_vars = list(set(chem_vars) & set(self.variables))
-        data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
-        return data
-
-    def shift(self, dim: str, window: int) -> xr.DataArray:
-        """
-        Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0).
-
-        :param dim: dimension along shift is applied
-        :param window: number of steps to shift (corresponds to the window length)
-
-        :return: shifted data
-        """
-        start = 1
-        end = 1
-        if window <= 0:
-            start = window
-        else:
-            end = window + 1
-        res = []
-        for w in range(start, end):
-            res.append(self.data.shift({dim: -w}))
-        window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim)
-        res = xr.concat(res, dim=window_array)
-        return res
-
-    @staticmethod
-    def create_index_array(index_name: str, index_value: Iterable[int], squeeze_dim: str) -> xr.DataArray:
-        """
-        Create an 1D xr.DataArray with given index name and value.
-
-        :param index_name: name of dimension
-        :param index_value: values of this dimension
-
-        :return: this array
-        """
-        ind = pd.DataFrame({'val': index_value}, index=index_value)
-        # res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim=squeez/e_dim, drop=True)
-        res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze(
-            dim=squeeze_dim,
-            drop=True
-        )
-        res.name = index_name
-        return res
-
-    def _set_file_name(self):
-        all_vars = sorted(self.statistics_per_var.keys())
-        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc")
-
-    def _set_meta_file_name(self):
-        all_vars = sorted(self.statistics_per_var.keys())
-        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv")
-
-    def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
-                    **kwargs):
-        """
-        Interpolate values according to different methods.
-
-        (Copy paste from dataarray.interpolate_na)
-
-        :param dim:
-                Specifies the dimension along which to interpolate.
-        :param method:
-                {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
-                          'polynomial', 'barycentric', 'krog', 'pchip',
-                          'spline', 'akima'}, optional
-                    String indicating which method to use for interpolation:
-
-                    - 'linear': linear interpolation (Default). Additional keyword
-                      arguments are passed to ``numpy.interp``
-                    - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
-                      'polynomial': are passed to ``scipy.interpolate.interp1d``. If
-                      method=='polynomial', the ``order`` keyword argument must also be
-                      provided.
-                    - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their
-                      respective``scipy.interpolate`` classes.
-        :param limit:
-                    default None
-                    Maximum number of consecutive NaNs to fill. Must be greater than 0
-                    or None for no limit.
-        :param use_coordinate:
-                default True
-                    Specifies which index to use as the x values in the interpolation
-                    formulated as `y = f(x)`. If False, values are treated as if
-                    eqaully-spaced along `dim`. If True, the IndexVariable `dim` is
-                    used. If use_coordinate is a string, it specifies the name of a
-                    coordinate variariable to use as the index.
-        :param kwargs:
-
-        :return: xarray.DataArray
-        """
-        self.data = self.data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate,
-                                             **kwargs)
-
-    def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
-        """
-        Create a xr.DataArray containing history data.
-
-        Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted
-        data. This is used to represent history in the data. Results are stored in history attribute.
-
-        :param dim_name_of_inputs: Name of dimension which contains the input variables
-        :param window: number of time steps to look back in history
-                Note: window will be treated as negative value. This should be in agreement with looking back on
-                a time line. Nonetheless positive values are allowed but they are converted to its negative
-                expression
-        :param dim_name_of_shift: Dimension along shift will be applied
-        """
-        window = -abs(window)
-        self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables})
-
-    def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
-                    window: int) -> None:
-        """
-        Create a xr.DataArray containing labels.
-
-        Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label
-        attribute.
-
-        :param dim_name_of_target: Name of dimension which contains the target variable
-        :param target_var: Name of target variable in 'dimension'
-        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
-        :param window: lead time of label
-        """
-        window = abs(window)
-        self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var})
-
-    def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
-        """
-        Create a xr.DataArray containing observations.
-
-        Observations are defined as value of the current time step t. Set observation attribute.
-
-        :param dim_name_of_target: Name of dimension which contains the observation variable
-        :param target_var: Name of observation variable(s) in 'dimension'
-        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
-        """
-        self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var})
-
-    def remove_nan(self, dim: str) -> None:
-        """
-        Remove all NAs slices along dim which contain nans in history, label and observation.
-
-        This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute.
-
-        :param dim: dimension along the remove is performed.
-        """
-        intersect = []
-        if (self.history is not None) and (self.label is not None):
-            non_nan_history = self.history.dropna(dim=dim)
-            non_nan_label = self.label.dropna(dim=dim)
-            non_nan_observation = self.observation.dropna(dim=dim)
-            intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values,
-                                                non_nan_observation.coords[dim].values))
-
-        min_length = self.kwargs.get("min_length", 0)
-        if len(intersect) < max(min_length, 1):
-            self.history = None
-            self.label = None
-            self.observation = None
-        else:
-            self.history = self.history.sel({dim: intersect})
-            self.label = self.label.sel({dim: intersect})
-            self.observation = self.observation.sel({dim: intersect})
-
-    def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray:
-        """
-        Set start and end date for slicing and execute self._slice().
-
-        :param data: data to slice
-        :param coord: name of axis to slice
-
-        :return: sliced data
-        """
-        start = self.kwargs.get('start', data.coords[coord][0].values)
-        end = self.kwargs.get('end', data.coords[coord][-1].values)
-        return self._slice(data, start, end, coord)
-
-    @staticmethod
-    def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray:
-        """
-        Slice through a given data_item (for example select only values of 2011).
-
-        :param data: data to slice
-        :param start: start date of slice
-        :param end: end date of slice
-        :param coord: name of axis to slice
-
-        :return: sliced data
-        """
-        return data.loc[{coord: slice(str(start), str(end))}]
-
-    def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
-        """
-        Set all negative concentrations to zero.
-
-        Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
-        #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
-        "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
-
-        :param data: data array containing variables to check
-        :param minimum: minimum value, by default this should be 0
-
-        :return: corrected data
-        """
-        chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
-                     "propane", "so2", "toluene"]
-        used_chem_vars = list(set(chem_vars) & set(self.variables))
-        data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
-        return data
-
-    def setup_transformation(self, transformation: Dict):
-        """
-        Set up transformation by extracting all relevant information.
-
-        Extract all information from transformation dictionary. Possible keys are scope. method, mean, and std. Scope
-        can either be station or data. Station scope means, that data transformation is performed for each station
-        independently (somehow like batch normalisation), whereas data scope means a transformation applied on the
-        entire data set.
-
-        * If using data scope, mean and standard deviation (each only if required by transformation method) can either
-          be calculated accurate or as an estimate (faster implementation). This must be set in dictionary  either
-          as "mean": "accurate" or "mean": "estimate". In both cases, the required statistics are calculated and saved.
-          After this calculations, the mean key is overwritten by the actual values to use.
-        * If using station scope, no additional information is required.
-        * If a transformation should be applied on base of existing values, these need to be provided in the respective
-          keys "mean" and "std" (again only if required for given method).
-
-        :param transformation: the transformation dictionary as described above.
-
-        :return: updated transformation dictionary
-        """
-        if transformation is None:
-            return
-        transformation = transformation.copy()
-        scope = transformation.get("scope", "station")
-        method = transformation.get("method", "standardise")
-        mean = transformation.get("mean", None)
-        std = transformation.get("std", None)
-        if scope == "data":
-            if isinstance(mean, str):
-                if mean == "accurate":
-                    mean, std = self.calculate_accurate_transformation(method)
-                elif mean == "estimate":
-                    mean, std = self.calculate_estimated_transformation(method)
-                else:
-                    raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or"
-                                     f"be an array with already calculated means. Given was: {mean}")
-        elif scope == "station":
-            mean, std = None, None
-        else:
-            raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}")
-        transformation["method"] = method
-        transformation["mean"] = mean
-        transformation["std"] = std
-        return transformation
-
-    def calculate_accurate_transformation(self, method: str) -> Tuple[data_or_none, data_or_none]:
-        """
-        Calculate accurate transformation statistics.
-
-        Use all stations of this generator and calculate mean and standard deviation on entire data set using dask.
-        Because there can be much data, this can take a while.
-
-        :param method: name of transformation method
-
-        :return: accurate calculated mean and std (depending on transformation)
-        """
-        tmp = []
-        mean = None
-        std = None
-        for station in self.stations:
-            try:
-                data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
-                                     **self.kwargs)
-                chunks = (1, 100, data.data.shape[2])
-                tmp.append(da.from_array(data.data.data, chunks=chunks))
-            except EmptyQueryResult:
-                continue
-        tmp = da.concatenate(tmp, axis=1)
-        if method in ["standardise", "centre"]:
-            mean = da.nanmean(tmp, axis=1).compute()
-            mean = xr.DataArray(mean.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"])
-            if method == "standardise":
-                std = da.nanstd(tmp, axis=1).compute()
-                std = xr.DataArray(std.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"])
-        else:
-            raise NotImplementedError
-        return mean, std
-
-    def calculate_estimated_transformation(self, method):
-        """
-        Calculate estimated transformation statistics.
-
-        Use all stations of this generator and calculate mean and standard deviation first for each station separately.
-        Afterwards, calculate the average mean and standard devation as estimated statistics. Because this method does
-        not consider the length of each data set, the estimated mean distinguishes from the real data mean. Furthermore,
-        the estimated standard deviation is assumed to be the mean (also not weighted) of all deviations. But this is
-        mathematically not true, but still a rough and faster estimation of the true standard deviation. Do not use this
-        method for further statistical calculation. However, in the scope of data preparation for machine learning, this
-        approach is decent ("it is just scaling").
-
-        :param method: name of transformation method
-
-        :return: accurate calculated mean and std (depending on transformation)
-        """
-        data = [[]] * len(self.variables)
-        coords = {"variables": self.variables, "Stations": range(0)}
-        mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
-        std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
-        for station in self.stations:
-            try:
-                data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
-                                     **self.kwargs)
-                data.transform("datetime", method=method)
-                mean = mean.combine_first(data.mean)
-                std = std.combine_first(data.std)
-                data.transform("datetime", method=method, inverse=True)
-            except EmptyQueryResult:
-                continue
-        return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None
-
-    def load_data(self):
-        try:
-            self.read_data_from_disk()
-        except FileNotFoundError:
-            self.download_data()
-            self.load_data()
-
-
-class AbstractDataPrep(object):
-    """
-    This class prepares data to be used in neural networks.
-
-    The instance searches for local stored data, that meet the given demands. If no local data is found, the DataPrep
-    instance will load data from TOAR database and store this data locally to use the next time. For the moment, there
-    is only support for daily aggregated time series. The aggregation can be set manually and differ for each variable.
-
-    After data loading, different data pre-processing steps can be executed to prepare the data for further
-    applications. Especially the following methods can be used for the pre-processing step:
-
-    - interpolate: interpolate between data points by using xarray's interpolation method
-    - standardise: standardise data to mean=1 and std=1, centralise to mean=0, additional methods like normalise on \
-        interval [0, 1] are not implemented yet.
-    - make window history: represent the history (time steps before) for training/ testing; X
-    - make labels: create target vector with given leading time steps for training/ testing; y
-    - remove Nans jointly from desired input and output, only keeps time steps where no NaNs are present in X AND y. \
-        Use this method after the creation of the window history and labels to clean up the data cube.
-
-    To create a DataPrep instance, it is needed to specify the stations by id (e.g. "DEBW107"), its network (e.g. UBA,
-    "Umweltbundesamt") and the variables to use. Further options can be set in the instance.
-
-    * `statistics_per_var`: define a specific statistic to extract from the TOAR database for each variable.
-    * `start`: define a start date for the data cube creation. Default: Use the first entry in time series
-    * `end`: set the end date for the data cube. Default: Use last date in time series.
-    * `store_data_locally`: store recently downloaded data on local disk. Default: True
-    * set further parameters for xarray's interpolation methods to modify the interpolation scheme
-
-    """
-
-    def __init__(self, path: str, station: Union[str, List[str]], variables: List[str], **kwargs):
-        """Construct instance."""
-        self.path = os.path.abspath(path)
-        self.station = helpers.to_list(station)
-        self.variables = variables
-        self.mean: data_or_none = None
-        self.std: data_or_none = None
-        self.history: data_or_none = None
-        self.label: data_or_none = None
-        self.observation: data_or_none = None
-        self.extremes_history: data_or_none = None
-        self.extremes_label: data_or_none = None
-        self.kwargs = kwargs
-        self.data = None
-        self.meta = None
-        self._transform_method = None
-        self.statistics_per_var = kwargs.get("statistics_per_var", None)
-        self.sampling = kwargs.get("sampling", "daily")
-        if self.statistics_per_var is not None or self.sampling == "hourly":
-            self.load_data()
-        else:
-            raise NotImplementedError("Either select hourly data or provide statistics_per_var.")
-
-    def load_data(self, source_name=""):
-        """
-        Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
-
-        Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both
-        cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
-        set, it is assumed, that data should be saved locally.
-        """
-        source_name = source_name if len(source_name) == 0 else f" from {source_name}"
-        check_path_and_create(self.path)
-        file_name = self._set_file_name()
-        meta_file = self._set_meta_file_name()
-        if self.kwargs.get('overwrite_local_data', False):
-            logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}")
-            if os.path.exists(file_name):
-                os.remove(file_name)
-            if os.path.exists(meta_file):
-                os.remove(meta_file)
-            data, self.meta = self.download_data(file_name, meta_file)
-            logging.debug(f"loaded new data{source_name}")
-        else:
-            try:
-                logging.debug(f"try to load local data from: {file_name}")
-                data = xr.open_dataarray(file_name)
-                self.meta = pd.read_csv(meta_file, index_col=0)
-                self.check_station_meta()
-                logging.debug("loading finished")
-            except FileNotFoundError as e:
-                logging.debug(e)
-                logging.debug(f"load new data{source_name}")
-                data, self.meta = self.download_data(file_name, meta_file)
-                logging.debug("loading finished")
-        # create slices and check for negative concentration.
-        data = self._slice_prep(data)
-        self.data = self.check_for_negative_concentrations(data)
-
-    def download_data(self, file_name, meta_file) -> [xr.DataArray, pd.DataFrame]:
-        """
-        Download data and meta.
-
-        :param file_name: name of file to save data to (containing full path)
-        :param meta_file: name of the meta data file (also containing full path)
-        """
-        raise NotImplementedError
-
-    def check_station_meta(self):
-        """
-        Placeholder function to implement some additional station meta data check if desired.
-
-        Ideally, this method should raise a FileNotFoundError if a value mismatch to load fresh data from a source. If
-        this method is not required for your application just inherit and add the `pass` command inside the method. The
-        NotImplementedError is more a reminder that you could use it.
-        """
-        raise NotImplementedError
-
-    def _set_file_name(self):
-        all_vars = sorted(self.statistics_per_var.keys())
-        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc")
-
-    def _set_meta_file_name(self):
-        all_vars = sorted(self.statistics_per_var.keys())
-        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv")
-
-    def __repr__(self):
-        """Represent class attributes."""
-        return f"AbstractDataPrep(path='{self.path}', station={self.station}, variables={self.variables}, " \
-               f"**{self.kwargs})"
-
-    def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
-                    **kwargs):
-        """
-        Interpolate values according to different methods.
-
-        (Copy paste from dataarray.interpolate_na)
-
-        :param dim:
-                Specifies the dimension along which to interpolate.
-        :param method:
-                {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
-                          'polynomial', 'barycentric', 'krog', 'pchip',
-                          'spline', 'akima'}, optional
-                    String indicating which method to use for interpolation:
-
-                    - 'linear': linear interpolation (Default). Additional keyword
-                      arguments are passed to ``numpy.interp``
-                    - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
-                      'polynomial': are passed to ``scipy.interpolate.interp1d``. If
-                      method=='polynomial', the ``order`` keyword argument must also be
-                      provided.
-                    - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their
-                      respective``scipy.interpolate`` classes.
-        :param limit:
-                    default None
-                    Maximum number of consecutive NaNs to fill. Must be greater than 0
-                    or None for no limit.
-        :param use_coordinate:
-                default True
-                    Specifies which index to use as the x values in the interpolation
-                    formulated as `y = f(x)`. If False, values are treated as if
-                    eqaully-spaced along `dim`. If True, the IndexVariable `dim` is
-                    used. If use_coordinate is a string, it specifies the name of a
-                    coordinate variariable to use as the index.
-        :param kwargs:
-
-        :return: xarray.DataArray
-        """
-        self.data = self.data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate,
-                                             **kwargs)
-
-    @staticmethod
-    def check_inverse_transform_params(mean: data_or_none, std: data_or_none, method: str) -> None:
-        """
-        Support inverse_transformation method.
-
-        Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas
-        normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements.
-
-        :param mean: data with all mean values
-        :param std: data with all standard deviation values
-        :param method: name of transformation method
-        """
-        msg = ""
-        if method in ['standardise', 'centre'] and mean is None:
-            msg += "mean, "
-        if method == 'standardise' and std is None:
-            msg += "std, "
-        if len(msg) > 0:
-            raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}")
-
-    def inverse_transform(self) -> None:
-        """
-        Perform inverse transformation.
-
-        Will raise an AssertionError, if no transformation was performed before. Checks first, if all required
-        statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by
-        new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the
-        current data is not transformed.
-        """
-
-        def f_inverse(data, mean, std, method_inverse):
-            if method_inverse == 'standardise':
-                return statistics.standardise_inverse(data, mean, std), None, None
-            elif method_inverse == 'centre':
-                return statistics.centre_inverse(data, mean), None, None
-            elif method_inverse == 'normalise':
-                raise NotImplementedError
-            else:
-                raise NotImplementedError
-
-        if self._transform_method is None:
-            raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.")
-        self.check_inverse_transform_params(self.mean, self.std, self._transform_method)
-        self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method)
-        self._transform_method = None
-
-    def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean=None,
-                  std=None) -> None:
-        """
-        Transform data according to given transformation settings.
-
-        This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
-        and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
-        (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
-        method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
-        internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
-
-        :param string/int dim: This param is not used for inverse transformation.
-                | for xarray.DataArray as string: name of dimension which should be standardised
-                | for pandas.DataFrame as int: axis of dimension which should be standardised
-        :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented
-                    yet. This param is not used for inverse transformation.
-        :param inverse: Switch between transformation and inverse transformation.
-
-        :return: xarray.DataArrays or pandas.DataFrames:
-                #. mean: Mean of data
-                #. std: Standard deviation of data
-                #. data: Standardised data
-        """
-
-        def f(data):
-            if method == 'standardise':
-                return statistics.standardise(data, dim)
-            elif method == 'centre':
-                return statistics.centre(data, dim)
-            elif method == 'normalise':
-                # use min/max of data or given min/max
-                raise NotImplementedError
-            else:
-                raise NotImplementedError
-
-        def f_apply(data):
-            if method == "standardise":
-                return mean, std, statistics.standardise_apply(data, mean, std)
-            elif method == "centre":
-                return mean, None, statistics.centre_apply(data, mean)
-            else:
-                raise NotImplementedError
-
-        if not inverse:
-            if self._transform_method is not None:
-                raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with "
-                                     f"{self._transform_method}. Please perform inverse transformation of data first.")
-            self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data)
-            self._transform_method = method
-        else:
-            self.inverse_transform()
-
-    def get_transformation_information(self, variable: str) -> Tuple[data_or_none, data_or_none, str]:
-        """
-        Extract transformation statistics and method.
-
-        Get mean and standard deviation for given variable and the transformation method if set. If a transformation
-        depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are
-        returned with None as fill value.
-
-        :param variable: Variable for which the information on transformation is requested.
-
-        :return: mean, standard deviation and transformation method
-        """
-        try:
-            mean = self.mean.sel({'variables': variable}).values
-        except AttributeError:
-            mean = None
-        try:
-            std = self.std.sel({'variables': variable}).values
-        except AttributeError:
-            std = None
-        return mean, std, self._transform_method
-
-    def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
-        """
-        Create a xr.DataArray containing history data.
-
-        Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted
-        data. This is used to represent history in the data. Results are stored in history attribute.
-
-        :param dim_name_of_inputs: Name of dimension which contains the input variables
-        :param window: number of time steps to look back in history
-                Note: window will be treated as negative value. This should be in agreement with looking back on
-                a time line. Nonetheless positive values are allowed but they are converted to its negative
-                expression
-        :param dim_name_of_shift: Dimension along shift will be applied
-        """
-        window = -abs(window)
-        self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables})
-
-    def shift(self, dim: str, window: int) -> xr.DataArray:
-        """
-        Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0).
-
-        :param dim: dimension along shift is applied
-        :param window: number of steps to shift (corresponds to the window length)
-
-        :return: shifted data
-        """
-        start = 1
-        end = 1
-        if window <= 0:
-            start = window
-        else:
-            end = window + 1
-        res = []
-        for w in range(start, end):
-            res.append(self.data.shift({dim: -w}))
-        window_array = self.create_index_array('window', range(start, end))
-        res = xr.concat(res, dim=window_array)
-        return res
-
-    def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
-                    window: int) -> None:
-        """
-        Create a xr.DataArray containing labels.
-
-        Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label
-        attribute.
-
-        :param dim_name_of_target: Name of dimension which contains the target variable
-        :param target_var: Name of target variable in 'dimension'
-        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
-        :param window: lead time of label
-        """
-        window = abs(window)
-        self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var})
-
-    def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
-        """
-        Create a xr.DataArray containing observations.
-
-        Observations are defined as value of the current time step t. Set observation attribute.
-
-        :param dim_name_of_target: Name of dimension which contains the observation variable
-        :param target_var: Name of observation variable(s) in 'dimension'
-        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
-        """
-        self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var})
-
-    def remove_nan(self, dim: str) -> None:
-        """
-        Remove all NAs slices along dim which contain nans in history, label and observation.
-
-        This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute.
-
-        :param dim: dimension along the remove is performed.
-        """
-        intersect = []
-        if (self.history is not None) and (self.label is not None):
-            non_nan_history = self.history.dropna(dim=dim)
-            non_nan_label = self.label.dropna(dim=dim)
-            non_nan_observation = self.observation.dropna(dim=dim)
-            intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values,
-                                                non_nan_observation.coords[dim].values))
-
-        min_length = self.kwargs.get("min_length", 0)
-        if len(intersect) < max(min_length, 1):
-            self.history = None
-            self.label = None
-            self.observation = None
-        else:
-            self.history = self.history.sel({dim: intersect})
-            self.label = self.label.sel({dim: intersect})
-            self.observation = self.observation.sel({dim: intersect})
-
-    @staticmethod
-    def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray:
-        """
-        Create an 1D xr.DataArray with given index name and value.
-
-        :param index_name: name of dimension
-        :param index_value: values of this dimension
-
-        :return: this array
-        """
-        ind = pd.DataFrame({'val': index_value}, index=index_value)
-        res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim='variable', drop=True)
-        res.name = index_name
-        return res
-
-    def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray:
-        """
-        Set start and end date for slicing and execute self._slice().
-
-        :param data: data to slice
-        :param coord: name of axis to slice
-
-        :return: sliced data
-        """
-        start = self.kwargs.get('start', data.coords[coord][0].values)
-        end = self.kwargs.get('end', data.coords[coord][-1].values)
-        return self._slice(data, start, end, coord)
-
-    @staticmethod
-    def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray:
-        """
-        Slice through a given data_item (for example select only values of 2011).
-
-        :param data: data to slice
-        :param start: start date of slice
-        :param end: end date of slice
-        :param coord: name of axis to slice
-
-        :return: sliced data
-        """
-        return data.loc[{coord: slice(str(start), str(end))}]
-
-    def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
-        """
-        Set all negative concentrations to zero.
-
-        Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
-        #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
-        "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
-
-        :param data: data array containing variables to check
-        :param minimum: minimum value, by default this should be 0
-
-        :return: corrected data
-        """
-        chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
-                     "propane", "so2", "toluene"]
-        used_chem_vars = list(set(chem_vars) & set(self.variables))
-        data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
-        return data
-
-    def get_transposed_history(self) -> xr.DataArray:
-        """Return history.
-
-        :return: history with dimensions datetime, window, Stations, variables.
-        """
-        return self.history.transpose("datetime", "window", "Stations", "variables").copy()
-
-    def get_transposed_label(self) -> xr.DataArray:
-        """Return label.
-
-        :return: label with dimensions datetime, window, Stations, variables.
-        """
-        return self.label.squeeze("Stations").transpose("datetime", "window").copy()
-
-    def get_extremes_history(self) -> xr.DataArray:
-        """Return extremes history.
-
-        :return: extremes history with dimensions datetime, window, Stations, variables.
-        """
-        return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy()
-
-    def get_extremes_label(self) -> xr.DataArray:
-        """Return extremes label.
-
-        :return: extremes label with dimensions datetime, window, Stations, variables.
-        """
-        return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy()
-
-    def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
-                          timedelta: Tuple[int, str] = (1, 'm')):
-        """
-        Multiply extremes.
-
-        This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can
-        also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of
-        floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised
-        space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be
-        extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is
-        used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can
-        identify those "artificial" data points later easily. Extreme inputs and labels are stored in
-        self.extremes_history and self.extreme_labels, respectively.
-
-        :param extreme_values: user definition of extreme
-        :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values,
-            if True only extract values larger than extreme_values
-        :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime
-        """
-        # check if labels or history is None
-        if (self.label is None) or (self.history is None):
-            logging.debug(f"{self.station} has `None' labels, skip multiply extremes")
-            return
-
-        # check type if inputs
-        extreme_values = helpers.to_list(extreme_values)
-        for i in extreme_values:
-            if not isinstance(i, number.__args__):
-                raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
-                                f"{i} is type {type(i)}")
-
-        for extr_val in sorted(extreme_values):
-            # check if some extreme values are already extracted
-            if (self.extremes_label is None) or (self.extremes_history is None):
-                # extract extremes based on occurance in labels
-                if extremes_on_right_tail_only:
-                    extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1, )
-                else:
-                    extreme_label_idx = np.concatenate(((self.label < -extr_val).any(axis=0).values.reshape(-1, 1),
-                                                        (self.label > extr_val).any(axis=0).values.reshape(-1, 1)),
-                                                       axis=1).any(axis=1)
-                extremes_label = self.label[..., extreme_label_idx]
-                extremes_history = self.history[..., extreme_label_idx, :]
-                extremes_label.datetime.values += np.timedelta64(*timedelta)
-                extremes_history.datetime.values += np.timedelta64(*timedelta)
-                self.extremes_label = extremes_label  # .squeeze('Stations').transpose('datetime', 'window')
-                self.extremes_history = extremes_history  # .transpose('datetime', 'window', 'Stations', 'variables')
-            else:  # one extr value iteration is done already: self.extremes_label is NOT None...
-                if extremes_on_right_tail_only:
-                    extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, )
-                else:
-                    extreme_label_idx = np.concatenate(
-                        ((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1),
-                         (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1)
-                         ), axis=1).any(axis=1)
-                # check on existing extracted extremes to minimise computational costs for comparison
-                extremes_label = self.extremes_label[..., extreme_label_idx]
-                extremes_history = self.extremes_history[..., extreme_label_idx, :]
-                extremes_label.datetime.values += np.timedelta64(*timedelta)
-                extremes_history.datetime.values += np.timedelta64(*timedelta)
-                self.extremes_label = xr.concat([self.extremes_label, extremes_label], dim='datetime')
-                self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime')
-
-
-if __name__ == "__main__":
-    # dp = AbstractDataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
-    # print(dp)
-    statistics_per_var = {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}
-    sp = StationPrep(path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122',
-                     statistics_per_var=statistics_per_var, transformation={}, station_type='background',
-                     network='UBA', sampling='daily', target_dim='variables', target_var='o3',
-                     interpolate_dim='datetime', window_history_size=7, window_lead_time=3)
-    sp.get_X()
-    sp.get_Y()
-    print(sp)
diff --git a/src/data_handling/data_preparation_join.py b/src/data_handling/data_preparation_join.py
deleted file mode 100644
index 86c7dee055c8258069307567b28ffcd113e13477..0000000000000000000000000000000000000000
--- a/src/data_handling/data_preparation_join.py
+++ /dev/null
@@ -1,124 +0,0 @@
-"""Data Preparation class to handle data processing for machine learning."""
-
-__author__ = 'Felix Kleinert, Lukas Leufen'
-__date__ = '2019-10-16'
-
-import datetime as dt
-import inspect
-import logging
-from typing import Union, List
-
-import pandas as pd
-import xarray as xr
-
-from src import helpers
-from src.helpers import join
-from src.data_handling.data_preparation import AbstractDataPrep
-
-# define a more general date type for type hinting
-date = Union[dt.date, dt.datetime]
-str_or_list = Union[str, List[str]]
-number = Union[float, int]
-num_or_list = Union[number, List[number]]
-data_or_none = Union[xr.DataArray, None]
-
-
-class DataPrepJoin(AbstractDataPrep):
-    """
-    This class prepares data to be used in neural networks.
-
-    The instance searches for local stored data, that meet the given demands. If no local data is found, the DataPrep
-    instance will load data from TOAR database and store this data locally to use the next time. For the moment, there
-    is only support for daily aggregated time series. The aggregation can be set manually and differ for each variable.
-
-    After data loading, different data pre-processing steps can be executed to prepare the data for further
-    applications. Especially the following methods can be used for the pre-processing step:
-
-    - interpolate: interpolate between data points by using xarray's interpolation method
-    - standardise: standardise data to mean=1 and std=1, centralise to mean=0, additional methods like normalise on \
-        interval [0, 1] are not implemented yet.
-    - make window history: represent the history (time steps before) for training/ testing; X
-    - make labels: create target vector with given leading time steps for training/ testing; y
-    - remove Nans jointly from desired input and output, only keeps time steps where no NaNs are present in X AND y. \
-        Use this method after the creation of the window history and labels to clean up the data cube.
-
-    To create a DataPrep instance, it is needed to specify the stations by id (e.g. "DEBW107"), its network (e.g. UBA,
-    "Umweltbundesamt") and the variables to use. Further options can be set in the instance.
-
-    * `statistics_per_var`: define a specific statistic to extract from the TOAR database for each variable.
-    * `start`: define a start date for the data cube creation. Default: Use the first entry in time series
-    * `end`: set the end date for the data cube. Default: Use last date in time series.
-    * `store_data_locally`: store recently downloaded data on local disk. Default: True
-    * set further parameters for xarray's interpolation methods to modify the interpolation scheme
-
-    """
-
-    def __init__(self, path: str, station: Union[str, List[str]], variables: List[str], network: str = None,
-                 station_type: str = None, **kwargs):
-        self.network = network
-        self.station_type = station_type
-        params = helpers.remove_items(inspect.getfullargspec(AbstractDataPrep.__init__).args, "self")
-        kwargs = {**{k: v for k, v in locals().items() if k in params and v is not None}, **kwargs}
-        super().__init__(**kwargs)
-
-    def download_data(self, file_name, meta_file):
-        """
-        Download data and meta from join.
-
-        :param file_name: name of file to save data to (containing full path)
-        :param meta_file: name of the meta data file (also containing full path)
-        """
-        data, meta = self.download_data_from_join(file_name, meta_file)
-        return data, meta
-
-    def check_station_meta(self):
-        """
-        Search for the entries in meta data and compare the value with the requested values.
-
-        Will raise a FileNotFoundError if the values mismatch.
-        """
-        if self.station_type is not None:
-            check_dict = {"station_type": self.station_type, "network_name": self.network}
-            for (k, v) in check_dict.items():
-                if v is None:
-                    continue
-                if self.meta.at[k, self.station[0]] != v:
-                    logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
-                                  f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new "
-                                  f"grapping from web.")
-                    raise FileNotFoundError
-
-    def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]:
-        """
-        Download data from TOAR database using the JOIN interface.
-
-        Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally
-        stored locally using given names for file and meta file.
-
-        :param file_name: name of file to save data to (containing full path)
-        :param meta_file: name of the meta data file (also containing full path)
-
-        :return: downloaded data and its meta data
-        """
-        df_all = {}
-        df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var,
-                                      station_type=self.station_type, network_name=self.network, sampling=self.sampling)
-        df_all[self.station[0]] = df
-        # convert df_all to xarray
-        xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()}
-        xarr = xr.Dataset(xarr).to_array(dim='Stations')
-        if self.kwargs.get('store_data_locally', True):
-            # save locally as nc/csv file
-            xarr.to_netcdf(path=file_name)
-            meta.to_csv(meta_file)
-        return xarr, meta
-
-    def __repr__(self):
-        """Represent class attributes."""
-        return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \
-               f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})"
-
-
-if __name__ == "__main__":
-    dp = DataPrepJoin('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
-    print(dp)
diff --git a/src/run_modules/__init__.py b/src/run_modules/__init__.py
deleted file mode 100644
index 0c70ae4205ff38fdc876538c42c44ca0bc8cb9c0..0000000000000000000000000000000000000000
--- a/src/run_modules/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from src.run_modules.experiment_setup import ExperimentSetup
-from src.run_modules.model_setup import ModelSetup
-from src.run_modules.partition_check import PartitionCheck
-from src.run_modules.post_processing import PostProcessing
-from src.run_modules.pre_processing import PreProcessing
-from src.run_modules.run_environment import RunEnvironment
-from src.run_modules.training import Training
diff --git a/src/workflows/__init__.py b/src/workflows/__init__.py
deleted file mode 100644
index 57e514cd9ced32fbf1dbb290b1008deffcec52d3..0000000000000000000000000000000000000000
--- a/src/workflows/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from src.workflows.abstract_workflow import Workflow
-from src.workflows.default_workflow import DefaultWorkflow, DefaultWorkflowHPC
\ No newline at end of file
diff --git a/test/test_configuration/test_path_config.py b/test/test_configuration/test_path_config.py
index acb43676cb86ca76aded88aa0d46f62dd78d9992..b97763632922fc2aaffaf267cfbc76ff99e25b6f 100644
--- a/test/test_configuration/test_path_config.py
+++ b/test/test_configuration/test_path_config.py
@@ -4,9 +4,9 @@ import os
 import mock
 import pytest
 
-from src.configuration import prepare_host, set_experiment_name, set_bootstrap_path, check_path_and_create, \
+from mlair.configuration import prepare_host, set_experiment_name, set_bootstrap_path, check_path_and_create, \
     set_experiment_path, ROOT_PATH
-from src.helpers import PyTestRegex
+from mlair.helpers import PyTestRegex
 
 
 class TestPrepareHost:
@@ -16,22 +16,20 @@ class TestPrepareHost:
     @mock.patch("getpass.getuser", return_value="testUser")
     @mock.patch("os.path.exists", return_value=True)
     def test_prepare_host(self, mock_host, mock_user, mock_path):
-        assert prepare_host() == "/home/testUser/machinelearningtools/data/toar_daily/"
+        assert prepare_host() == "/home/testUser/mlair/data/toar_daily/"
         assert prepare_host() == "/home/testUser/Data/toar_daily/"
         assert prepare_host() == "/home/testUser/Data/toar_daily/"
         assert prepare_host() == "/p/project/cjjsc42/testUser/DATA/toar_daily/"
         assert prepare_host() == "/p/project/deepacf/intelliaq/testUser/DATA/toar_daily/"
-        assert prepare_host() == '/home/testUser/machinelearningtools/data/toar_daily/'
+        assert prepare_host() == '/home/testUser/mlair/data/toar_daily/'
 
     @mock.patch("socket.gethostname", return_value="NotExistingHostName")
     @mock.patch("getpass.getuser", return_value="zombie21")
-    def test_error_handling_unknown_host(self, mock_user, mock_host):
-        with pytest.raises(OSError) as e:
-            prepare_host()
-        assert "unknown host 'NotExistingHostName'" in e.value.args[0]
+    def test_prepare_host_unknown(self, mock_user, mock_host):
+        assert prepare_host() == os.path.join(os.path.abspath(os.getcwd()), 'data', 'daily')
 
     @mock.patch("getpass.getuser", return_value="zombie21")
-    @mock.patch("src.configuration.path_config.check_path_and_create", side_effect=PermissionError)
+    @mock.patch("mlair.configuration.path_config.check_path_and_create", side_effect=PermissionError)
     @mock.patch("os.path.exists", return_value=False)
     def test_error_handling(self, mock_path_exists, mock_cpath, mock_user):
         # if "runner-6HmDp9Qd-project-2411-concurrent" not in platform.node():
@@ -50,7 +48,7 @@ class TestPrepareHost:
     @mock.patch("os.makedirs", side_effect=None)
     def test_os_path_exists(self, mock_host, mock_user, mock_path, mock_check):
         path = prepare_host()
-        assert path == "/home/testUser/machinelearningtools/data/toar_daily/"
+        assert path == "/home/testUser/mlair/data/toar_daily/"
 
 
 class TestSetExperimentName:
diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handler/old_t_bootstraps.py
similarity index 96%
rename from test/test_data_handling/test_bootstraps.py
rename to test/test_data_handler/old_t_bootstraps.py
index 839b02203b22c2f5538613601aa125ed30455b0b..9616ed3f457d74e44e8a9eae5a3ed862fa804011 100644
--- a/test/test_data_handling/test_bootstraps.py
+++ b/test/test_data_handler/old_t_bootstraps.py
@@ -7,9 +7,8 @@ import numpy as np
 import pytest
 import xarray as xr
 
-from src.data_handling.bootstraps import BootStraps, CreateShuffledData, BootStrapGenerator
-from src.data_handling.data_generator import DataGenerator
-from src.data_handling import DataPrepJoin
+from mlair.data_handler.bootstraps import BootStraps
+from src.data_handler import DataPrepJoin
 
 
 @pytest.fixture
@@ -74,7 +73,7 @@ class TestCreateShuffledData:
         return CreateShuffledData(orig_generator, 20, data_path)
 
     @pytest.fixture
-    @mock.patch("src.data_handling.bootstraps.CreateShuffledData.create_shuffled_data", return_value=None)
+    @mock.patch("mlair.data_handling.bootstraps.CreateShuffledData.create_shuffled_data", return_value=None)
     def shuffled_data_no_creation(self, mock_create_shuffle_data, orig_generator, data_path):
         return CreateShuffledData(orig_generator, 20, data_path)
 
@@ -175,7 +174,7 @@ class TestBootStraps:
         return BootStraps(orig_generator, data_path, 20)
 
     @pytest.fixture
-    @mock.patch("src.data_handling.bootstraps.CreateShuffledData", return_value=None)
+    @mock.patch("mlair.data_handling.bootstraps.CreateShuffledData", return_value=None)
     def bootstrap_no_shuffling(self, mock_create_shuffle_data, orig_generator, data_path):
         shutil.rmtree(data_path)
         return BootStraps(orig_generator, data_path, 20)
@@ -212,7 +211,7 @@ class TestBootStraps:
         assert xr.testing.assert_equal(gen.history, expected.sel(variables=var_others)) is None
         assert gen.shuffled.variables == "o3"
 
-    @mock.patch("src.data_handling.data_generator.DataGenerator._load_pickle_data", side_effect=FileNotFoundError)
+    @mock.patch("mlair.data_handling.data_generator.DataGenerator._load_pickle_data", side_effect=FileNotFoundError)
     def test_get_generator_different_generator(self, mock_load_pickle, data_path, orig_generator):
         BootStraps(orig_generator, data_path, 20)  # to create
         orig_generator.window_history_size = 4
diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handler/old_t_data_generator.py
similarity index 97%
rename from test/test_data_handling/test_data_generator.py
rename to test/test_data_handler/old_t_data_generator.py
index 3144bde3440d861e109c4a3b0da8b77d317faa2b..9198923e2f75601f2ce7e6dc18a663da647eaadb 100644
--- a/test/test_data_handling/test_data_generator.py
+++ b/test/test_data_handler/old_t_data_generator.py
@@ -6,9 +6,8 @@ import numpy as np
 import pytest
 import xarray as xr
 
-from src.data_handling.data_generator import DataGenerator
-from src.data_handling import DataPrepJoin
-from src.helpers.join import EmptyQueryResult
+from mlair.data_hander import DataPrepJoin
+from mlair.helpers.join import EmptyQueryResult
 
 
 class TestDataGenerator:
@@ -80,10 +79,10 @@ class TestDataGenerator:
         assert gen.stations == ['DEBW107']
         assert gen.variables == ['o3', 'temp']
         assert gen.station_type is None
-        assert gen.interpolate_dim == 'datetime'
+        assert gen.time_dim == 'datetime'
         assert gen.target_dim == 'variables'
         assert gen.target_var == 'o3'
-        assert gen.interpolate_method == "linear"
+        assert gen.interpolation_method == "linear"
         assert gen.limit_nan_fill == 1
         assert gen.window_history_size == 7
         assert gen.window_lead_time == 4
@@ -93,7 +92,7 @@ class TestDataGenerator:
     def test_repr(self, gen):
         path = os.path.join(os.path.dirname(__file__), 'data')
         assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', stations=['DEBW107'], " \
-                                          f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \
+                                          f"variables=['o3', 'temp'], station_type=None, interpolation_dim='datetime', " \
                                           f"target_dim='variables', target_var='o3', **{{'start': 2010, 'end': 2014}})" \
             .rstrip()
 
diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handler/old_t_data_preparation.py
similarity index 99%
rename from test/test_data_handling/test_data_preparation.py
rename to test/test_data_handler/old_t_data_preparation.py
index 3af8a04561b26c67b5b4e9d35fcb08d6d0cfe0cb..586e17158a93880e2a98bf64189fa947299a64f3 100644
--- a/test/test_data_handling/test_data_preparation.py
+++ b/test/test_data_handler/old_t_data_preparation.py
@@ -8,9 +8,9 @@ import pandas as pd
 import pytest
 import xarray as xr
 
-from src.data_handling.data_preparation import AbstractDataPrep
-from src.data_handling import DataPrepJoin as DataPrep
-from src.helpers.join import EmptyQueryResult
+from mlair.data_handler.data_preparation import AbstractDataPrep
+from mlair.data_handler import DataPrepJoin as DataPrep
+from mlair.helpers.join import EmptyQueryResult
 
 
 class TestAbstractDataPrep:
diff --git a/test/test_data_handling/test_iterator.py b/test/test_data_handler/test_iterator.py
similarity index 82%
rename from test/test_data_handling/test_iterator.py
rename to test/test_data_handler/test_iterator.py
index 3f1cf683d627495cf958b6c2376a5c42a4c6e61f..ff81fc7b89b2cede0f47cdf209e77e373cd0d656 100644
--- a/test/test_data_handling/test_iterator.py
+++ b/test/test_data_handler/test_iterator.py
@@ -1,9 +1,11 @@
 
-from src.data_handling.iterator import DataCollection, StandardIterator, KerasIterator
-from src.helpers.testing import PyTestAllEqual
+from mlair.data_handler.iterator import DataCollection, StandardIterator, KerasIterator
+from mlair.helpers.testing import PyTestAllEqual
+from mlair.model_modules.model_class import MyLittleModel, MyBranchedModel
 
 import numpy as np
 import pytest
+import mock
 import os
 import shutil
 
@@ -56,13 +58,13 @@ class DummyData:
     def __init__(self, number_of_samples=np.random.randint(100, 150)):
         self.number_of_samples = number_of_samples
 
-    def get_X(self):
+    def get_X(self, upsampling=False, as_numpy=True):
         X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5))  # samples, window, variables
         X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2))  # samples, window, variables
         X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2))  # samples, window, variables
         return [X1, X2, X3]
 
-    def get_Y(self):
+    def get_Y(self, upsampling=False, as_numpy=True):
         Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1))  # samples, window, variables
         Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1))  # samples, window, variables
         return [Y1, Y2]
@@ -88,7 +90,7 @@ class TestKerasIterator:
     def test_init(self, collection, path):
         iterator = KerasIterator(collection, 25, path)
         assert isinstance(iterator._collection, DataCollection)
-        assert iterator._path == os.path.join(path, "%i.pickle")
+        assert iterator._path == os.path.join(path, str(id(iterator)), "%i.pickle")
         assert iterator.batch_size == 25
         assert iterator.shuffle is False
 
@@ -149,6 +151,8 @@ class TestKerasIterator:
         iterator._collection = collection
         iterator.batch_size = 50
         iterator.indexes = []
+        iterator.model = None
+        iterator.upsampling = False
         iterator._path = os.path.join(path, "%i.pickle")
         os.makedirs(path)
         iterator._prepare_batches()
@@ -162,6 +166,8 @@ class TestKerasIterator:
         iterator._collection = DataCollection([DummyData(50)])
         iterator.batch_size = 50
         iterator.indexes = []
+        iterator.model = None
+        iterator.upsampling = False
         iterator._path = os.path.join(path, "%i.pickle")
         os.makedirs(path)
         iterator._prepare_batches()
@@ -198,4 +204,25 @@ class TestKerasIterator:
         while iterator.indexes == sorted(iterator.indexes):
             iterator.on_epoch_end()
         assert iterator.indexes != [0, 1, 2, 3, 4]
-        assert sorted(iterator.indexes) == [0, 1, 2, 3, 4]
\ No newline at end of file
+        assert sorted(iterator.indexes) == [0, 1, 2, 3, 4]
+
+    def test_get_model_rank_no_model(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.model = None
+        assert iterator._get_model_rank() == 1
+
+    def test_get_model_rank_single_output_branch(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.model = MyLittleModel(shape_inputs=[(14, 1, 2)], shape_outputs=[(3,)])
+        assert iterator._get_model_rank() == 1
+
+    def test_get_model_rank_multiple_output_branch(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.model = MyBranchedModel(shape_inputs=[(14, 1, 2)], shape_outputs=[(3,)])
+        assert iterator._get_model_rank() == 3
+
+    def test_get_model_rank_error(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.model = mock.MagicMock(return_value=1)
+        with pytest.raises(TypeError):
+            iterator._get_model_rank()
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
deleted file mode 100644
index 43c61be2134d68e1f81ed50420e2a801c9e63646..0000000000000000000000000000000000000000
--- a/test/test_data_handling/test_data_distributor.py
+++ /dev/null
@@ -1,121 +0,0 @@
-import math
-import os
-
-import keras
-import numpy as np
-import pytest
-
-from src.data_handling.data_distributor import Distributor
-from src.data_handling.data_generator import DataGenerator
-from src.data_handling import DataPrepJoin
-from test.test_modules.test_training import my_test_model
-
-
-class TestDistributor:
-
-    @pytest.fixture
-    def generator(self):
-        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'],
-                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
-                             data_preparation=DataPrepJoin)
-
-    @pytest.fixture
-    def generator_two_stations(self):
-        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), ['DEBW107', 'DEBW013'],
-                             ['o3', 'temp'], 'datetime', 'variables', 'o3',
-                             statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
-                             data_preparation=DataPrepJoin)
-
-    @pytest.fixture
-    def model(self):
-        return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False)
-
-    @pytest.fixture
-    def model_with_minor_branch(self):
-        return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
-
-    @pytest.fixture
-    def distributor(self, generator, model):
-        return Distributor(generator, model)
-
-    def test_init_defaults(self, distributor):
-        assert distributor.batch_size == 256
-        assert distributor.do_data_permutation is False
-
-    def test_get_model_rank(self, distributor, model_with_minor_branch):
-        assert distributor._get_model_rank() == 1
-        distributor.model = model_with_minor_branch
-        assert distributor._get_model_rank() == 2
-        distributor.model = 1
-
-    def test_get_number_of_mini_batches(self, distributor):
-        values = np.zeros((2311, 19))
-        assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)
-
-    def test_distribute_on_batches_single_loop(self, generator_two_stations, model):
-        d = Distributor(generator_two_stations, model)
-        for e in d.distribute_on_batches(fit_call=False):
-            assert e[0].shape[0] <= d.batch_size
-
-    def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model):
-        d = Distributor(generator_two_stations, model)
-        elements = []
-        for i, e in enumerate(d.distribute_on_batches()):
-            if i < len(d):
-                elements.append(e[0])
-            elif i == 2 * len(d):  # check if all elements are repeated
-                assert np.testing.assert_array_equal(e[0], elements[i - len(d)]) is None
-            else:  # break when 3rd iteration starts (is called as infinite loop)
-                break
-
-    def test_len(self, distributor):
-        assert len(distributor) == math.ceil(len(distributor.generator[0][0]) / 256)
-
-    def test_len_two_stations(self, generator_two_stations, model):
-        gen = generator_two_stations
-        d = Distributor(gen, model)
-        expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
-        assert len(d) == expected
-
-    def test_permute_data_no_permutation(self, distributor):
-        x = np.array(range(20)).reshape(2, 10).T
-        y = np.array(range(10)).reshape(10, 1)
-        x_perm, y_perm = distributor._permute_data(x, y)
-        assert np.testing.assert_equal(x, x_perm) is None
-        assert np.testing.assert_equal(y, y_perm) is None
-
-    def test_permute_data(self, distributor):
-        x = np.array(range(20)).reshape(2, 10).T
-        y = np.array(range(10)).reshape(10, 1)
-        distributor.do_data_permutation = True
-        x_perm, y_perm = distributor._permute_data(x, y)
-        assert x_perm[0, 0] == y_perm[0]
-        assert x_perm[0, 1] == y_perm[0] + 10
-        assert x_perm[5, 0] == y_perm[5]
-        assert x_perm[5, 1] == y_perm[5] + 10
-        assert x_perm[-1, 0] == y_perm[-1]
-        assert x_perm[-1, 1] == y_perm[-1] + 10
-        # resort x_perm and compare if equal to x
-        x_perm.sort(axis=0)
-        y_perm.sort(axis=0)
-        assert np.testing.assert_equal(x, x_perm) is None
-        assert np.testing.assert_equal(y, y_perm) is None
-
-    def test_distribute_on_batches_upsampling_no_extremes_given(self, generator, model):
-        d = Distributor(generator, model, upsampling=True)
-        gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0]
-        num_mini_batches = math.ceil(gen_len / d.batch_size)
-        i = 0
-        for i, e in enumerate(d.distribute_on_batches(fit_call=False)):
-            assert e[0].shape[0] <= d.batch_size
-        assert i + 1 == num_mini_batches
-
-    def test_distribute_on_batches_upsampling(self, generator, model):
-        generator.extreme_values = [1]
-        d = Distributor(generator, model, upsampling=True)
-        gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0]
-        extr_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_extremes_label().shape[0]
-        i = 0
-        for i, e in enumerate(d.distribute_on_batches(fit_call=False)):
-            assert e[0].shape[0] <= d.batch_size
-        assert i + 1 == math.ceil((gen_len + extr_len) / d.batch_size)
diff --git a/test/test_datastore.py b/test/test_datastore.py
index 9aca1eef35927242df0b5f659eece716f81f6c13..662c90bf04e11b8b4ff9647506c1981c8883f30b 100644
--- a/test/test_datastore.py
+++ b/test/test_datastore.py
@@ -3,8 +3,8 @@ __date__ = '2019-11-22'
 
 import pytest
 
-from src.helpers.datastore import AbstractDataStore, DataStoreByVariable, DataStoreByScope, CorrectScope
-from src.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope, EmptyScope
+from mlair.helpers.datastore import AbstractDataStore, DataStoreByVariable, DataStoreByScope, CorrectScope
+from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope, EmptyScope
 
 
 class TestAbstractDataStore:
diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py
index 28a8bf6e421d62d58d76e7a32906f8a594f16ed7..281d60e07463c6b5118f36714d80144443a03050 100644
--- a/test/test_helpers/test_helpers.py
+++ b/test/test_helpers/test_helpers.py
@@ -5,13 +5,14 @@ import datetime as dt
 import logging
 import math
 import time
+import os
 
 import mock
 import pytest
 
-from src.helpers import to_list, dict_to_xarray, float_round, remove_items
-from src.helpers import PyTestRegex
-from src.helpers import Logger, TimeTracking
+from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items
+from mlair.helpers import PyTestRegex
+from mlair.helpers import Logger, TimeTracking
 
 
 class TestToList:
@@ -236,8 +237,8 @@ class TestLogger:
 
     def test_setup_logging_path_none(self):
         log_file = Logger.setup_logging_path(None)
-        assert PyTestRegex(
-            ".*machinelearningtools/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_file
+        test_regex = os.getcwd() + r"/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log"
+        assert PyTestRegex(test_regex) == log_file
 
     @mock.patch("os.makedirs", side_effect=None)
     def test_setup_logging_path_given(self, mock_makedirs):
diff --git a/test/test_join.py b/test/test_join.py
index 5adc013cfbd446c4feaf4a2b344f07d6f170077d..791723335e16cf2124512629414ebe626bc20e9c 100644
--- a/test/test_join.py
+++ b/test/test_join.py
@@ -2,9 +2,9 @@ from typing import Iterable
 
 import pytest
 
-from src.helpers.join import *
-from src.helpers.join import _save_to_pandas, _correct_stat_name, _lower_list
-from src.configuration.join_settings import join_settings
+from mlair.helpers.join import *
+from mlair.helpers.join import _save_to_pandas, _correct_stat_name, _lower_list
+from mlair.configuration.join_settings import join_settings
 
 
 class TestJoinUrlBase:
diff --git a/test/test_model_modules/test_advanced_paddings.py b/test/test_model_modules/test_advanced_paddings.py
index 8c7cae91ad12cc2b06ec82ba64f91c792a620756..8ca81c42c0b807b28c444badba8d92a255341eb4 100644
--- a/test/test_model_modules/test_advanced_paddings.py
+++ b/test/test_model_modules/test_advanced_paddings.py
@@ -1,7 +1,7 @@
 import keras
 import pytest
 
-from src.model_modules.advanced_paddings import *
+from mlair.model_modules.advanced_paddings import *
 
 
 class TestPadUtils:
diff --git a/test/test_model_modules/test_flatten_tail.py b/test/test_model_modules/test_flatten_tail.py
index 0de138ec2323aea3409d5deadfb26c9741b89f50..623d51c07f6b27c8d6238d8a5189dea33837115e 100644
--- a/test/test_model_modules/test_flatten_tail.py
+++ b/test/test_model_modules/test_flatten_tail.py
@@ -1,6 +1,6 @@
 import keras
 import pytest
-from src.model_modules.flatten import flatten_tail, get_activation
+from mlair.model_modules.flatten import flatten_tail, get_activation
 
 
 class TestGetActivation:
diff --git a/test/test_model_modules/test_inception_model.py b/test/test_model_modules/test_inception_model.py
index ca0126a44fa0f8ccd2ed2a7ea79c872c4731fea1..2dfc2c9c1c0510355216769b2ab83152a0a02118 100644
--- a/test/test_model_modules/test_inception_model.py
+++ b/test/test_model_modules/test_inception_model.py
@@ -1,9 +1,9 @@
 import keras
 import pytest
 
-from src.helpers import PyTestRegex
-from src.model_modules.advanced_paddings import ReflectionPadding2D, SymmetricPadding2D
-from src.model_modules.inception_model import InceptionModelBase
+from mlair.helpers import PyTestRegex
+from mlair.model_modules.advanced_paddings import ReflectionPadding2D, SymmetricPadding2D
+from mlair.model_modules.inception_model import InceptionModelBase
 
 
 class TestInceptionModelBase:
diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py
index 56c60ec43173e9fdd438214862603caba632bc65..78559ee0e54c725d242194133549d8b17699b729 100644
--- a/test/test_model_modules/test_keras_extensions.py
+++ b/test/test_model_modules/test_keras_extensions.py
@@ -4,8 +4,8 @@ import keras
 import mock
 import pytest
 
-from src.model_modules.loss import l_p_loss
-from src.model_modules.keras_extensions import *
+from mlair.model_modules.loss import l_p_loss
+from mlair.model_modules.keras_extensions import *
 
 
 class TestHistoryAdvanced:
diff --git a/test/test_model_modules/test_loss.py b/test/test_model_modules/test_loss.py
index c47f3f188a4b360bda08470fb00fd1d88a9f754c..e54e0b00de4a71d241f30e0b6b0c1a2e8fa1a19c 100644
--- a/test/test_model_modules/test_loss.py
+++ b/test/test_model_modules/test_loss.py
@@ -1,7 +1,7 @@
 import keras
 import numpy as np
 
-from src.model_modules.loss import l_p_loss
+from mlair.model_modules.loss import l_p_loss
 
 
 class TestLoss:
diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py
index 0ee2eb7e5d439c76888f1f05e238bb5507db6a7a..3e77fd17c4cd8151fe76816abf0bef323adb2e96 100644
--- a/test/test_model_modules/test_model_class.py
+++ b/test/test_model_modules/test_model_class.py
@@ -1,8 +1,8 @@
 import keras
 import pytest
 
-from src.model_modules.model_class import AbstractModelClass
-from src.model_modules.model_class import MyPaperModel
+from mlair.model_modules.model_class import AbstractModelClass
+from mlair.model_modules.model_class import MyPaperModel
 
 
 class Paddings:
@@ -12,7 +12,7 @@ class Paddings:
 class AbstractModelSubClass(AbstractModelClass):
 
     def __init__(self):
-        super().__init__()
+        super().__init__(shape_inputs=(12, 1, 2), shape_outputs=3)
         self.test_attr = "testAttr"
 
 
@@ -20,7 +20,7 @@ class TestAbstractModelClass:
 
     @pytest.fixture
     def amc(self):
-        return AbstractModelClass()
+        return AbstractModelClass(shape_inputs=(14, 1, 2), shape_outputs=(3,))
 
     @pytest.fixture
     def amsc(self):
@@ -31,6 +31,8 @@ class TestAbstractModelClass:
         # assert amc.loss is None
         assert amc.model_name == "AbstractModelClass"
         assert amc.custom_objects == {}
+        assert amc.shape_inputs == (14, 1, 2)
+        assert amc.shape_outputs == 3
 
     def test_model_property(self, amc):
         amc.model = keras.Model()
@@ -179,8 +181,10 @@ class TestAbstractModelClass:
         assert amc.compile == amc.model.compile
 
     def test_get_settings(self, amc, amsc):
-        assert amc.get_settings() == {"model_name": "AbstractModelClass"}
-        assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass"}
+        assert amc.get_settings() == {"model_name": "AbstractModelClass", "shape_inputs": (14, 1, 2),
+                                      "shape_outputs": 3}
+        assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass",
+                                       "shape_inputs": (12, 1, 2), "shape_outputs": 3}
 
     def test_custom_objects(self, amc):
         amc.custom_objects = {"Test": 123}
@@ -200,7 +204,7 @@ class TestMyPaperModel:
 
     @pytest.fixture
     def mpm(self):
-        return MyPaperModel(window_history_size=6, window_lead_time=4, channels=9)
+        return MyPaperModel(shape_inputs=[(7, 1, 9)], shape_outputs=[(4,)])
 
     def test_init(self, mpm):
         # check if loss number of loss functions fit to model outputs
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
deleted file mode 100644
index 6de61b2dbe88e24eb3caccf6de575d6340129b91..0000000000000000000000000000000000000000
--- a/test/test_modules/test_model_setup.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import os
-
-import pytest
-
-from src.data_handling import DataPrepJoin
-from src.data_handling.data_generator import DataGenerator
-from src.helpers.datastore import EmptyScope
-from src.model_modules.keras_extensions import CallbackHandler
-from src.model_modules.model_class import AbstractModelClass, MyLittleModel
-from src.run_modules.model_setup import ModelSetup
-from src.run_modules.run_environment import RunEnvironment
-
-
-class TestModelSetup:
-
-    @pytest.fixture
-    def setup(self):
-        obj = object.__new__(ModelSetup)
-        super(ModelSetup, obj).__init__()
-        obj.scope = "general.model"
-        obj.model = None
-        obj.callbacks_name = "placeholder_%s_str.pickle"
-        obj.data_store.set("model_class", MyLittleModel)
-        obj.data_store.set("lr_decay", "dummy_str", "general.model")
-        obj.data_store.set("hist", "dummy_str", "general.model")
-        obj.data_store.set("epochs", 2)
-        obj.model_name = "%s.h5"
-        yield obj
-        RunEnvironment().__del__()
-
-    @pytest.fixture
-    def gen(self):
-        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'],
-                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
-                             data_preparation=DataPrepJoin)
-
-    @pytest.fixture
-    def setup_with_gen(self, setup, gen):
-        setup.data_store.set("generator", gen, "general.train")
-        setup.data_store.set("window_history_size", gen.window_history_size, "general")
-        setup.data_store.set("window_lead_time", gen.window_lead_time, "general")
-        setup.data_store.set("channels", 2, "general")
-        yield setup
-        RunEnvironment().__del__()
-
-    @pytest.fixture
-    def setup_with_gen_tiny(self, setup, gen):
-        setup.data_store.set("generator", gen, "general.train")
-        yield setup
-        RunEnvironment().__del__()
-
-    @pytest.fixture
-    def setup_with_model(self, setup):
-        setup.model = AbstractModelClass()
-        setup.model.test_param = "42"
-        yield setup
-        RunEnvironment().__del__()
-
-    @staticmethod
-    def current_scope_as_set(model_cls):
-        return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
-
-    def test_set_callbacks(self, setup):
-        assert "general.model" not in setup.data_store.search_name("callbacks")
-        setup.checkpoint_name = "TestName"
-        setup._set_callbacks()
-        assert "general.model" in setup.data_store.search_name("callbacks")
-        callbacks = setup.data_store.get("callbacks", "general.model")
-        assert len(callbacks.get_callbacks()) == 3
-
-    def test_set_callbacks_no_lr_decay(self, setup):
-        setup.data_store.set("lr_decay", None, "general.model")
-        assert "general.model" not in setup.data_store.search_name("callbacks")
-        setup.checkpoint_name = "TestName"
-        setup._set_callbacks()
-        callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
-        assert len(callbacks.get_callbacks()) == 2
-        with pytest.raises(IndexError):
-            callbacks.get_callback_by_name("lr_decay")
-
-    def test_get_model_settings(self, setup_with_model):
-        setup_with_model.scope = "model_test"
-        with pytest.raises(EmptyScope):
-            self.current_scope_as_set(setup_with_model)  # will fail because scope is not created
-        setup_with_model.get_model_settings()  # this saves now the parameter test_param into scope
-        assert {"test_param", "model_name"} <= self.current_scope_as_set(setup_with_model)
-
-    def test_build_model(self, setup_with_gen):
-        assert setup_with_gen.model is None
-        setup_with_gen.build_model()
-        assert isinstance(setup_with_gen.model, AbstractModelClass)
-        expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
-                    "optimizer", "activation"}
-        assert expected <= self.current_scope_as_set(setup_with_gen)
-
-    def test_set_channels(self, setup_with_gen_tiny):
-        assert len(setup_with_gen_tiny.data_store.search_name("channels")) == 0
-        setup_with_gen_tiny._set_channels()
-        assert setup_with_gen_tiny.data_store.get("channels", setup_with_gen_tiny.scope) == 2
-
-    def test_load_weights(self):
-        pass
-
-    def test_compile_model(self):
-        pass
-
-    def test_run(self):
-        pass
-
-    def test_init(self):
-        pass
diff --git a/test/test_plotting/test_tracker_plot.py b/test/test_plotting/test_tracker_plot.py
index 9a92360a819c130c213d06b89a48a896e082adad..196879657452fe12238c990fc419cb0848c9ec9c 100644
--- a/test/test_plotting/test_tracker_plot.py
+++ b/test/test_plotting/test_tracker_plot.py
@@ -7,8 +7,8 @@ import shutil
 from matplotlib import pyplot as plt
 import numpy as np
 
-from src.plotting.tracker_plot import TrackObject, TrackChain, TrackPlot
-from src.helpers import PyTestAllEqual
+from mlair.plotting.tracker_plot import TrackObject, TrackChain, TrackPlot
+from mlair.helpers import PyTestAllEqual
 
 
 class TestTrackObject:
diff --git a/test/test_plotting/test_training_monitoring.py b/test/test_plotting/test_training_monitoring.py
index 6e5e0abbc5da0978e200f19019700c4dedd14ad0..18009bc19947bd3318c6f1d220d303c1efeec972 100644
--- a/test/test_plotting/test_training_monitoring.py
+++ b/test/test_plotting/test_training_monitoring.py
@@ -3,8 +3,8 @@ import os
 import keras
 import pytest
 
-from src.model_modules.keras_extensions import LearningRateDecay
-from src.plotting.training_monitoring import PlotModelLearningRate, PlotModelHistory
+from mlair.model_modules.keras_extensions import LearningRateDecay
+from mlair.plotting.training_monitoring import PlotModelLearningRate, PlotModelHistory
 
 
 @pytest.fixture
diff --git a/test/test_modules/test_experiment_setup.py b/test/test_run_modules/test_experiment_setup.py
similarity index 90%
rename from test/test_modules/test_experiment_setup.py
rename to test/test_run_modules/test_experiment_setup.py
index 5b7d517e658de6bd71e1b4190bb5114dc005216e..abd265f5815d974d6edb474e5a03ed08dc5843cc 100644
--- a/test/test_modules/test_experiment_setup.py
+++ b/test/test_run_modules/test_experiment_setup.py
@@ -4,9 +4,9 @@ import os
 
 import pytest
 
-from src.helpers import TimeTracking
-from src.configuration.path_config import prepare_host
-from src.run_modules.experiment_setup import ExperimentSetup
+from mlair.helpers import TimeTracking
+from mlair.configuration.path_config import prepare_host
+from mlair.run_modules.experiment_setup import ExperimentSetup
 
 
 class TestExperimentSetup:
@@ -14,7 +14,7 @@ class TestExperimentSetup:
     @pytest.fixture
     def empty_obj(self, caplog):
         obj = object.__new__(ExperimentSetup)
-        obj.time = TimeTracking()
+        super(ExperimentSetup, obj).__init__()
         caplog.set_level(logging.DEBUG)
         return obj
 
@@ -43,7 +43,7 @@ class TestExperimentSetup:
         assert data_store.get("fraction_of_training", "general") == 0.8
         # set experiment name
         assert data_store.get("experiment_name", "general") == "TestExperiment_daily"
-        path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment_daily"))
+        path = os.path.abspath(os.path.join(os.getcwd(), "TestExperiment_daily"))
         assert data_store.get("experiment_path", "general") == path
         default_statistics_per_var = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum',
                                       'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu',
@@ -51,8 +51,6 @@ class TestExperimentSetup:
         # setup for data
         default_stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
         assert data_store.get("stations", "general") == default_stations
-        assert data_store.get("network", "general") == "AIRBASE"
-        assert data_store.get("station_type", "general") == "background"
         assert data_store.get("variables", "general") == list(default_statistics_per_var.keys())
         assert data_store.get("statistics_per_var", "general") == default_statistics_per_var
         assert data_store.get("start", "general") == "1997-01-01"
@@ -64,9 +62,9 @@ class TestExperimentSetup:
         assert data_store.get("window_lead_time", "general") == 3
         # interpolation
         assert data_store.get("dimensions", "general") == {'new_index': ['datetime', 'Stations']}
-        assert data_store.get("interpolate_dim", "general") == "datetime"
-        assert data_store.get("interpolate_method", "general") == "linear"
-        assert data_store.get("limit_nan_fill", "general") == 1
+        assert data_store.get("time_dim", "general") == "datetime"
+        assert data_store.get("interpolation_method", "general") == "linear"
+        assert data_store.get("interpolation_limit", "general") == 1
         # train parameters
         assert data_store.get("start", "general.train") == "1997-01-01"
         assert data_store.get("end", "general.train") == "2007-12-31"
@@ -93,7 +91,7 @@ class TestExperimentSetup:
                       stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", station_type="background",
                       variables=["o3", "temp"], start="1999-01-01", end="2001-01-01", window_history_size=4,
                       target_var="relhum", target_dim="target", window_lead_time=10, dimensions="dim1",
-                      interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01",
+                      time_dim="int_dim", interpolation_method="cubic", interpolation_limit=5, train_start="2000-01-01",
                       train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05",
                       test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False,
                       fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True, val_min_length=20)
@@ -125,9 +123,9 @@ class TestExperimentSetup:
         assert data_store.get("window_lead_time", "general") == 10
         # interpolation
         assert data_store.get("dimensions", "general") == "dim1"
-        assert data_store.get("interpolate_dim", "general") == "int_dim"
-        assert data_store.get("interpolate_method", "general") == "cubic"
-        assert data_store.get("limit_nan_fill", "general") == 5
+        assert data_store.get("time_dim", "general") == "int_dim"
+        assert data_store.get("interpolation_method", "general") == "cubic"
+        assert data_store.get("interpolation_limit", "general") == 5
         # train parameters
         assert data_store.get("start", "general.train") == "2000-01-01"
         assert data_store.get("end", "general.train") == "2000-01-02"
diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b3e43b2bbfda44f1a5b5463e876adc578360ff3
--- /dev/null
+++ b/test/test_run_modules/test_model_setup.py
@@ -0,0 +1,144 @@
+import os
+import numpy as np
+import shutil
+
+import pytest
+
+from mlair.data_handler import KerasIterator
+from mlair.data_handler import DataCollection
+from mlair.helpers.datastore import EmptyScope
+from mlair.model_modules.keras_extensions import CallbackHandler
+from mlair.model_modules.model_class import AbstractModelClass, MyLittleModel
+from mlair.run_modules.model_setup import ModelSetup
+from mlair.run_modules.run_environment import RunEnvironment
+
+
+class TestModelSetup:
+
+    @pytest.fixture
+    def setup(self):
+        obj = object.__new__(ModelSetup)
+        super(ModelSetup, obj).__init__()
+        obj.scope = "general.model"
+        obj.model = None
+        obj.callbacks_name = "placeholder_%s_str.pickle"
+        obj.data_store.set("model_class", MyLittleModel)
+        obj.data_store.set("lr_decay", "dummy_str", "general.model")
+        obj.data_store.set("hist", "dummy_str", "general.model")
+        obj.data_store.set("epochs", 2)
+        obj.model_name = "%s.h5"
+        yield obj
+        RunEnvironment().__del__()
+
+    @pytest.fixture
+    def path(self):
+        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+        shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
+        yield p
+        shutil.rmtree(p, ignore_errors=True)
+
+    @pytest.fixture
+    def keras_iterator(self, path):
+        coll = []
+        for i in range(3):
+            coll.append(DummyData(50 + i))
+        data_coll = DataCollection(collection=coll)
+        KerasIterator(data_coll, 25, path)
+        return data_coll
+
+    @pytest.fixture
+    def setup_with_gen(self, setup, keras_iterator):
+        setup.data_store.set("data_collection", keras_iterator, "train")
+        shape_inputs = [keras_iterator[0].get_X()[0].shape[1:]]
+        setup.data_store.set("shape_inputs", shape_inputs, "model")
+        shape_outputs = [keras_iterator[0].get_Y()[0].shape[1:]]
+        setup.data_store.set("shape_outputs", shape_outputs, "model")
+        yield setup
+        RunEnvironment().__del__()
+
+    @pytest.fixture
+    def setup_with_gen_tiny(self, setup, keras_iterator):
+        setup.data_store.set("data_collection", keras_iterator, "train")
+        yield setup
+        RunEnvironment().__del__()
+
+    @pytest.fixture
+    def setup_with_model(self, setup):
+        setup.model = AbstractModelClass(shape_inputs=(12, 1), shape_outputs=2)
+        setup.model.test_param = "42"
+        yield setup
+        RunEnvironment().__del__()
+
+    @staticmethod
+    def current_scope_as_set(model_cls):
+        return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
+
+    def test_set_callbacks(self, setup):
+        assert "general.model" not in setup.data_store.search_name("callbacks")
+        setup.checkpoint_name = "TestName"
+        setup._set_callbacks()
+        assert "general.model" in setup.data_store.search_name("callbacks")
+        callbacks = setup.data_store.get("callbacks", "general.model")
+        assert len(callbacks.get_callbacks()) == 3
+
+    def test_set_callbacks_no_lr_decay(self, setup):
+        setup.data_store.set("lr_decay", None, "general.model")
+        assert "general.model" not in setup.data_store.search_name("callbacks")
+        setup.checkpoint_name = "TestName"
+        setup._set_callbacks()
+        callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
+        assert len(callbacks.get_callbacks()) == 2
+        with pytest.raises(IndexError):
+            callbacks.get_callback_by_name("lr_decay")
+
+    def test_get_model_settings(self, setup_with_model):
+        setup_with_model.scope = "model_test"
+        with pytest.raises(EmptyScope):
+            self.current_scope_as_set(setup_with_model)  # will fail because scope is not created
+        setup_with_model.get_model_settings()  # this saves now the parameter test_param into scope
+        assert {"test_param", "model_name"} <= self.current_scope_as_set(setup_with_model)
+
+    def test_build_model(self, setup_with_gen):
+        assert setup_with_gen.model is None
+        setup_with_gen.build_model()
+        assert isinstance(setup_with_gen.model, AbstractModelClass)
+        expected = {"lr_decay", "model_name", "dropout_rate", "regularizer", "initial_lr", "optimizer", "activation",
+                    "shape_inputs", "shape_outputs"}
+        assert expected <= self.current_scope_as_set(setup_with_gen)
+
+    def test_set_shapes(self, setup_with_gen_tiny):
+        assert len(setup_with_gen_tiny.data_store.search_name("shape_inputs")) == 0
+        assert len(setup_with_gen_tiny.data_store.search_name("shape_outputs")) == 0
+        setup_with_gen_tiny._set_shapes()
+        assert setup_with_gen_tiny.data_store.get("shape_inputs", setup_with_gen_tiny.scope) == [(14, 1, 5), (10, 1, 2),
+                                                                                                 (1, 1, 2)]
+        assert setup_with_gen_tiny.data_store.get("shape_outputs", setup_with_gen_tiny.scope) == [(5,), (3,)]
+
+    def test_load_weights(self):
+        pass
+
+    def test_compile_model(self):
+        pass
+
+    def test_run(self):
+        pass
+
+    def test_init(self):
+        pass
+
+
+class DummyData:
+
+    def __init__(self, number_of_samples=np.random.randint(100, 150)):
+        self.number_of_samples = number_of_samples
+
+    def get_X(self, upsampling=False, as_numpy=True):
+        X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 1, 5))  # samples, window, variables
+        X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 1, 2))  # samples, window, variables
+        X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 1, 2))  # samples, window, variables
+        return [X1, X2, X3]
+
+    def get_Y(self, upsampling=False, as_numpy=True):
+        Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5))  # samples, window
+        Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 3))  # samples, window
+        return [Y1, Y2]
\ No newline at end of file
diff --git a/test/test_modules/test_partition_check.py b/test/test_run_modules/test_partition_check.py
similarity index 90%
rename from test/test_modules/test_partition_check.py
rename to test/test_run_modules/test_partition_check.py
index b04e01d13e9e160553f8ff66af8d97f65aa24bf0..ba5b3d7ef127258eaa6c4f2a1a0b4d0b531eeac5 100644
--- a/test/test_modules/test_partition_check.py
+++ b/test/test_run_modules/test_partition_check.py
@@ -2,10 +2,9 @@ import logging
 
 import pytest
 import mock
-from src.run_modules.experiment_setup import ExperimentSetup
-from src.run_modules.partition_check import PartitionCheck
-from src.run_modules.run_environment import RunEnvironment
-from src.configuration import get_host
+from mlair.run_modules.experiment_setup import ExperimentSetup
+from mlair.run_modules.partition_check import PartitionCheck
+from mlair.run_modules.run_environment import RunEnvironment
 
 
 class TestPartitionCheck:
@@ -24,6 +23,7 @@ class TestPartitionCheck:
     @mock.patch("os.path.exists", return_value=False)
     @mock.patch("os.makedirs", side_effect=None)
     def obj_with_exp_setup_login(self, mock_host, mock_user,  mock_path, mock_check):
+        RunEnvironment().__del__()
         ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
                         statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background")
         pre = object.__new__(PartitionCheck)
@@ -37,6 +37,7 @@ class TestPartitionCheck:
     @mock.patch("os.path.exists", return_value=False)
     @mock.patch("os.makedirs", side_effect=None)
     def obj_with_exp_setup_compute(self, mock_host, mock_user,  mock_path, mock_check):
+        RunEnvironment().__del__()
         ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
                         statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background")
         pre = object.__new__(PartitionCheck)
@@ -71,5 +72,5 @@ class TestPartitionCheck:
     @mock.patch("os.path.exists", return_value=False)
     @mock.patch("os.makedirs", side_effect=None)
     def test_run_compute(self, mock_host, mock_user, mock_path, mock_check, obj_with_exp_setup_compute, caplog):
-
-        assert obj_with_exp_setup_compute.__next__()._run() is None
+        obj = obj_with_exp_setup_compute.__next__()
+        assert obj._run() is None
diff --git a/test/test_modules/test_post_processing.py b/test/test_run_modules/test_post_processing.py
similarity index 100%
rename from test/test_modules/test_post_processing.py
rename to test/test_run_modules/test_post_processing.py
diff --git a/test/test_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py
similarity index 65%
rename from test/test_modules/test_pre_processing.py
rename to test/test_run_modules/test_pre_processing.py
index 0b439e9e9ad54ca3aef70e27b2017482706383c0..97e73204068d334590ee98271080acddf29dfc5f 100644
--- a/test/test_modules/test_pre_processing.py
+++ b/test/test_run_modules/test_pre_processing.py
@@ -2,13 +2,12 @@ import logging
 
 import pytest
 
-from src.data_handling import DataPrepJoin
-from src.data_handling.data_generator import DataGenerator
-from src.helpers.datastore import NameNotFoundInScope
-from src.helpers import PyTestRegex
-from src.run_modules.experiment_setup import ExperimentSetup
-from src.run_modules.pre_processing import PreProcessing, DEFAULT_ARGS_LIST, DEFAULT_KWARGS_LIST
-from src.run_modules.run_environment import RunEnvironment
+from mlair.data_handler import DefaultDataPreparation, DataCollection, AbstractDataPreparation
+from mlair.helpers.datastore import NameNotFoundInScope
+from mlair.helpers import PyTestRegex
+from mlair.run_modules.experiment_setup import ExperimentSetup
+from mlair.run_modules.pre_processing import PreProcessing
+from mlair.run_modules.run_environment import RunEnvironment
 
 
 class TestPreProcessing:
@@ -29,7 +28,7 @@ class TestPreProcessing:
     def obj_with_exp_setup(self):
         ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
                         statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background",
-                        data_preparation=DataPrepJoin)
+                        data_preparation=DefaultDataPreparation)
         pre = object.__new__(PreProcessing)
         super(PreProcessing, pre).__init__()
         yield pre
@@ -42,25 +41,26 @@ class TestPreProcessing:
         caplog.set_level(logging.INFO)
         with PreProcessing():
             assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
-            assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (all)')
+            assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (preprocessing)')
             assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 '
                                                                         r'station\(s\). Found 5/5 valid stations.'))
         RunEnvironment().__del__()
 
     def test_run(self, obj_with_exp_setup):
-        assert obj_with_exp_setup.data_store.search_name("generator") == []
+        assert obj_with_exp_setup.data_store.search_name("data_collection") == []
         assert obj_with_exp_setup._run() is None
-        assert obj_with_exp_setup.data_store.search_name("generator") == sorted(["general.train", "general.val",
-                                                                                 "general.train_val", "general.test"])
+        assert obj_with_exp_setup.data_store.search_name("data_collection") == sorted(["general.train", "general.val",
+                                                                                       "general.train_val",
+                                                                                       "general.test"])
 
     def test_split_train_val_test(self, obj_with_exp_setup):
-        assert obj_with_exp_setup.data_store.search_name("generator") == []
+        assert obj_with_exp_setup.data_store.search_name("data_collection") == []
         obj_with_exp_setup.split_train_val_test()
         data_store = obj_with_exp_setup.data_store
-        expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length", "extreme_values",
-                           "extremes_on_right_tail_only", "upsampling"]
+        expected_params = ["data_collection", "start", "end", "stations", "permute_data", "min_length",
+                           "extreme_values", "extremes_on_right_tail_only", "upsampling"]
         assert data_store.search_scope("general.train") == sorted(expected_params)
-        assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test",
+        assert data_store.search_name("data_collection") == sorted(["general.train", "general.val", "general.test",
                                                               "general.train_val"])
 
     def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup):
@@ -69,9 +69,9 @@ class TestPreProcessing:
         obj_with_exp_setup.create_set_split(slice(0, 2), "awesome")
         assert ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") in caplog.record_tuples
         data_store = obj_with_exp_setup.data_store
-        assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator)
+        assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection)
         with pytest.raises(NameNotFoundInScope):
-            data_store.get("generator", "general")
+            data_store.get("data_collection", "general")
         assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBY081"]
 
     def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup):
@@ -80,22 +80,22 @@ class TestPreProcessing:
         message = "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']"
         assert ('root', 10, message) in caplog.record_tuples
         data_store = obj_with_exp_setup.data_store
-        assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator)
+        assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection)
         with pytest.raises(NameNotFoundInScope):
-            data_store.get("generator", "general")
+            data_store.get("data_collection", "general")
         assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
 
     @pytest.mark.parametrize("name", (None, "tester"))
-    def test_check_valid_stations(self, caplog, obj_with_exp_setup, name):
+    def test_validate_station(self, caplog, obj_with_exp_setup, name):
         pre = obj_with_exp_setup
         caplog.set_level(logging.INFO)
-        args = pre.data_store.create_args_dict(DEFAULT_ARGS_LIST)
-        kwargs = pre.data_store.create_args_dict(DEFAULT_KWARGS_LIST)
         stations = pre.data_store.get("stations", "general")
-        valid_stations = pre.check_valid_stations(args, kwargs, stations, name=name)
+        data_preparation = pre.data_store.get("data_preparation")
+        collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=name)
+        assert isinstance(collection, DataCollection)
         assert len(valid_stations) < len(stations)
         assert valid_stations == stations[:-1]
-        expected = 'check valid stations started (tester)' if name else 'check valid stations started'
+        expected = "check valid stations started" + ' (%s)' % (name if name else 'all')
         assert caplog.record_tuples[0] == ('root', 20, expected)
         assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 '
                                                                     r'station\(s\). Found 5/6 valid stations.'))
@@ -107,3 +107,11 @@ class TestPreProcessing:
         assert dummy_list[val] == list(range(10, 13))
         assert dummy_list[test] == list(range(13, 15))
         assert dummy_list[train_val] == list(range(0, 13))
+
+    def test_transformation(self):
+        pre = object.__new__(PreProcessing)
+        data_preparation = AbstractDataPreparation
+        stations = ['DEBW107', 'DEBY081']
+        assert pre.transformation(data_preparation, stations) is None
+        class data_preparation_no_trans: pass
+        assert pre.transformation(data_preparation_no_trans, stations) is None
diff --git a/test/test_modules/test_run_environment.py b/test/test_run_modules/test_run_environment.py
similarity index 90%
rename from test/test_modules/test_run_environment.py
rename to test/test_run_modules/test_run_environment.py
index 59bb8535c4dab44e646bd6bc4aa83a8553be4d26..aa385e32673c2bf58db3f5666b2f64076af0193f 100644
--- a/test/test_modules/test_run_environment.py
+++ b/test/test_run_modules/test_run_environment.py
@@ -1,7 +1,7 @@
 import logging
 
-from src.helpers import TimeTracking, PyTestRegex
-from src.run_modules.run_environment import RunEnvironment
+from mlair.helpers import TimeTracking, PyTestRegex
+from mlair.run_modules.run_environment import RunEnvironment
 
 
 class TestRunEnvironment:
diff --git a/test/test_modules/test_training.py b/test/test_run_modules/test_training.py
similarity index 69%
rename from test/test_modules/test_training.py
rename to test/test_run_modules/test_training.py
index d58c1a973ec474b2ec786271dff9d35ce5ca94d9..1fec8f4e56e2925bff0bc4af859dac1fe5fbb2b6 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_run_modules/test_training.py
@@ -9,18 +9,16 @@ import mock
 import pytest
 from keras.callbacks import History
 
-from src.data_handling import DataPrepJoin
-from src.data_handling.data_distributor import Distributor
-from src.data_handling.data_generator import DataGenerator
-from src.helpers import PyTestRegex
-from src.model_modules.flatten import flatten_tail
-from src.model_modules.inception_model import InceptionModelBase
-from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler
-from src.run_modules.run_environment import RunEnvironment
-from src.run_modules.training import Training
-
-
-def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
+from mlair.data_handler import DataCollection, KerasIterator, DefaultDataPreparation
+from mlair.helpers import PyTestRegex
+from mlair.model_modules.flatten import flatten_tail
+from mlair.model_modules.inception_model import InceptionModelBase
+from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler
+from mlair.run_modules.run_environment import RunEnvironment
+from mlair.run_modules.training import Training
+
+
+def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False):
     inception_model = InceptionModelBase()
     conv_settings_dict1 = {
         'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
@@ -29,7 +27,6 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
     X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
     X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
     if add_minor_branch:
-        # out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
         out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
                             output_activation='linear', reduction_filter=64,
                             name='Minor_1', dropout_rate=dropout_rate,
@@ -37,8 +34,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
     else:
         out = []
     X_in = keras.layers.Dropout(dropout_rate)(X_in)
-    # out.append(flatten_tail(X_in, 'Main', activation=activation))
-    out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
+    out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size,
                             output_activation='linear', reduction_filter=64,
                             name='Main', dropout_rate=dropout_rate,
                             ))
@@ -48,7 +44,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
 class TestTraining:
 
     @pytest.fixture
-    def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path):
+    def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path, batch_path):
         obj = object.__new__(Training)
         super(Training, obj).__init__()
         obj.model = model
@@ -62,19 +58,23 @@ class TestTraining:
         obj.lr_sc = lr
         obj.hist = hist
         obj.experiment_name = "TestExperiment"
-        obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
-        obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val")
-        obj.data_store.set("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test")
+        obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train")
+        obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val")
+        obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test")
         os.makedirs(path)
         obj.data_store.set("experiment_path", path, "general")
+        os.makedirs(batch_path)
+        obj.data_store.set("batch_path", batch_path, "general")
         os.makedirs(model_path)
         obj.data_store.set("model_path", model_path, "general")
         obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
         obj.data_store.set("experiment_name", "TestExperiment", "general")
+
         path_plot = os.path.join(path, "plots")
         os.makedirs(path_plot)
         obj.data_store.set("plot_path", path_plot, "general")
         obj._trainable = True
+        obj._create_new_model = False
         yield obj
         if os.path.exists(path):
             shutil.rmtree(path)
@@ -108,14 +108,35 @@ class TestTraining:
         return os.path.join(path, "model")
 
     @pytest.fixture
-    def generator(self, path):
-        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), ['DEBW107'], ['o3', 'temp'], 'datetime',
-                             'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
-                             data_preparation=DataPrepJoin)
+    def batch_path(self, path):
+        return os.path.join(path, "batch")
+
+    @pytest.fixture
+    def window_history_size(self):
+        return 7
+
+    @pytest.fixture
+    def window_lead_time(self):
+        return 2
+
+    @pytest.fixture
+    def statistics_per_var(self):
+        return {'o3': 'dma8eu', 'temp': 'maximum'}
 
     @pytest.fixture
-    def model(self):
-        return my_test_model(keras.layers.PReLU, 7, 2, 0.1, False)
+    def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var):
+        data_prep = DefaultDataPreparation.build(['DEBW107'], data_path=os.path.join(os.path.dirname(__file__), 'data'),
+                                                 statistics_per_var=statistics_per_var, station_type="background",
+                                                 network="AIRBASE", sampling="daily", target_dim="variables",
+                                                 target_var="o3", time_dim="datetime",
+                                                 window_history_size=window_history_size,
+                                                 window_lead_time=window_lead_time, name_affix="train")
+        return DataCollection([data_prep])
+
+    @pytest.fixture
+    def model(self, window_history_size, window_lead_time, statistics_per_var):
+        channels = len(list(statistics_per_var.keys()))
+        return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False)
 
     @pytest.fixture
     def callbacks(self, path):
@@ -129,29 +150,31 @@ class TestTraining:
         return clbk, hist, lr
 
     @pytest.fixture
-    def ready_to_train(self, generator: DataGenerator, init_without_run: Training):
-        init_without_run.train_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
-        init_without_run.val_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
+    def ready_to_train(self, data_collection: DataCollection, init_without_run: Training, batch_path: str):
+        batch_size = init_without_run.batch_size
+        model = init_without_run.model
+        init_without_run.train_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="train")
+        init_without_run.val_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="val")
         init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
         return init_without_run
 
     @pytest.fixture
-    def ready_to_run(self, generator, init_without_run):
+    def ready_to_run(self, data_collection, init_without_run):
         obj = init_without_run
-        obj.data_store.set("generator", generator, "general.train")
-        obj.data_store.set("generator", generator, "general.val")
-        obj.data_store.set("generator", generator, "general.test")
+        obj.data_store.set("data_collection", data_collection, "general.train")
+        obj.data_store.set("data_collection", data_collection, "general.val")
+        obj.data_store.set("data_collection", data_collection, "general.test")
         obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
         return obj
 
     @pytest.fixture
-    def ready_to_init(self, generator, model, callbacks, path, model_path):
+    def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path):
         os.makedirs(path)
         os.makedirs(model_path)
         obj = RunEnvironment()
-        obj.data_store.set("generator", generator, "general.train")
-        obj.data_store.set("generator", generator, "general.val")
-        obj.data_store.set("generator", generator, "general.test")
+        obj.data_store.set("data_collection", data_collection, "general.train")
+        obj.data_store.set("data_collection", data_collection, "general.val")
+        obj.data_store.set("data_collection", data_collection, "general.test")
         model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
         obj.data_store.set("model", model, "general.model")
         obj.data_store.set("model_path", model_path, "general")
@@ -166,6 +189,8 @@ class TestTraining:
         obj.data_store.set("experiment_path", path, "general")
         obj.data_store.set("trainable", True, "general")
         obj.data_store.set("create_new_model", True, "general")
+        os.makedirs(batch_path)
+        obj.data_store.set("batch_path", batch_path, "general")
         path_plot = os.path.join(path, "plots")
         os.makedirs(path_plot)
         obj.data_store.set("plot_path", path_plot, "general")
@@ -176,6 +201,13 @@ class TestTraining:
     def test_init(self, ready_to_init):
         assert isinstance(Training(), Training)  # just test, if nothing fails
 
+    def test_no_training(self, ready_to_init, caplog):
+        caplog.set_level(logging.INFO)
+        ready_to_init.data_store.set("trainable", False)
+        Training()
+        message = "No training has started, because trainable parameter was false."
+        assert caplog.record_tuples[-2] == ("root", 20, message)
+
     def test_run(self, ready_to_run):
         assert ready_to_run._run() is None  # just test, if nothing fails
 
@@ -187,8 +219,8 @@ class TestTraining:
     def test_set_gen(self, init_without_run):
         assert init_without_run.train_set is None
         init_without_run._set_gen("train")
-        assert isinstance(init_without_run.train_set, Distributor)
-        assert init_without_run.train_set.generator.return_value == "mock_train_gen"
+        assert isinstance(init_without_run.train_set, KerasIterator)
+        assert init_without_run.train_set._collection.return_value == "mock_train_gen"
 
     def test_set_generators(self, init_without_run):
         sets = ["train", "val", "test"]
@@ -196,7 +228,7 @@ class TestTraining:
         init_without_run.set_generators()
         assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
         assert all(
-            [getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets])
+            [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets])
 
     def test_train(self, ready_to_train, path):
         assert not hasattr(ready_to_train.model, "history")
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 3da7a47871f6d92472de268d165d788c343ce394..d4a72674ae89ecd106ff1861aa6ee26567da3243 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -3,7 +3,7 @@ import pandas as pd
 import pytest
 import xarray as xr
 
-from src.helpers.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, centre_apply, \
+from mlair.helpers.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, centre_apply, \
     apply_inverse_transformation
 
 lazy = pytest.lazy_fixture