diff --git a/src/configuration/path_config.py b/src/configuration/path_config.py
index 7af25875eea58de081012fc6040a76a04f001d54..29dcd24e3626aca2ad3f24612399c24469eb3218 100644
--- a/src/configuration/path_config.py
+++ b/src/configuration/path_config.py
@@ -31,13 +31,13 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str:
     elif hostname == "zam347":
         path = f"/home/{user}/Data/toar_{sampling}/"
     elif hostname == "linux-aa9b":
-        path = f"/home/{user}/machinelearningtools/data/toar_{sampling}/"
+        path = f"/home/{user}/mlair/data/toar_{sampling}/"
     elif (len(hostname) > 2) and (hostname[:2] == "jr"):
         path = f"/p/project/cjjsc42/{user}/DATA/toar_{sampling}/"
     elif (len(hostname) > 2) and (hostname[:2] in ['jw', 'ju'] or hostname[:5] in ['hdfml']):
         path = f"/p/project/deepacf/intelliaq/{user}/DATA/toar_{sampling}/"
     elif runner_regex.match(hostname) is not None:
-        path = f"/home/{user}/machinelearningtools/data/toar_{sampling}/"
+        path = f"/home/{user}/mlair/data/toar_{sampling}/"
     else:
         raise OSError(f"unknown host '{hostname}'")
     if not os.path.exists(path):
diff --git a/src/data_handling/__init__.py b/src/data_handling/__init__.py
index cb5aa5db0f29cf51d32ed54e810fa9b363d80cc6..1bd380d35cae73ba7dee2c2a10214483ab0ed62d 100644
--- a/src/data_handling/__init__.py
+++ b/src/data_handling/__init__.py
@@ -13,3 +13,6 @@ from .bootstraps import BootStraps
 from .data_preparation_join import DataPrepJoin
 from .data_generator import DataGenerator
 from .data_distributor import Distributor
+from .iterator import KerasIterator, DataCollection
+from .advanced_data_handling import DefaultDataPreparation
+from .data_preparation import StationPrep
\ No newline at end of file
diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5c5214de9fe03f9e6c5de6e2ddcfbfb9987d052
--- /dev/null
+++ b/src/data_handling/advanced_data_handling.py
@@ -0,0 +1,306 @@
+
+__author__ = 'Lukas Leufen'
+__date__ = '2020-07-08'
+
+
+from src.helpers import to_list, remove_items
+import numpy as np
+import xarray as xr
+import pickle
+import os
+import pandas as pd
+import datetime as dt
+import shutil
+import inspect
+
+from typing import Union, List, Tuple
+import logging
+from functools import reduce
+from src.data_handling.data_preparation import StationPrep
+
+
+number = Union[float, int]
+num_or_list = Union[number, List[number]]
+
+
+class DummyDataSingleStation:  # pragma: no cover
+
+    def __init__(self, name, number_of_samples=None):
+        self.name = name
+        self.number_of_samples = number_of_samples if number_of_samples is not None else np.random.randint(100, 150)
+
+    def get_X(self):
+        X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5))  # samples, window, variables
+        datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
+        return xr.DataArray(X1, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
+                                                                                  "window": range(14),
+                                                                                  "variables": range(5)})
+
+    def get_Y(self):
+        Y1 = np.round(0.5 * np.random.randn(self.number_of_samples, 5, 1), 1)  # samples, window, variables
+        datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
+        return xr.DataArray(Y1, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
+                                                                                  "window": range(5),
+                                                                                  "variables": range(1)})
+
+    def __str__(self):
+        return self.name
+
+
+class AbstractDataPreparation:
+
+    _requirements = []
+
+    def __init__(self, *args, **kwargs):
+        pass
+
+    @classmethod
+    def build(cls, *args, **kwargs):
+        """Return initialised class."""
+        return cls(*args, **kwargs)
+
+    @classmethod
+    def requirements(cls):
+        """Return requirements and own arguments without duplicates."""
+        return list(set(cls._requirements + cls.own_args()))
+
+    @classmethod
+    def own_args(cls, *args):
+        return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args))
+
+    def get_X(self, upsampling=False, as_numpy=False):
+        raise NotImplementedError
+
+    def get_Y(self, upsampling=False, as_numpy=False):
+        raise NotImplementedError
+
+
+class DefaultDataPreparation(AbstractDataPreparation):
+
+    _requirements = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"])
+
+    def __init__(self, id_class, data_path, min_length=0,
+                 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
+        super().__init__()
+        self.id_class = id_class
+        self.interpolate_dim = "datetime"
+        self.min_length = min_length
+        self._X = None
+        self._Y = None
+        self._X_extreme = None
+        self._Y_extreme = None
+        self._save_file = os.path.join(data_path, f"data_preparation_{str(self.id_class)}.pickle")
+        self._collection = self._create_collection()
+        self.harmonise_X()
+        self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolate_dim)
+        self._store(fresh_store=True)
+
+    @classmethod
+    def build(cls, station, **kwargs):
+        sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
+        sp = StationPrep(station, **sp_keys)
+        dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs}
+        return cls(sp, **dp_args)
+
+    def _create_collection(self):
+        return [self.id_class]
+
+    @classmethod
+    def requirements(cls):
+        return remove_items(super().requirements(), "id_class")
+
+    def _reset_data(self):
+        self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
+
+    def _cleanup(self):
+        directory = os.path.dirname(self._save_file)
+        if os.path.exists(directory) is False:
+            os.makedirs(directory)
+        if os.path.exists(self._save_file):
+            shutil.rmtree(self._save_file, ignore_errors=True)
+
+    def _store(self, fresh_store=False):
+        self._cleanup() if fresh_store is True else None
+        data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
+        with open(self._save_file, "wb") as f:
+            pickle.dump(data, f)
+        logging.debug(f"save pickle data to {self._save_file}")
+        self._reset_data()
+
+    def _load(self):
+        try:
+            with open(self._save_file, "rb") as f:
+                data = pickle.load(f)
+            logging.debug(f"load pickle data from {self._save_file}")
+            self._X, self._Y = data["X"], data["Y"]
+            self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"]
+        except FileNotFoundError:
+            pass
+
+    def get_data(self, upsampling=False, as_numpy=True):
+        self._load()
+        X = self.get_X(upsampling, as_numpy)
+        Y = self.get_Y(upsampling, as_numpy)
+        self._reset_data()
+        return X, Y
+
+    def __repr__(self):
+        return ";".join(list(map(lambda x: str(x), self._collection)))
+
+    def get_X_original(self):
+        X = []
+        for data in self._collection:
+            X.append(data.get_X())
+        return X
+
+    def get_Y_original(self):
+        Y = self._collection[0].get_Y()
+        return Y
+
+    @staticmethod
+    def _to_numpy(d):
+        return list(map(lambda x: np.copy(x), d))
+
+    def get_X(self, upsampling=False, as_numpy=True):
+        no_data = (self._X is None)
+        self._load() if no_data is True else None
+        X = self._X if upsampling is False else self._X_extreme
+        self._reset_data() if no_data is True else None
+        return self._to_numpy(X) if as_numpy is True else X
+
+    def get_Y(self, upsampling=False, as_numpy=True):
+        no_data = (self._Y is None)
+        self._load() if no_data is True else None
+        Y = self._Y if upsampling is False else self._Y_extreme
+        self._reset_data() if no_data is True else None
+        return self._to_numpy([Y]) if as_numpy is True else Y
+
+    def harmonise_X(self):
+        X_original, Y_original = self.get_X_original(), self.get_Y_original()
+        dim = self.interpolate_dim
+        intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original))
+        if len(intersect) < max(self.min_length, 1):
+            X, Y = None, None
+        else:
+            X = list(map(lambda x: x.sel({dim: intersect}), X_original))
+            Y = Y_original.sel({dim: intersect})
+        self._X, self._Y = X, Y
+
+    def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
+                          timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"):
+        """
+        Multiply extremes.
+
+        This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can
+        also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of
+        floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised
+        space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be
+        extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is
+        used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can
+        identify those "artificial" data points later easily. Extreme inputs and labels are stored in
+        self.extremes_history and self.extreme_labels, respectively.
+
+        :param extreme_values: user definition of extreme
+        :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values,
+            if True only extract values larger than extreme_values
+        :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime
+        """
+        # check if X or Y is None
+        if (self._X is None) or (self._Y is None):
+            logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes")
+            return
+        if extreme_values is None:
+            logging.debug(f"No extreme values given, skip multiply extremes")
+            self._X_extreme, self._Y_extreme = self._X, self._Y
+            return
+
+        # check type if inputs
+        extreme_values = to_list(extreme_values)
+        for i in extreme_values:
+            if not isinstance(i, number.__args__):
+                raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
+                                f"{i} is type {type(i)}")
+
+        for extr_val in sorted(extreme_values):
+            # check if some extreme values are already extracted
+            if (self._X_extreme is None) or (self._Y_extreme is None):
+                X = self._X
+                Y = self._Y
+            else:  # one extr value iteration is done already: self.extremes_label is NOT None...
+                X = self._X_extreme
+                Y = self._Y_extreme
+
+            # extract extremes based on occurrence in labels
+            other_dims = remove_items(list(Y.dims), dim)
+            if extremes_on_right_tail_only:
+                extreme_idx = (Y > extr_val).any(dim=other_dims)
+            else:
+                extreme_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]),
+                                           (Y > extr_val).any(dim=other_dims[0])],
+                                          dim=other_dims[1]).any(dim=other_dims[1])
+
+            extremes_X = list(map(lambda x: x.sel(**{dim: extreme_idx}), X))
+            self._add_timedelta(extremes_X, dim, timedelta)
+            # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X))
+
+            extremes_Y = Y.sel(**{dim: extreme_idx})
+            extremes_Y.coords[dim].values += np.timedelta64(*timedelta)
+
+            self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
+            self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
+
+    @staticmethod
+    def _add_timedelta(data, dim, timedelta):
+        for d in data:
+            d.coords[dim].values += np.timedelta64(*timedelta)
+
+
+def run_data_prep():
+
+    data = DummyDataSingleStation("main_class")
+    data.get_X()
+    data.get_Y()
+
+    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+    data_prep = DataPreparation(DummyDataSingleStation("main_class"), path,
+                                neighbors=[DummyDataSingleStation("neighbor1"), DummyDataSingleStation("neighbor2")],
+                                extreme_values=[1., 1.2])
+    data_prep.get_data(upsampling=False)
+
+
+def create_data_prep():
+
+    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+    station_type = None
+    network = 'UBA'
+    sampling = 'daily'
+    target_dim = 'variables'
+    target_var = 'o3'
+    interpolate_dim = 'datetime'
+    window_history_size = 7
+    window_lead_time = 3
+    central_station = StationPrep("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {},station_type, network, sampling, target_dim,
+                                  target_var, interpolate_dim, window_history_size, window_lead_time)
+    neighbor1 = StationPrep("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},station_type, network, sampling, target_dim,
+                                  target_var, interpolate_dim, window_history_size, window_lead_time)
+    neighbor2 = StationPrep("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim,
+                                  target_var, interpolate_dim, window_history_size, window_lead_time)
+
+    data_prep = []
+    data_prep.append(DataPreparation(central_station, path, neighbors=[neighbor1, neighbor2]))
+    data_prep.append(DataPreparation(neighbor1, path, neighbors=[central_station, neighbor2]))
+    data_prep.append(DataPreparation(neighbor2, path, neighbors=[neighbor1, central_station]))
+    return data_prep
+
+
+if __name__ == "__main__":
+    from src.data_handling.data_preparation import StationPrep
+    from src.data_handling.iterator import KerasIterator, DataCollection
+    data_prep = create_data_prep()
+    data_collection = DataCollection(data_prep)
+    for data in data_collection:
+        print(data)
+    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras")
+    keras_it = KerasIterator(data_collection, 100, path, upsampling=True)
+    keras_it[2]
+
diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index bd4958d37510e850f13a5cf6ffef8bf0180a43e4..09c16c68196b09fc7c1fbe5ef4b2639b684205a4 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -68,8 +68,8 @@ class AbstractStationPrep():
 
 class StationPrep(AbstractStationPrep):
 
-    def __init__(self, path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var,
-                 interpolate_dim, window_history_size, window_lead_time, **kwargs):
+    def __init__(self, station, data_path, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var,
+                 interpolate_dim, window_history_size, window_lead_time, overwrite_local_data: bool = False, **kwargs):
         super().__init__()  # path, station, statistics_per_var, transformation, **kwargs)
         self.station_type = station_type
         self.network = network
@@ -80,12 +80,10 @@ class StationPrep(AbstractStationPrep):
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
 
-        self.path = os.path.abspath(path)
+        self.path = os.path.abspath(data_path)
         self.station = helpers.to_list(station)
         self.statistics_per_var = statistics_per_var
         # self.target_dim = 'variable'
-        self.transformation = self.setup_transformation(transformation)
-        self.kwargs = kwargs
 
         # internal
         self.data = None
@@ -95,8 +93,15 @@ class StationPrep(AbstractStationPrep):
         self.label = None
         self.observation = None
 
+        self.transformation = None  # self.setup_transformation(transformation)
+        self.kwargs = kwargs
+        self.kwargs["overwrite_local_data"] = overwrite_local_data
+
         self.make_samples()
 
+    def __str__(self):
+        return self.station[0]
+
     def __repr__(self):
         return f"StationPrep(path='{self.path}', station={self.station}, statistics_per_var={self.statistics_per_var}, " \
                f"transformation={self.transformation}, station_type='{self.station_type}', network='{self.network}', " \
diff --git a/src/data_handling/data_preparation_neighbors.py b/src/data_handling/data_preparation_neighbors.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5b5c3c436ef057244544248e6c7deedafbf0c4b
--- /dev/null
+++ b/src/data_handling/data_preparation_neighbors.py
@@ -0,0 +1,67 @@
+
+__author__ = 'Lukas Leufen'
+__date__ = '2020-07-17'
+
+
+from src.helpers import to_list, remove_items
+from src.data_handling.data_preparation import StationPrep
+from src.data_handling.advanced_data_handling import AbstractDataPreparation, DefaultDataPreparation
+import numpy as np
+import xarray as xr
+import pickle
+import os
+import shutil
+import inspect
+
+from typing import Union, List, Tuple
+import logging
+from functools import reduce
+
+number = Union[float, int]
+num_or_list = Union[number, List[number]]
+
+
+class DataPreparationNeighbors(DefaultDataPreparation):
+
+    def __init__(self, id_class, data_path, neighbors=None, min_length=0,
+                 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
+        self.neighbors = to_list(neighbors) if neighbors is not None else []
+        super().__init__(id_class, data_path, min_length=min_length, extreme_values=extreme_values,
+                         extremes_on_right_tail_only=extremes_on_right_tail_only)
+
+    @classmethod
+    def build(cls, station, **kwargs):
+        sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
+        sp = StationPrep(station, **sp_keys)
+        n_list = []
+        for neighbor in kwargs.get("neighbors", []):
+            n_list.append(StationPrep(neighbor, **sp_keys))
+        else:
+            kwargs["neighbors"] = n_list if len(n_list) > 0 else None
+        dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs}
+        return cls(sp, **dp_args)
+
+    def _create_collection(self):
+        return [self.id_class] + self.neighbors
+
+
+if __name__ == "__main__":
+
+    a = DataPreparationNeighbors
+    requirements = a.requirements()
+
+    kwargs = {"path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
+              "station_type": None,
+              "network": 'UBA',
+              "sampling": 'daily',
+              "target_dim": 'variables',
+              "target_var": 'o3',
+              "interpolate_dim": 'datetime',
+              "window_history_size": 7,
+              "window_lead_time": 3,
+              "neighbors": ["DEBW034"],
+              "data_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
+              "statistics_per_var":  {'o3': 'dma8eu', 'temp': 'maximum'},
+              "transformation": None,}
+    a_inst = a.build("DEBW011", **kwargs)
+    print(a_inst)
diff --git a/src/data_handling/iterator.py b/src/data_handling/iterator.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eec4326723086c5c36dcbcbb4db37976a2a3b0a
--- /dev/null
+++ b/src/data_handling/iterator.py
@@ -0,0 +1,200 @@
+
+__author__ = 'Lukas Leufen'
+__date__ = '2020-07-07'
+
+from collections import Iterator, Iterable
+import keras
+import numpy as np
+import math
+import os
+import shutil
+import pickle
+from typing import Tuple, List
+
+
+class StandardIterator(Iterator):
+
+    _position: int = None
+
+    def __init__(self, collection: list):
+        assert isinstance(collection, list)
+        self._collection = collection
+        self._position = 0
+
+    def __next__(self):
+        """Return next element or stop iteration."""
+        try:
+            value = self._collection[self._position]
+            self._position += 1
+        except IndexError:
+            raise StopIteration()
+        return value
+
+
+class DataCollection(Iterable):
+
+    def __init__(self, collection: list = None):
+        if collection is None:
+            collection = []
+        assert isinstance(collection, list)
+        self._collection = collection
+
+    def __len__(self):
+        return len(self._collection)
+
+    def __iter__(self) -> Iterator:
+        return StandardIterator(self._collection)
+
+    def __getitem__(self, index):
+        return self._collection[index]
+
+    def add(self, element):
+        self._collection.append(element)
+
+
+class KerasIterator(keras.utils.Sequence):
+
+    def __init__(self, collection: DataCollection, batch_size: int, batch_path: str, shuffle_batches: bool = False,
+                 model=None, upsampling=False, name=None):
+        self._collection = collection
+        batch_path = os.path.join(batch_path, str(name if name is not None else id(self)))
+        self._path = os.path.join(batch_path, "%i.pickle")
+        self.batch_size = batch_size
+        self.model = model
+        self.shuffle = shuffle_batches
+        self.upsampling = upsampling
+        self.indexes: list = []
+        self._cleanup_path(batch_path)
+        self._prepare_batches()
+
+    def __len__(self) -> int:
+        return len(self.indexes)
+
+    def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
+        """Get batch for given index."""
+        return self.__data_generation(self.indexes[index])
+
+    def _get_model_rank(self):
+        if self.model is not None:
+            mod_out = self.model.output_shape
+            if isinstance(mod_out, tuple):  # only one output branch: (None, ahead)
+                mod_rank = 1
+            elif isinstance(mod_out, list):  # multiple output branches, e.g.: [(None, ahead), (None, ahead)]
+                mod_rank = len(mod_out)
+            else:  # pragma: no cover
+                raise TypeError("model output shape must either be tuple or list.")
+            return mod_rank
+        else:  # no model provided, assume to use single output
+            return 1
+
+    def __data_generation(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
+        """Load pickle data from disk."""
+        file = self._path % index
+        with open(file, "rb") as f:
+            data = pickle.load(f)
+        return data["X"], data["Y"]
+
+    @staticmethod
+    def _concatenate(new: List[np.ndarray], old: List[np.ndarray]) -> List[np.ndarray]:
+        """Concatenate two lists of data along axis=0."""
+        return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new))
+
+    def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]:
+        """Get batch according to batch size from data list."""
+        return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list))
+
+    def _permute_data(self, X, Y):
+        p = np.random.permutation(len(X[0]))  # equiv to .shape[0]
+        X = list(map(lambda x: x[p], X))
+        Y = list(map(lambda x: x[p], Y))
+        return X, Y
+
+    def _prepare_batches(self) -> None:
+        """
+        Prepare all batches as locally stored files.
+
+        Walk through all elements of collection and split (or merge) data according to the batch size. Too long data
+        sets are divided into multiple batches. Not fully filled batches are merged with data from the next collection
+        element. If data is remaining after the last element, it is saved as smaller batch. All batches are enumerated
+        beginning from 0. A list with all batch numbers is stored in class's parameter indexes.
+        """
+        index = 0
+        remaining = None
+        mod_rank = self._get_model_rank()
+        for data in self._collection:
+            X = data.get_X(upsampling=self.upsampling)
+            Y = [data.get_Y(upsampling=self.upsampling)[0] for _ in range(mod_rank)]
+            if self.upsampling:
+                X, Y = self._permute_data(X, Y)
+            if remaining is not None:
+                X, Y = self._concatenate(X, remaining[0]), self._concatenate(Y, remaining[1])
+            length = X[0].shape[0]
+            batches = self._get_number_of_mini_batches(length)
+            for b in range(batches):
+                batch_X, batch_Y = self._get_batch(X, b), self._get_batch(Y, b)
+                self._save_to_pickle(X=batch_X, Y=batch_Y, index=index)
+                index += 1
+            if (batches * self.batch_size) < length:  # keep remaining to concatenate with next data element
+                remaining = (self._get_batch(X, batches), self._get_batch(Y, batches))
+            else:
+                remaining = None
+        if remaining is not None:  # add remaining as smaller batch
+            self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index)
+            index += 1
+        self.indexes = np.arange(0, index).tolist()
+
+    def _save_to_pickle(self, X: List[np.ndarray], Y: List[np.ndarray], index: int) -> None:
+        """Save data as pickle file with variables X and Y and given index as <index>.pickle ."""
+        data = {"X": X, "Y": Y}
+        file = self._path % index
+        with open(file, "wb") as f:
+            pickle.dump(data, f)
+
+    def _get_number_of_mini_batches(self, number_of_samples: int) -> int:
+        """Return number of mini batches as the floored ration of number of samples to batch size."""
+        return math.floor(number_of_samples / self.batch_size)
+
+    @staticmethod
+    def _cleanup_path(path: str, create_new: bool = True) -> None:
+        """First remove existing path, second create empty path if enabled."""
+        if os.path.exists(path):
+            shutil.rmtree(path)
+        if create_new is True:
+            os.makedirs(path)
+
+    def on_epoch_end(self) -> None:
+        """Randomly shuffle indexes if enabled."""
+        if self.shuffle is True:
+            np.random.shuffle(self.indexes)
+
+
+class DummyData:  # pragma: no cover
+
+    def __init__(self, number_of_samples=np.random.randint(100, 150)):
+        self.number_of_samples = number_of_samples
+
+    def get_X(self):
+        X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5))  # samples, window, variables
+        X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2))  # samples, window, variables
+        X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2))  # samples, window, variables
+        return [X1, X2, X3]
+
+    def get_Y(self):
+        Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1))  # samples, window, variables
+        Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1))  # samples, window, variables
+        return [Y1, Y2]
+
+
+if __name__ == "__main__":
+
+    collection = []
+    for _ in range(3):
+        collection.append(DummyData(50))
+
+    data_collection = DataCollection(collection=collection)
+
+    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+    iterator = KerasIterator(data_collection, 25, path, shuffle=True)
+
+    for data in data_collection:
+        print(data)
\ No newline at end of file
diff --git a/src/model_modules/linear_model.py b/src/model_modules/linear_model.py
index e556f0358a2a5e5247f7b6cc7d416af25a8a664d..85c062119d881635bf54aadb7491a3c0298e64f5 100644
--- a/src/model_modules/linear_model.py
+++ b/src/model_modules/linear_model.py
@@ -42,21 +42,27 @@ class OrdinaryLeastSquaredModel:
         return self.ordinary_least_squared_model(self.x, self.y)
 
     def _set_x_y_from_generator(self):
-        data_x = None
-        data_y = None
+        data_x, data_y = None, None
         for item in self.generator:
-            x = self.reshape_xarray_to_numpy(item[0])
-            y = item[1].values
-            data_x = np.concatenate((data_x, x), axis=0) if data_x is not None else x
-            data_y = np.concatenate((data_y, y), axis=0) if data_y is not None else y
-        self.x = data_x
-        self.y = data_y
+            x, y = item.get_data(as_numpy=True)
+            x = self.flatten(x)
+            data_x = self._concatenate(x, data_x)
+            data_y = self._concatenate(y, data_y)
+        self.x, self.y = np.concatenate(data_x, axis=1), data_y[0]
+
+    def _concatenate(self, new, old):
+        return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new)) if old is not None else new
 
     def predict(self, data):
         """Apply OLS model on data."""
         data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add")
         return np.atleast_2d(self.model.predict(data))
 
+    @staticmethod
+    def flatten(data):
+        shapes = list(map(lambda x: x.shape, data))
+        return list(map(lambda x, shape: x.reshape(shape[0], -1), data, shapes))
+
     @staticmethod
     def reshape_xarray_to_numpy(data):
         """Reshape xarray data to numpy data and flatten."""
diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index dab2e168c5a9f87d4aee42fc94489fd0fa67772a..ca54840c8b995a4719041e2a8bc9ccd46351a89f 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -133,7 +133,7 @@ class AbstractModelClass(ABC):
     the corresponding loss function.
     """
 
-    def __init__(self) -> None:
+    def __init__(self, shape_inputs, shape_outputs) -> None:
         """Predefine internal attributes for model and loss."""
         self.__model = None
         self.model_name = self.__class__.__name__
@@ -147,6 +147,8 @@ class AbstractModelClass(ABC):
                                           'target_tensors': None
                                           }
         self.__compile_options = self.__allowed_compile_options
+        self.shape_inputs = shape_inputs
+        self.shape_outputs = self.__extract_from_tuple(shape_outputs)
 
     def __getattr__(self, name: str) -> Any:
         """
@@ -267,6 +269,11 @@ class AbstractModelClass(ABC):
                 raise ValueError(
                     f"Got different values or arguments for same argument: self.{allow_k}={new_v_attr.__class__} and '{allow_k}': {new_v_dic.__class__}")
 
+    @staticmethod
+    def __extract_from_tuple(tup):
+        """Return element of tuple if it contains only a single element."""
+        return tup[0] if isinstance(tup, tuple) and len(tup) == 1 else tup
+
     @staticmethod
     def __compare_keras_optimizers(first, second):
         if first.__class__ == second.__class__ and first.__module__ == 'keras.optimizers':
@@ -334,24 +341,19 @@ class MyLittleModel(AbstractModelClass):
     Dense layer.
     """
 
-    def __init__(self, window_history_size, window_lead_time, channels):
+    def __init__(self, shape_inputs: list, shape_outputs: list):
         """
         Sets model and loss depending on the given arguments.
 
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param regularizer: <not used here>
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
+        :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast))
         """
 
-        super().__init__()
+        assert len(shape_inputs) == 1
+        assert len(shape_outputs) == 1
+        super().__init__(shape_inputs[0], shape_outputs[0])
 
         # settings
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-        self.channels = channels
         self.dropout_rate = 0.1
         self.regularizer = keras.regularizers.l2(0.1)
         self.activation = keras.layers.PReLU
@@ -364,17 +366,10 @@ class MyLittleModel(AbstractModelClass):
     def set_model(self):
         """
         Build the model.
-
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
-        :return: built keras model
         """
 
         # add 1 to window_size to include current time step t0
-        x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels))
+        x_input = keras.layers.Input(shape=self.shape_inputs)
         x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input)
         x_in = self.activation(name='{}_conv_act'.format("major"))(x_in)
         x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
@@ -385,16 +380,16 @@ class MyLittleModel(AbstractModelClass):
         x_in = self.activation()(x_in)
         x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
         x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
+        x_in = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("major"))(x_in)
         out_main = self.activation()(x_in)
         self.model = keras.Model(inputs=x_input, outputs=[out_main])
 
     def set_compile_options(self):
         self.initial_lr = 1e-2
-        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
+        self.optimizer = keras.optimizers.adam(lr=self.initial_lr)
         self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
                                                                              epochs_drop=10)
-        self.compile_options = {"loss": keras.losses.mean_squared_error, "metrics": ["mse", "mae"]}
+        self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
 
 
 class MyBranchedModel(AbstractModelClass):
@@ -406,24 +401,19 @@ class MyBranchedModel(AbstractModelClass):
     Dense layer.
     """
 
-    def __init__(self, window_history_size, window_lead_time, channels):
+    def __init__(self, shape_inputs: list, shape_outputs: list):
         """
         Sets model and loss depending on the given arguments.
 
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param regularizer: <not used here>
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
+        :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast))
         """
 
-        super().__init__()
+        assert len(shape_inputs) == 1
+        assert len(shape_outputs) == 1
+        super().__init__(shape_inputs[0], shape_outputs[0])
 
         # settings
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-        self.channels = channels
         self.dropout_rate = 0.1
         self.regularizer = keras.regularizers.l2(0.1)
         self.activation = keras.layers.PReLU
@@ -436,32 +426,25 @@ class MyBranchedModel(AbstractModelClass):
     def set_model(self):
         """
         Build the model.
-
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
-        :return: built keras model
         """
 
         # add 1 to window_size to include current time step t0
-        x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels))
+        x_input = keras.layers.Input(shape=self.shape_inputs)
         x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input)
         x_in = self.activation(name='{}_conv_act'.format("major"))(x_in)
         x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
         x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in)
         x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in)
         x_in = self.activation()(x_in)
-        out_minor_1 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_1"))(x_in)
+        out_minor_1 = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("minor_1"))(x_in)
         out_minor_1 = self.activation(name="minor_1")(out_minor_1)
         x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in)
         x_in = self.activation()(x_in)
-        out_minor_2 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_2"))(x_in)
+        out_minor_2 = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("minor_2"))(x_in)
         out_minor_2 = self.activation(name="minor_2")(out_minor_2)
         x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
         x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
+        x_in = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("major"))(x_in)
         out_main = self.activation(name="main")(x_in)
         self.model = keras.Model(inputs=x_input, outputs=[out_minor_1, out_minor_2, out_main])
 
@@ -476,24 +459,19 @@ class MyBranchedModel(AbstractModelClass):
 
 class MyTowerModel(AbstractModelClass):
 
-    def __init__(self, window_history_size, window_lead_time, channels):
+    def __init__(self, shape_inputs: list, shape_outputs: list):
         """
         Sets model and loss depending on the given arguments.
 
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param regularizer: <not used here>
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
+        :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast))
         """
 
-        super().__init__()
+        assert len(shape_inputs) == 1
+        assert len(shape_outputs) == 1
+        super().__init__(shape_inputs[0], shape_outputs[0])
 
         # settings
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-        self.channels = channels
         self.dropout_rate = 1e-2
         self.regularizer = keras.regularizers.l2(0.1)
         self.initial_lr = 1e-2
@@ -509,13 +487,6 @@ class MyTowerModel(AbstractModelClass):
     def set_model(self):
         """
         Build the model.
-
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
-        :return: built keras model
         """
         activation = self.activation
         conv_settings_dict1 = {
@@ -549,9 +520,7 @@ class MyTowerModel(AbstractModelClass):
         ##########################################
         inception_model = InceptionModelBase()
 
-        X_input = keras.layers.Input(
-            shape=(
-            self.window_history_size + 1, 1, self.channels))  # add 1 to window_size to include current time step t0
+        X_input = keras.layers.Input(shape=self.shape_inputs)
 
         X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1,
                                                regularizer=self.regularizer,
@@ -573,7 +542,7 @@ class MyTowerModel(AbstractModelClass):
         # out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=self.dropout_rate,
         #                         reduction_filter=64, inner_neurons=64, output_neurons=self.window_lead_time)
 
-        out_main = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time,
+        out_main = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.shape_outputs,
                                 output_activation='linear', reduction_filter=64,
                                 name='Main', bound_weight=True, dropout_rate=self.dropout_rate,
                                 kernel_regularizer=self.regularizer
@@ -588,24 +557,19 @@ class MyTowerModel(AbstractModelClass):
 
 class MyPaperModel(AbstractModelClass):
 
-    def __init__(self, window_history_size, window_lead_time, channels):
+    def __init__(self, shape_inputs: list, shape_outputs: list):
         """
         Sets model and loss depending on the given arguments.
 
-        :param activation: activation function
-        :param window_history_size: number of historical time steps included in the input data
-        :param channels: number of variables used in input data
-        :param regularizer: <not used here>
-        :param dropout_rate: dropout rate used in the model [0, 1)
-        :param window_lead_time: number of time steps to forecast in the output layer
+        :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast))
         """
 
-        super().__init__()
+        assert len(shape_inputs) == 1
+        assert len(shape_outputs) == 1
+        super().__init__(shape_inputs[0], shape_outputs[0])
 
         # settings
-        self.window_history_size = window_history_size
-        self.window_lead_time = window_lead_time
-        self.channels = channels
         self.dropout_rate = .3
         self.regularizer = keras.regularizers.l2(0.001)
         self.initial_lr = 1e-3
@@ -670,9 +634,7 @@ class MyPaperModel(AbstractModelClass):
         ##########################################
         inception_model = InceptionModelBase()
 
-        X_input = keras.layers.Input(
-            shape=(
-            self.window_history_size + 1, 1, self.channels))  # add 1 to window_size to include current time step t0
+        X_input = keras.layers.Input(shape=self.shape_inputs)
 
         pad_size = PadUtils.get_padding_for_same(first_kernel)
         # X_in = adv_pad.SymmetricPadding2D(padding=pad_size)(X_input)
@@ -690,7 +652,7 @@ class MyPaperModel(AbstractModelClass):
                                                padding=self.padding)
         # out_minor1 = flatten_tail(X_in, 'minor_1', False, self.dropout_rate, self.window_lead_time,
         #                           self.activation, 32, 64)
-        out_minor1 = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time,
+        out_minor1 = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.shape_outputs,
                                   output_activation='linear', reduction_filter=32,
                                   name='minor_1', bound_weight=False, dropout_rate=self.dropout_rate,
                                   kernel_regularizer=self.regularizer
@@ -708,7 +670,7 @@ class MyPaperModel(AbstractModelClass):
         #                                        batch_normalisation=True)
         #############################################
 
-        out_main = flatten_tail(X_in, inner_neurons=64 * 2, activation=activation, output_neurons=self.window_lead_time,
+        out_main = flatten_tail(X_in, inner_neurons=64 * 2, activation=activation, output_neurons=self.shape_outputs,
                                 output_activation='linear',  reduction_filter=64 * 2,
                                 name='Main', bound_weight=False, dropout_rate=self.dropout_rate,
                                 kernel_regularizer=self.regularizer
diff --git a/src/run.py b/src/run.py
index 7e262dd769204077697b7df3f3fbaedb4c012257..8a4ade33c0e5b260fafab58e76cf753455077d50 100644
--- a/src/run.py
+++ b/src/run.py
@@ -29,7 +29,7 @@ def run(stations=None,
         model=None,
         batch_size=None,
         epochs=None,
-        data_preparation=None):
+        data_preparation=None,):
 
     params = inspect.getfullargspec(DefaultWorkflow).args
     kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
@@ -39,5 +39,4 @@ def run(stations=None,
 
 
 if __name__ == "__main__":
-
-    run()
+    run(stations=["DEBW013","DEBW025"], statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True, create_new_model=True)
diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index 1d375c32be06b583abbfb06a20ea482e6775b232..15b5c4c6e9d01284d108284365546f1eac9804c1 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -18,7 +18,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D
     DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
     DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST
-from src.data_handling import DataPrepJoin
+from src.data_handling.advanced_data_handling import DefaultDataPreparation
 from src.run_modules.run_environment import RunEnvironment
 from src.model_modules.model_class import MyLittleModel as VanillaModel
 
@@ -228,8 +228,8 @@ class ExperimentSetup(RunEnvironment):
                  create_new_model = None, bootstrap_path=None, permute_data_on_training = None, transformation=None,
                  train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
                  extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None,
-                 create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None,
-                 batch_size=None, epochs=None, data_preparation=None):
+                 create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
+                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_preparation=None):
 
         # create run framework
         super().__init__()
@@ -265,6 +265,9 @@ class ExperimentSetup(RunEnvironment):
         logging.info(f"Experiment path is: {experiment_path}")
         path_config.check_path_and_create(self.data_store.get("experiment_path"))
 
+        # batch path (temporary)
+        self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data"))
+
         # set model path
         self._set_param("model_path", None, os.path.join(experiment_path, "model"))
         path_config.check_path_and_create(self.data_store.get("model_path"))
@@ -297,7 +300,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("sampling", sampling)
         self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
         self._set_param("transformation", None, scope="preprocessing")
-        self._set_param("data_preparation", data_preparation, default=DataPrepJoin)
+        self._set_param("data_preparation", data_preparation, default=DefaultDataPreparation)
 
         # target
         self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
@@ -344,6 +347,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS,
                         scope="general.postprocessing")
         self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
+        self._set_param("neighbors", ["DEBW030"])  # TODO: just for testing
 
         # check variables, statistics and target variable
         self._check_target_var()
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index f9683b953d85bacf6e452e0a1922e85dfe946cd1..7de1c7b6adfccf705539dd32765af149da5f6508 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -70,7 +70,7 @@ class ModelSetup(RunEnvironment):
     def _run(self):
 
         # set channels depending on inputs
-        self._set_channels()
+        self._set_shapes()
 
         # build model graph using settings from my_model_settings()
         self.build_model()
@@ -88,10 +88,12 @@ class ModelSetup(RunEnvironment):
         # compile model
         self.compile_model()
 
-    def _set_channels(self):
-        """Set channels as number of variables of train generator."""
-        channels = self.data_store.get("generator", "train")[0][0].shape[-1]
-        self.data_store.set("channels", channels, self.scope)
+    def _set_shapes(self):
+        """Set input and output shapes from train collection."""
+        shape = list(map(lambda x: x.shape[1:], self.data_store.get("data_collection", "train")[0].get_X()))
+        self.data_store.set("shape_inputs", shape, self.scope)
+        shape = list(map(lambda y: y.shape[1:], self.data_store.get("data_collection", "train")[0].get_Y()))
+        self.data_store.set("shape_outputs", shape, self.scope)
 
     def compile_model(self):
         """
@@ -128,8 +130,8 @@ class ModelSetup(RunEnvironment):
             logging.info('no weights to reload...')
 
     def build_model(self):
-        """Build model using window_history_size, window_lead_time and channels from data store."""
-        args_list = ["window_history_size", "window_lead_time", "channels"]
+        """Build model using input and output shapes from data store."""
+        args_list = ["shape_inputs", "shape_outputs"]
         args = self.data_store.create_args_dict(args_list, self.scope)
         model = self.data_store.get("model_class")
         self.model = model(**args)
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index b97d28c1cf71d35526207450d6b0bb386ddefdb7..2512244c8f9516becfb0edec48a3c9f82e5643de 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -13,7 +13,7 @@ import numpy as np
 import pandas as pd
 import xarray as xr
 
-from src.data_handling import BootStraps, Distributor, DataGenerator, DataPrepJoin
+from src.data_handling import BootStraps, Distributor, DataGenerator, DataPrepJoin, KerasIterator
 from src.helpers.datastore import NameNotFoundInDataStore
 from src.helpers import TimeTracking, statistics
 from src.model_modules.linear_model import OrdinaryLeastSquaredModel
@@ -65,11 +65,12 @@ class PostProcessing(RunEnvironment):
         self.model: keras.Model = self._load_model()
         self.ols_model = None
         self.batch_size: int = self.data_store.get_default("batch_size", "model", 64)
-        self.test_data: DataGenerator = self.data_store.get("generator", "test")
-        self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size)
-        self.train_data: DataGenerator = self.data_store.get("generator", "train")
-        self.val_data: DataGenerator = self.data_store.get("generator", "val")
-        self.train_val_data: DataGenerator = self.data_store.get("generator", "train_val")
+        self.test_data = self.data_store.get("data_collection", "test")
+        batch_path = self.data_store.get("batch_path", scope="test")
+        self.test_data_distributed = KerasIterator(self.test_data, self.batch_size, model=self.model, name="test", batch_path=batch_path)
+        self.train_data = self.data_store.get("data_collection", "train")
+        self.val_data = self.data_store.get("data_collection", "val")
+        self.train_val_data = self.data_store.get("data_collection", "train_val")
         self.plot_path: str = self.data_store.get("plot_path")
         self.target_var = self.data_store.get("target_var")
         self._sampling = self.data_store.get("sampling")
@@ -311,17 +312,17 @@ class PostProcessing(RunEnvironment):
         be found inside `forecast_path`.
         """
         logging.debug("start make_prediction")
-        for i, _ in enumerate(self.test_data):
-            data = self.test_data.get_data_generator(i)
-            input_data = data.get_transposed_history()
+        for i, data in enumerate(self.test_data):
+            input_data = data.get_X()
+            target_data = data.get_Y()
 
             # get scaling parameters
-            mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
+            # mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
 
             for normalised in [True, False]:
                 # create empty arrays
                 nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
-                    data, count=4)
+                    target_data, count=4)
 
                 # nn forecast
                 nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method,
@@ -459,8 +460,8 @@ class PostProcessing(RunEnvironment):
         return nn_prediction
 
     @staticmethod
-    def _create_empty_prediction_arrays(generator, count=1):
-        return [generator.label.copy() for _ in range(count)]
+    def _create_empty_prediction_arrays(target_data, count=1):
+        return [target_data.copy() for _ in range(count)]
 
     @staticmethod
     def create_fullindex(df: Union[xr.DataArray, pd.DataFrame, pd.DatetimeIndex], freq: str) -> pd.DataFrame:
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index db7fff2ab9e385ce769f86ef95d1565ea783cc95..c6ea67b87fc33a0952a5123754ab3fea62eee488 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -11,6 +11,7 @@ import numpy as np
 import pandas as pd
 
 from src.data_handling import DataGenerator
+from src.data_handling import DataCollection
 from src.helpers import TimeTracking
 from src.configuration import path_config
 from src.helpers.join import EmptyQueryResult
@@ -59,10 +60,9 @@ class PreProcessing(RunEnvironment):
         self._run()
 
     def _run(self):
-        args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="preprocessing")
-        kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="preprocessing")
         stations = self.data_store.get("stations")
-        valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False, name="all")
+        data_preparation = self.data_store.get("data_preparation")
+        _, valid_stations = self.validate_station(data_preparation, stations, "preprocessing", overwrite_local_data=True)
         self.data_store.set("stations", valid_stations)
         self.split_train_val_test()
         self.report_pre_processing()
@@ -70,16 +70,14 @@ class PreProcessing(RunEnvironment):
     def report_pre_processing(self):
         """Log some metrics on data and create latex report."""
         logging.debug(20 * '##')
-        n_train = len(self.data_store.get('generator', 'train'))
-        n_val = len(self.data_store.get('generator', 'val'))
-        n_test = len(self.data_store.get('generator', 'test'))
+        n_train = len(self.data_store.get('data_collection', 'train'))
+        n_val = len(self.data_store.get('data_collection', 'val'))
+        n_test = len(self.data_store.get('data_collection', 'test'))
         n_total = n_train + n_val + n_test
         logging.debug(f"Number of all stations: {n_total}")
         logging.debug(f"Number of training stations: {n_train}")
         logging.debug(f"Number of val stations: {n_val}")
         logging.debug(f"Number of test stations: {n_test}")
-        logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'test')[0][0].shape}"
-                      f"{self.data_store.get('generator', 'test')[0][1].shape}")
         self.create_latex_report()
 
     def create_latex_report(self):
@@ -121,11 +119,12 @@ class PreProcessing(RunEnvironment):
         set_names = ["train", "val", "test"]
         df = pd.DataFrame(columns=meta_data + set_names)
         for set_name in set_names:
-            data: DataGenerator = self.data_store.get("generator", set_name)
-            for station in data.stations:
-                df.loc[station, set_name] = data.get_data_generator(station).get_transposed_label().shape[0]
-                if df.loc[station, meta_data].isnull().any():
-                    df.loc[station, meta_data] = data.get_data_generator(station).meta.loc[meta_data].values.flatten()
+            data = self.data_store.get("data_collection", set_name)
+            for station in data:
+                station_name = str(station.id_class)
+                df.loc[station_name, set_name] = station.get_Y()[0].shape[0]
+                if df.loc[station_name, meta_data].isnull().any():
+                    df.loc[station_name, meta_data] = station.id_class.meta.loc[meta_data].values.flatten()
             df.loc["# Samples", set_name] = df.loc[:, set_name].sum()
             df.loc["# Stations", set_name] = df.loc[:, set_name].count()
         df[meta_round] = df[meta_round].astype(float).round(precision)
@@ -147,7 +146,7 @@ class PreProcessing(RunEnvironment):
         Split data into subsets.
 
         Currently: train, val, test and train_val (actually this is only the merge of train and val, but as an separate
-        generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs
+        data_collection). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs
         always to be executed at first, to set a proper transformation.
         """
         fraction_of_training = self.data_store.get("fraction_of_training")
@@ -159,7 +158,7 @@ class PreProcessing(RunEnvironment):
             raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset"
                                  f"order was: {subset_names}.")
         for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names):
-            self.create_set_split(ind, scope)
+            self.create_set_split_new(ind, scope)
 
     @staticmethod
     def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]:
@@ -183,37 +182,19 @@ class PreProcessing(RunEnvironment):
         train_val_index = slice(0, pos_test_split)
         return train_index, val_index, test_index, train_val_index
 
-    def create_set_split(self, index_list: slice, set_name: str) -> None:
-        """
-        Create subsets and store in data store.
-
-        Create the subset for given split index and stores the DataGenerator with given set name in data store as
-        `generator`. Check for all valid stations using the default (kw)args for given scope and create the
-        DataGenerator for all valid stations. Also set all transformation information, if subset is training set. Make
-        sure, that the train set is executed first, and all other subsets afterwards.
-
-        :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True,
-            this list is ignored.
-        :param set_name: name to load/save all information from/to data store.
-        """
-        args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name)
-        kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope=set_name)
-        stations = args["stations"]
+    def create_set_split_new(self, index_list: slice, set_name: str) -> None:
+        # get set stations
+        stations = self.data_store.get("stations", scope=set_name)
         if self.data_store.get("use_all_stations_on_all_data_sets"):
             set_stations = stations
         else:
             set_stations = stations[index_list]
         logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
-        # validate set
-        set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False, name=set_name)
-        self.data_store.set("stations", set_stations, scope=set_name)
-        # create set generator and store
-        set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name)
-        data_set = DataGenerator(**set_args, **kwargs)
-        self.data_store.set("generator", data_set, scope=set_name)
-        # extract transformation from train set
-        if set_name == "train":
-            self.data_store.set("transformation", data_set.transformation)
+        # create set data_collection and store
+        data_preparation = self.data_store.get("data_preparation")
+        collection, valid_stations = self.validate_station(data_preparation, set_stations, set_name)
+        self.data_store.set("stations", valid_stations, scope=set_name)
+        self.data_store.set("data_collection", collection, scope=set_name)
 
     @staticmethod
     def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True,
@@ -257,3 +238,36 @@ class PreProcessing(RunEnvironment):
         logging.info(f"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/"
                      f"{len(all_stations)} valid stations.")
         return valid_stations
+
+    def validate_station(self, data_preparation, set_stations, set_name=None, overwrite_local_data=False):
+        """
+        Check if all given stations in `all_stations` are valid.
+
+        Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the
+        loading time are logged in debug mode.
+
+        :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
+            `variables`, `interpolate_dim`, `target_dim`, `target_var`).
+        :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`,
+            `window_lead_time`).
+        :param all_stations: All stations to check.
+        :param name: name to display in the logging info message
+
+        :return: Corrected list containing only valid station IDs.
+        """
+        t_outer = TimeTracking()
+        logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}")
+        collection = DataCollection()
+        valid_stations = []
+        kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name)
+        for station in set_stations:
+            try:
+                dp = data_preparation.build(station, **kwargs)
+                collection.add(dp)
+                valid_stations.append(station)
+            except (AttributeError, EmptyQueryResult):
+                continue
+        logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/"
+                     f"{len(set_stations)} valid stations.")
+        return collection, valid_stations
+
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 1a0d7beb1ec37bb5e59a4129da58572d79a73636..4ca0063cc4c6446e80db91626fe535e613cd7c82 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -11,7 +11,7 @@ from typing import Union
 import keras
 from keras.callbacks import Callback, History
 
-from src.data_handling import Distributor
+from src.data_handling import KerasIterator
 from src.model_modules.keras_extensions import CallbackHandler
 from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
 from src.run_modules.run_environment import RunEnvironment
@@ -64,9 +64,9 @@ class Training(RunEnvironment):
         """Set up and run training."""
         super().__init__()
         self.model: keras.Model = self.data_store.get("model", "model")
-        self.train_set: Union[Distributor, None] = None
-        self.val_set: Union[Distributor, None] = None
-        self.test_set: Union[Distributor, None] = None
+        self.train_set: Union[KerasIterator, None] = None
+        self.val_set: Union[KerasIterator, None] = None
+        self.test_set: Union[KerasIterator, None] = None
         self.batch_size = self.data_store.get("batch_size")
         self.epochs = self.data_store.get("epochs")
         self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model")
@@ -102,9 +102,9 @@ class Training(RunEnvironment):
 
         :param mode: name of set, should be from ["train", "val", "test"]
         """
-        gen = self.data_store.get("generator", mode)
-        kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=mode)
-        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs))
+        collection = self.data_store.get("data_collection", mode)
+        kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path"], scope=mode)
+        setattr(self, f"{mode}_set", KerasIterator(collection, self.batch_size, model=self.model, name=mode, **kwargs))
 
     def set_generators(self) -> None:
         """
@@ -128,15 +128,15 @@ class Training(RunEnvironment):
         """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
         logging.info(f"Train with option upsampling={self.train_set.upsampling}.")
-        logging.info(f"Train with option data_permutation={self.train_set.do_data_permutation}.")
+        logging.info(f"Train with option shuffle={self.train_set.shuffle}.")
 
         checkpoint = self.callbacks.get_checkpoint()
         if not os.path.exists(checkpoint.filepath) or self._create_new_model:
-            history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+            history = self.model.fit_generator(generator=self.train_set,
                                                steps_per_epoch=len(self.train_set),
                                                epochs=self.epochs,
                                                verbose=2,
-                                               validation_data=self.val_set.distribute_on_batches(),
+                                               validation_data=self.val_set,
                                                validation_steps=len(self.val_set),
                                                callbacks=self.callbacks.get_callbacks(as_dict=False))
         else:
@@ -146,11 +146,11 @@ class Training(RunEnvironment):
             self.model = keras.models.load_model(checkpoint.filepath)
             hist: History = self.callbacks.get_callback_by_name("hist")
             initial_epoch = max(hist.epoch) + 1
-            _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+            _ = self.model.fit_generator(generator=self.train_set,
                                          steps_per_epoch=len(self.train_set),
                                          epochs=self.epochs,
                                          verbose=2,
-                                         validation_data=self.val_set.distribute_on_batches(),
+                                         validation_data=self.val_set,
                                          validation_steps=len(self.val_set),
                                          callbacks=self.callbacks.get_callbacks(as_dict=False),
                                          initial_epoch=initial_epoch)
diff --git a/src/workflows/abstract_workflow.py b/src/workflows/abstract_workflow.py
index 5d4e62c8a2e409e865f43412a6757a9cb4e4b1f3..350008eace4598567779228b1302a83c7375fd06 100644
--- a/src/workflows/abstract_workflow.py
+++ b/src/workflows/abstract_workflow.py
@@ -26,4 +26,4 @@ class Workflow:
         """Run workflow embedded in a run environment and according to the stage's ordering."""
         with RunEnvironment():
             for stage, kwargs in self._registry.items():
-                stage(**kwargs)
\ No newline at end of file
+                stage(**kwargs)
diff --git a/test/test_data_handling/test_iterator.py b/test/test_data_handling/test_iterator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4baa3ee45203b2f0b533b38c8f2024419274ed2
--- /dev/null
+++ b/test/test_data_handling/test_iterator.py
@@ -0,0 +1,228 @@
+
+from src.data_handling.iterator import DataCollection, StandardIterator, KerasIterator
+from src.helpers.testing import PyTestAllEqual
+from src.model_modules.model_class import MyLittleModel, MyBranchedModel
+
+import numpy as np
+import pytest
+import mock
+import os
+import shutil
+
+
+class TestStandardIterator:
+
+    @pytest.fixture
+    def collection(self):
+        return list(range(10))
+
+    def test_blank(self):
+        std_iterator = object.__new__(StandardIterator)
+        assert std_iterator._position is None
+
+    def test_init(self, collection):
+        std_iterator = StandardIterator(collection)
+        assert std_iterator._collection == list(range(10))
+        assert std_iterator._position == 0
+
+    def test_next(self, collection):
+        std_iterator = StandardIterator(collection)
+        for i in range(10):
+            assert i == next(std_iterator)
+        with pytest.raises(StopIteration):
+            next(std_iterator)
+        std_iterator = StandardIterator(collection)
+        for e, i in enumerate(iter(std_iterator)):
+            assert i == e
+
+
+class TestDataCollection:
+
+    @pytest.fixture
+    def collection(self):
+        return list(range(10))
+
+    def test_init(self, collection):
+        data_collection = DataCollection(collection)
+        assert data_collection._collection == collection
+
+    def test_iter(self, collection):
+        data_collection = DataCollection(collection)
+        assert isinstance(iter(data_collection), StandardIterator)
+        for e, i in enumerate(data_collection):
+            assert i == e
+
+
+class DummyData:
+
+    def __init__(self, number_of_samples=np.random.randint(100, 150)):
+        self.number_of_samples = number_of_samples
+
+    def get_X(self, upsampling=False, as_numpy=True):
+        X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5))  # samples, window, variables
+        X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2))  # samples, window, variables
+        X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2))  # samples, window, variables
+        return [X1, X2, X3]
+
+    def get_Y(self, upsampling=False, as_numpy=True):
+        Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1))  # samples, window, variables
+        Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1))  # samples, window, variables
+        return [Y1, Y2]
+
+
+class TestKerasIterator:
+
+    @pytest.fixture
+    def collection(self):
+        coll = []
+        for i in range(3):
+            coll.append(DummyData(50 + i))
+        data_coll = DataCollection(collection=coll)
+        return data_coll
+
+    @pytest.fixture
+    def path(self):
+        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+        shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
+        yield p
+        shutil.rmtree(p, ignore_errors=True)
+
+    def test_init(self, collection, path):
+        iterator = KerasIterator(collection, 25, path)
+        assert isinstance(iterator._collection, DataCollection)
+        assert iterator._path == os.path.join(path, str(id(iterator)), "%i.pickle")
+        assert iterator.batch_size == 25
+        assert iterator.shuffle is False
+
+    def test_cleanup_path(self, path):
+        assert os.path.exists(path) is False
+        iterator = object.__new__(KerasIterator)
+        iterator._cleanup_path(path, create_new=False)
+        assert os.path.exists(path) is False
+        iterator._cleanup_path(path)
+        assert os.path.exists(path) is True
+        iterator._cleanup_path(path, create_new=False)
+        assert os.path.exists(path) is False
+
+    def test_get_number_of_mini_batches(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.batch_size = 36
+        assert iterator._get_number_of_mini_batches(30) == 0
+        assert iterator._get_number_of_mini_batches(40) == 1
+        assert iterator._get_number_of_mini_batches(72) == 2
+
+    def test_len(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.indexes = [0, 1, 2, 3, 4, 5]
+        assert len(iterator) == 6
+
+    def test_concatenate(self):
+        arr1 = DummyData(10).get_X()
+        arr2 = DummyData(50).get_X()
+        iterator = object.__new__(KerasIterator)
+        new_arr = iterator._concatenate(arr2, arr1)
+        test_arr = [np.concatenate((arr1[0], arr2[0]), axis=0),
+                    np.concatenate((arr1[1], arr2[1]), axis=0),
+                    np.concatenate((arr1[2], arr2[2]), axis=0)]
+        for i in range(3):
+            assert PyTestAllEqual([new_arr[i], test_arr[i]])
+
+    def test_get_batch(self):
+        arr = DummyData(20).get_X()
+        iterator = object.__new__(KerasIterator)
+        iterator.batch_size = 19
+        batch1 = iterator._get_batch(arr, 0)
+        assert batch1[0].shape[0] == 19
+        batch2 = iterator._get_batch(arr, 1)
+        assert batch2[0].shape[0] == 1
+
+    def test_save_to_pickle(self, path):
+        os.makedirs(path)
+        d = DummyData(20)
+        X, Y = d.get_X(), d.get_Y()
+        iterator = object.__new__(KerasIterator)
+        iterator._path = os.path.join(path, "%i.pickle")
+        assert os.path.exists(iterator._path % 2) is False
+        iterator._save_to_pickle(X=X, Y=Y, index=2)
+        assert os.path.exists(iterator._path % 2) is True
+
+    def test_prepare_batches(self, collection, path):
+        iterator = object.__new__(KerasIterator)
+        iterator._collection = collection
+        iterator.batch_size = 50
+        iterator.indexes = []
+        iterator.model = None
+        iterator.upsampling = False
+        iterator._path = os.path.join(path, "%i.pickle")
+        os.makedirs(path)
+        iterator._prepare_batches()
+        assert len(os.listdir(path)) == 4
+        assert len(iterator.indexes) == 4
+        assert len(iterator) == 4
+        assert iterator.indexes == [0, 1, 2, 3]
+
+    def test_prepare_batches_no_remaining(self, path):
+        iterator = object.__new__(KerasIterator)
+        iterator._collection = DataCollection([DummyData(50)])
+        iterator.batch_size = 50
+        iterator.indexes = []
+        iterator.model = None
+        iterator.upsampling = False
+        iterator._path = os.path.join(path, "%i.pickle")
+        os.makedirs(path)
+        iterator._prepare_batches()
+        assert len(os.listdir(path)) == 1
+        assert len(iterator.indexes) == 1
+        assert len(iterator) == 1
+        assert iterator.indexes == [0]
+
+    def test_data_generation(self, collection, path):
+        iterator = KerasIterator(collection, 50, path)
+        X, Y = iterator._KerasIterator__data_generation(0)
+        expected = next(iter(collection))
+        assert PyTestAllEqual([X, expected.get_X()])
+        assert PyTestAllEqual([Y, expected.get_Y()])
+
+    def test_getitem(self, collection, path):
+        iterator = KerasIterator(collection, 50, path)
+        X, Y = iterator[0]
+        expected = next(iter(collection))
+        assert PyTestAllEqual([X, expected.get_X()])
+        assert PyTestAllEqual([Y, expected.get_Y()])
+        reversed(iterator.indexes)
+        X, Y = iterator[3]
+        assert PyTestAllEqual([X, expected.get_X()])
+        assert PyTestAllEqual([Y, expected.get_Y()])
+
+    def test_on_epoch_end(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.indexes = [0, 1, 2, 3, 4]
+        iterator.shuffle = False
+        iterator.on_epoch_end()
+        assert iterator.indexes == [0, 1, 2, 3, 4]
+        iterator.shuffle = True
+        while iterator.indexes == sorted(iterator.indexes):
+            iterator.on_epoch_end()
+        assert iterator.indexes != [0, 1, 2, 3, 4]
+        assert sorted(iterator.indexes) == [0, 1, 2, 3, 4]
+
+    def test_get_model_rank_no_model(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.model = None
+        assert iterator._get_model_rank() == 1
+
+    def test_get_model_rank_single_output_branch(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.model = MyLittleModel(shape_inputs=[(14, 1, 2)], shape_outputs=[(3,)])
+        assert iterator._get_model_rank() == 1
+
+    def test_get_model_rank_multiple_output_branch(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.model = MyBranchedModel(shape_inputs=[(14, 1, 2)], shape_outputs=[(3,)])
+        assert iterator._get_model_rank() == 3
+
+    def test_get_model_rank_error(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.model = mock.MagicMock(return_value=1)
+        with pytest.raises(TypeError):
+            iterator._get_model_rank()
diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py
index 28a8bf6e421d62d58d76e7a32906f8a594f16ed7..49051e1017826dbe8b61053799f87d69595f441d 100644
--- a/test/test_helpers/test_helpers.py
+++ b/test/test_helpers/test_helpers.py
@@ -237,7 +237,7 @@ class TestLogger:
     def test_setup_logging_path_none(self):
         log_file = Logger.setup_logging_path(None)
         assert PyTestRegex(
-            ".*machinelearningtools/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_file
+            ".*mlair/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_file
 
     @mock.patch("os.makedirs", side_effect=None)
     def test_setup_logging_path_given(self, mock_makedirs):
diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py
index 0ee2eb7e5d439c76888f1f05e238bb5507db6a7a..78adb92f025b9e757733e4d6e014f253799447e2 100644
--- a/test/test_model_modules/test_model_class.py
+++ b/test/test_model_modules/test_model_class.py
@@ -12,7 +12,7 @@ class Paddings:
 class AbstractModelSubClass(AbstractModelClass):
 
     def __init__(self):
-        super().__init__()
+        super().__init__(shape_inputs=(12, 1, 1), shape_outputs=3)
         self.test_attr = "testAttr"
 
 
@@ -20,7 +20,7 @@ class TestAbstractModelClass:
 
     @pytest.fixture
     def amc(self):
-        return AbstractModelClass()
+        return AbstractModelClass(shape_inputs=(14, 1, 2), shape_outputs=(3,))
 
     @pytest.fixture
     def amsc(self):
@@ -31,6 +31,8 @@ class TestAbstractModelClass:
         # assert amc.loss is None
         assert amc.model_name == "AbstractModelClass"
         assert amc.custom_objects == {}
+        assert amc.shape_inputs == (14, 1, 2)
+        assert amc.shape_outputs == 3
 
     def test_model_property(self, amc):
         amc.model = keras.Model()
@@ -179,8 +181,10 @@ class TestAbstractModelClass:
         assert amc.compile == amc.model.compile
 
     def test_get_settings(self, amc, amsc):
-        assert amc.get_settings() == {"model_name": "AbstractModelClass"}
-        assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass"}
+        assert amc.get_settings() == {"model_name": "AbstractModelClass", "shape_inputs": (14, 1, 2),
+                                      "shape_outputs": 3}
+        assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass",
+                                       "shape_inputs": (12, 1, 2), "shape_outputs": 3}
 
     def test_custom_objects(self, amc):
         amc.custom_objects = {"Test": 123}
@@ -200,7 +204,7 @@ class TestMyPaperModel:
 
     @pytest.fixture
     def mpm(self):
-        return MyPaperModel(window_history_size=6, window_lead_time=4, channels=9)
+        return MyPaperModel(shape_inputs=[(7, 1, 9)], shape_outputs=[(4,)])
 
     def test_init(self, mpm):
         # check if loss number of loss functions fit to model outputs
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
index 6de61b2dbe88e24eb3caccf6de575d6340129b91..5150eadee55906002b3ac453b855999c2a9336a2 100644
--- a/test/test_modules/test_model_setup.py
+++ b/test/test_modules/test_model_setup.py
@@ -1,9 +1,11 @@
 import os
+import numpy as np
+import shutil
 
 import pytest
 
-from src.data_handling import DataPrepJoin
-from src.data_handling.data_generator import DataGenerator
+from src.data_handling import KerasIterator
+from src.data_handling import DataCollection
 from src.helpers.datastore import EmptyScope
 from src.model_modules.keras_extensions import CallbackHandler
 from src.model_modules.model_class import AbstractModelClass, MyLittleModel
@@ -29,29 +31,40 @@ class TestModelSetup:
         RunEnvironment().__del__()
 
     @pytest.fixture
-    def gen(self):
-        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'],
-                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
-                             data_preparation=DataPrepJoin)
+    def path(self):
+        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+        shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
+        yield p
+        shutil.rmtree(p, ignore_errors=True)
 
     @pytest.fixture
-    def setup_with_gen(self, setup, gen):
-        setup.data_store.set("generator", gen, "general.train")
-        setup.data_store.set("window_history_size", gen.window_history_size, "general")
-        setup.data_store.set("window_lead_time", gen.window_lead_time, "general")
-        setup.data_store.set("channels", 2, "general")
+    def keras_iterator(self, path):
+        coll = []
+        for i in range(3):
+            coll.append(DummyData(50 + i))
+        data_coll = DataCollection(collection=coll)
+        KerasIterator(data_coll, 25, path)
+        return data_coll
+
+    @pytest.fixture
+    def setup_with_gen(self, setup, keras_iterator):
+        setup.data_store.set("data_collection", keras_iterator, "train")
+        shape_inputs = [keras_iterator[0].get_X()[0].shape[1:]]
+        setup.data_store.set("shape_inputs", shape_inputs, "model")
+        shape_outputs = [keras_iterator[0].get_Y()[0].shape[1:]]
+        setup.data_store.set("shape_outputs", shape_outputs, "model")
         yield setup
         RunEnvironment().__del__()
 
     @pytest.fixture
-    def setup_with_gen_tiny(self, setup, gen):
-        setup.data_store.set("generator", gen, "general.train")
+    def setup_with_gen_tiny(self, setup, keras_iterator):
+        setup.data_store.set("data_collection", keras_iterator, "train")
         yield setup
         RunEnvironment().__del__()
 
     @pytest.fixture
     def setup_with_model(self, setup):
-        setup.model = AbstractModelClass()
+        setup.model = AbstractModelClass(shape_inputs=(12, 1), shape_outputs=2)
         setup.model.test_param = "42"
         yield setup
         RunEnvironment().__del__()
@@ -89,14 +102,17 @@ class TestModelSetup:
         assert setup_with_gen.model is None
         setup_with_gen.build_model()
         assert isinstance(setup_with_gen.model, AbstractModelClass)
-        expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
-                    "optimizer", "activation"}
+        expected = {"lr_decay", "model_name", "dropout_rate", "regularizer", "initial_lr", "optimizer", "activation",
+                    "shape_inputs", "shape_outputs"}
         assert expected <= self.current_scope_as_set(setup_with_gen)
 
-    def test_set_channels(self, setup_with_gen_tiny):
-        assert len(setup_with_gen_tiny.data_store.search_name("channels")) == 0
-        setup_with_gen_tiny._set_channels()
-        assert setup_with_gen_tiny.data_store.get("channels", setup_with_gen_tiny.scope) == 2
+    def test_set_shapes(self, setup_with_gen_tiny):
+        assert len(setup_with_gen_tiny.data_store.search_name("shape_inputs")) == 0
+        assert len(setup_with_gen_tiny.data_store.search_name("shape_outputs")) == 0
+        setup_with_gen_tiny._set_shapes()
+        assert setup_with_gen_tiny.data_store.get("shape_inputs", setup_with_gen_tiny.scope) == [(14, 1, 5), (10, 1, 2),
+                                                                                                 (1, 1, 2)]
+        assert setup_with_gen_tiny.data_store.get("shape_outputs", setup_with_gen_tiny.scope) == [(5,), (3,)]
 
     def test_load_weights(self):
         pass
@@ -109,3 +125,20 @@ class TestModelSetup:
 
     def test_init(self):
         pass
+
+
+class DummyData:
+
+    def __init__(self, number_of_samples=np.random.randint(100, 150)):
+        self.number_of_samples = number_of_samples
+
+    def get_X(self, upsampling=False, as_numpy=True):
+        X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 1, 5))  # samples, window, variables
+        X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 1, 2))  # samples, window, variables
+        X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 1, 2))  # samples, window, variables
+        return [X1, X2, X3]
+
+    def get_Y(self, upsampling=False, as_numpy=True):
+        Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5))  # samples, window
+        Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 3))  # samples, window
+        return [Y1, Y2]
\ No newline at end of file
diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py
index 0b439e9e9ad54ca3aef70e27b2017482706383c0..4bf23e6bb15309c0c5cf33c42637a8bfc752c802 100644
--- a/test/test_modules/test_pre_processing.py
+++ b/test/test_modules/test_pre_processing.py
@@ -42,7 +42,7 @@ class TestPreProcessing:
         caplog.set_level(logging.INFO)
         with PreProcessing():
             assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
-            assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (all)')
+            assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (preprocessing)')
             assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 '
                                                                         r'station\(s\). Found 5/5 valid stations.'))
         RunEnvironment().__del__()