diff --git a/src/data_handling/__init__.py b/src/data_handling/__init__.py
index cb5aa5db0f29cf51d32ed54e810fa9b363d80cc6..dc14146417bd4a6bf30ccbb79537e7920313b077 100644
--- a/src/data_handling/__init__.py
+++ b/src/data_handling/__init__.py
@@ -13,3 +13,6 @@ from .bootstraps import BootStraps
 from .data_preparation_join import DataPrepJoin
 from .data_generator import DataGenerator
 from .data_distributor import Distributor
+from .iterator import KerasIterator, DataCollection
+from .advanced_data_handling import DataPreparation
+from .data_preparation import StationPrep
\ No newline at end of file
diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py
index e36e0c75fc9107431a69482d46755acbdf5334bd..d4c4e363dfd1d27ad7ecb2ef34a619b61c20e9fb 100644
--- a/src/data_handling/advanced_data_handling.py
+++ b/src/data_handling/advanced_data_handling.py
@@ -4,6 +4,7 @@ __date__ = '2020-07-08'
 
 
 from src.helpers import to_list, remove_items
+from src.data_handling.data_preparation import StationPrep
 import numpy as np
 import xarray as xr
 import pickle
@@ -46,8 +47,8 @@ class DummyDataSingleStation:  # pragma: no cover
 
 class DataPreparation:
 
-    def __init__(self, id_class, interpolate_dim: str, store_path, neighbors=None, min_length=0,
-                 extreme_values: num_or_list = None,extremes_on_right_tail_only: bool = False,):
+    def __init__(self, id_class, interpolate_dim: str, data_path, neighbors=None, min_length=0,
+                 extreme_values: num_or_list = None,extremes_on_right_tail_only: bool = False):
         self.id_class = id_class
         self.neighbors = to_list(neighbors) if neighbors is not None else []
         self.interpolate_dim = interpolate_dim
@@ -56,7 +57,7 @@ class DataPreparation:
         self._Y = None
         self._X_extreme = None
         self._Y_extreme = None
-        self._save_file = os.path.join(store_path, f"data_preparation_{str(self.id_class)}.pickle")
+        self._save_file = os.path.join(data_path, f"data_preparation_{str(self.id_class)}.pickle")
         self._collection = []
         self._create_collection()
         self.harmonise_X()
@@ -119,17 +120,17 @@ class DataPreparation:
     def _to_numpy(d):
         return list(map(lambda x: np.copy(x), d))
 
-    def get_X(self, upsamling=False, as_numpy=True):
+    def get_X(self, upsampling=False, as_numpy=True):
         no_data = (self._X is None)
         self._load() if no_data is True else None
-        X = self._X if upsamling is False else self._X_extreme
+        X = self._X if upsampling is False else self._X_extreme
         self._reset_data() if no_data is True else None
         return self._to_numpy(X) if as_numpy is True else X
 
-    def get_Y(self, upsamling=False, as_numpy=True):
+    def get_Y(self, upsampling=False, as_numpy=True):
         no_data = (self._Y is None)
         self._load() if no_data is True else None
-        Y = self._Y if upsamling is False else self._Y_extreme
+        Y = self._Y if upsampling is False else self._Y_extreme
         self._reset_data() if no_data is True else None
         return self._to_numpy([Y]) if as_numpy is True else Y
 
@@ -250,6 +251,41 @@ def create_data_prep():
     data_prep.append(DataPreparation(neighbor2, interpolate_dim, path, neighbors=[neighbor1, central_station]))
     return data_prep
 
+
+class AbstractDataClass:
+
+    def __init__(self):
+        self._requires = []
+
+    def __call__(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @property
+    def requirements(self):
+        return self._requires
+
+    @requirements.setter
+    def requirements(self, value):
+        self._requires = value
+
+
+class CustomDataClass(AbstractDataClass):
+
+    def __init__(self):
+        import inspect
+        super().__init__()
+        self.sp_keys = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"])
+        self.dp_keys = remove_items(inspect.getfullargspec(DataPreparation).args, ["self", "id_class"])
+        self.requirements = self.sp_keys + self.dp_keys
+
+    def __call__(self, station, **kwargs):
+        sp_keys = {k: kwargs[k] for k in self.sp_keys if k in kwargs}
+        sp_keys["station"] = station
+        sp = StationPrep(**sp_keys)
+        dp_args = {k: kwargs[k] for k in self.dp_keys if k in kwargs}
+        return DataPreparation(sp, **dp_args)
+
+
 if __name__ == "__main__":
     from src.data_handling.data_preparation import StationPrep
     from src.data_handling.iterator import KerasIterator, DataCollection
@@ -258,6 +294,6 @@ if __name__ == "__main__":
     for data in data_collection:
         print(data)
     path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras")
-    keras_it = KerasIterator(data_collection, 100, path)
+    keras_it = KerasIterator(data_collection, 100, path, upsampling=True)
     keras_it[2]
 
diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index d5933f193018efb1529db2c026981e8c4d7936d2..dadda2c58979ddb2678d366470c3b6d3f0584ee4 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -68,8 +68,8 @@ class AbstractStationPrep():
 
 class StationPrep(AbstractStationPrep):
 
-    def __init__(self, path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var,
-                 interpolate_dim, window_history_size, window_lead_time, **kwargs):
+    def __init__(self, data_path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var,
+                 interpolate_dim, window_history_size, window_lead_time, overwrite_local_data: bool = False, **kwargs):
         super().__init__()  # path, station, statistics_per_var, transformation, **kwargs)
         self.station_type = station_type
         self.network = network
@@ -80,12 +80,10 @@ class StationPrep(AbstractStationPrep):
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
 
-        self.path = os.path.abspath(path)
+        self.path = os.path.abspath(data_path)
         self.station = helpers.to_list(station)
         self.statistics_per_var = statistics_per_var
         # self.target_dim = 'variable'
-        self.transformation = self.setup_transformation(transformation)
-        self.kwargs = kwargs
 
         # internal
         self.data = None
@@ -95,17 +93,15 @@ class StationPrep(AbstractStationPrep):
         self.label = None
         self.observation = None
 
-    def __str__(self):
-        return self.station[0]
+        self.transformation = self.setup_transformation(transformation)
+        self.kwargs = kwargs
+        self.kwargs["overwrite_local_data"] = overwrite_local_data
 
-    def load_data(self):
-        try:
-            self.read_data_from_disk()
-        except FileNotFoundError:
-            self.download_data()
-            self.load_data()
         self.make_samples()
 
+    def __str__(self):
+        return self.station[0]
+
     def __repr__(self):
         return f"StationPrep(path='{self.path}', station={self.station}, statistics_per_var={self.statistics_per_var}, " \
                f"transformation={self.transformation}, station_type='{self.station_type}', network='{self.network}', " \
diff --git a/src/data_handling/iterator.py b/src/data_handling/iterator.py
index 14d71a9afc23d3a0d80bacf60bbaa928fb34407a..d2ef9eb8df6373934e30ef9ca98c5de3fefed6c9 100644
--- a/src/data_handling/iterator.py
+++ b/src/data_handling/iterator.py
@@ -33,23 +33,37 @@ class StandardIterator(Iterator):
 
 class DataCollection(Iterable):
 
-    def __init__(self, collection: list):
+    def __init__(self, collection: list = None):
+        if collection is None:
+            collection = []
         assert isinstance(collection, list)
         self._collection = collection
 
+    def __len__(self):
+        return len(self._collection)
+
     def __iter__(self) -> Iterator:
         return StandardIterator(self._collection)
 
+    def __getitem__(self, index):
+        return self._collection[index]
+
+    def add(self, element):
+        self._collection.append(element)
+
 
 class KerasIterator(keras.utils.Sequence):
 
-    def __init__(self, collection: DataCollection, batch_size: int, path: str, shuffle: bool = False):
+    def __init__(self, collection: DataCollection, batch_size: int, batch_path: str, shuffle_batches: bool = False,
+                 model=None, upsampling=False):
         self._collection = collection
-        self._path = os.path.join(path, "%i.pickle")
+        self._path = os.path.join(batch_path, "%i.pickle")
         self.batch_size = batch_size
-        self.shuffle = shuffle
+        self.model = model
+        self.shuffle = shuffle_batches
+        self.upsampling = upsampling
         self.indexes: list = []
-        self._cleanup_path(path)
+        self._cleanup_path(batch_path)
         self._prepare_batches()
 
     def __len__(self) -> int:
@@ -59,6 +73,19 @@ class KerasIterator(keras.utils.Sequence):
         """Get batch for given index."""
         return self.__data_generation(self.indexes[index])
 
+    def _get_model_rank(self):
+        if self.model is not None:
+            mod_out = self.model.output_shape
+            if isinstance(mod_out, tuple):  # only one output branch: (None, ahead)
+                mod_rank = 1
+            elif isinstance(mod_out, list):  # multiple output branches, e.g.: [(None, ahead), (None, ahead)]
+                mod_rank = len(mod_out)
+            else:  # pragma: no cover
+                raise TypeError("model output shape must either be tuple or list.")
+            return mod_rank
+        else:  # no model provided, assume to use single output
+            return 1
+
     def __data_generation(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
         """Load pickle data from disk."""
         file = self._path % index
@@ -75,6 +102,12 @@ class KerasIterator(keras.utils.Sequence):
         """Get batch according to batch size from data list."""
         return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list))
 
+    def _permute_data(self, X, Y):
+        p = np.random.permutation(len(X[0]))  # equiv to .shape[0]
+        X = list(map(lambda x: x[p], X))
+        Y = list(map(lambda x: x[p], Y))
+        return X, Y
+
     def _prepare_batches(self) -> None:
         """
         Prepare all batches as locally stored files.
@@ -86,8 +119,12 @@ class KerasIterator(keras.utils.Sequence):
         """
         index = 0
         remaining = None
+        mod_rank = self._get_model_rank()
         for data in self._collection:
-            X, Y = data.get_X(), data.get_Y()
+            X = data.get_X(upsampling=self.upsampling)
+            Y = [data.get_Y(upsampling=self.upsampling)[0] for _ in range(mod_rank)]
+            if self.upsampling:
+                X, Y = self._permute_data(X, Y)
             if remaining is not None:
                 X, Y = self._concatenate(X, remaining[0]), self._concatenate(Y, remaining[1])
             length = X[0].shape[0]
diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index dab2e168c5a9f87d4aee42fc94489fd0fa67772a..6b3b9972bc0af4c968f2831963cc18446ff09162 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -351,7 +351,7 @@ class MyLittleModel(AbstractModelClass):
         # settings
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
-        self.channels = channels
+        self.channels = channels[0]
         self.dropout_rate = 0.1
         self.regularizer = keras.regularizers.l2(0.1)
         self.activation = keras.layers.PReLU
@@ -387,7 +387,7 @@ class MyLittleModel(AbstractModelClass):
         x_in = self.activation()(x_in)
         x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
         out_main = self.activation()(x_in)
-        self.model = keras.Model(inputs=x_input, outputs=[out_main])
+        self.model = keras.Model(inputs=[x_input], outputs=[out_main])
 
     def set_compile_options(self):
         self.initial_lr = 1e-2
@@ -423,7 +423,7 @@ class MyBranchedModel(AbstractModelClass):
         # settings
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
-        self.channels = channels
+        self.channels = channels[0]
         self.dropout_rate = 0.1
         self.regularizer = keras.regularizers.l2(0.1)
         self.activation = keras.layers.PReLU
@@ -493,7 +493,7 @@ class MyTowerModel(AbstractModelClass):
         # settings
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
-        self.channels = channels
+        self.channels = channels[0]
         self.dropout_rate = 1e-2
         self.regularizer = keras.regularizers.l2(0.1)
         self.initial_lr = 1e-2
@@ -605,7 +605,7 @@ class MyPaperModel(AbstractModelClass):
         # settings
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
-        self.channels = channels
+        self.channels = channels[0]
         self.dropout_rate = .3
         self.regularizer = keras.regularizers.l2(0.001)
         self.initial_lr = 1e-3
diff --git a/src/run.py b/src/run.py
index 7e262dd769204077697b7df3f3fbaedb4c012257..4033d52303035ede583529169e93548ab7a205e1 100644
--- a/src/run.py
+++ b/src/run.py
@@ -39,5 +39,5 @@ def run(stations=None,
 
 
 if __name__ == "__main__":
-
-    run()
+    from src.data_handling.advanced_data_handling import CustomDataClass
+    run(data_preparation=CustomDataClass, statistics_per_var={'o3': 'dma8eu'}, transformation={})
diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index 1d375c32be06b583abbfb06a20ea482e6775b232..3e471dda7934fc53c990ede6e459c41f3ef6229b 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -19,6 +19,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D
     DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST
 from src.data_handling import DataPrepJoin
+from src.data_handling.advanced_data_handling import CustomDataClass
 from src.run_modules.run_environment import RunEnvironment
 from src.model_modules.model_class import MyLittleModel as VanillaModel
 
@@ -228,8 +229,8 @@ class ExperimentSetup(RunEnvironment):
                  create_new_model = None, bootstrap_path=None, permute_data_on_training = None, transformation=None,
                  train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
                  extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None,
-                 create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None,
-                 batch_size=None, epochs=None, data_preparation=None):
+                 create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
+                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_preparation=None):
 
         # create run framework
         super().__init__()
@@ -265,6 +266,9 @@ class ExperimentSetup(RunEnvironment):
         logging.info(f"Experiment path is: {experiment_path}")
         path_config.check_path_and_create(self.data_store.get("experiment_path"))
 
+        # batch path (temporary)
+        self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data"))
+
         # set model path
         self._set_param("model_path", None, os.path.join(experiment_path, "model"))
         path_config.check_path_and_create(self.data_store.get("model_path"))
@@ -297,7 +301,8 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("sampling", sampling)
         self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
         self._set_param("transformation", None, scope="preprocessing")
-        self._set_param("data_preparation", data_preparation, default=DataPrepJoin)
+        self._set_param("data_preparation", data_preparation(), default=CustomDataClass())
+        assert isinstance(getattr(self.data_store.get("data_preparation"), "requirements"), property) is False
 
         # target
         self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index f9683b953d85bacf6e452e0a1922e85dfe946cd1..dc537eb1cd3e5cf04fbdddee3017e4ace7f7bfca 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -90,7 +90,7 @@ class ModelSetup(RunEnvironment):
 
     def _set_channels(self):
         """Set channels as number of variables of train generator."""
-        channels = self.data_store.get("generator", "train")[0][0].shape[-1]
+        channels = list(map(lambda x: x[0].shape[-1], self.data_store.get("data_collection", "train")[0].get_X()))
         self.data_store.set("channels", channels, self.scope)
 
     def compile_model(self):
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index db7fff2ab9e385ce769f86ef95d1565ea783cc95..72493c1fbad42a7aa9fec1e32292c0727a7dfb38 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -11,6 +11,8 @@ import numpy as np
 import pandas as pd
 
 from src.data_handling import DataGenerator
+from src.data_handling import DataCollection, DataPreparation, StationPrep
+from src.data_handling.advanced_data_handling import CustomDataClass
 from src.helpers import TimeTracking
 from src.configuration import path_config
 from src.helpers.join import EmptyQueryResult
@@ -59,10 +61,9 @@ class PreProcessing(RunEnvironment):
         self._run()
 
     def _run(self):
-        args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="preprocessing")
-        kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="preprocessing")
         stations = self.data_store.get("stations")
-        valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False, name="all")
+        data_preparation = self.data_store.get("data_preparation")
+        _, valid_stations = self.validate_station(data_preparation, stations, "preprocessing", overwrite_local_data=True)
         self.data_store.set("stations", valid_stations)
         self.split_train_val_test()
         self.report_pre_processing()
@@ -70,16 +71,14 @@ class PreProcessing(RunEnvironment):
     def report_pre_processing(self):
         """Log some metrics on data and create latex report."""
         logging.debug(20 * '##')
-        n_train = len(self.data_store.get('generator', 'train'))
-        n_val = len(self.data_store.get('generator', 'val'))
-        n_test = len(self.data_store.get('generator', 'test'))
+        n_train = len(self.data_store.get('data_collection', 'train'))
+        n_val = len(self.data_store.get('data_collection', 'val'))
+        n_test = len(self.data_store.get('data_collection', 'test'))
         n_total = n_train + n_val + n_test
         logging.debug(f"Number of all stations: {n_total}")
         logging.debug(f"Number of training stations: {n_train}")
         logging.debug(f"Number of val stations: {n_val}")
         logging.debug(f"Number of test stations: {n_test}")
-        logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'test')[0][0].shape}"
-                      f"{self.data_store.get('generator', 'test')[0][1].shape}")
         self.create_latex_report()
 
     def create_latex_report(self):
@@ -121,11 +120,12 @@ class PreProcessing(RunEnvironment):
         set_names = ["train", "val", "test"]
         df = pd.DataFrame(columns=meta_data + set_names)
         for set_name in set_names:
-            data: DataGenerator = self.data_store.get("generator", set_name)
-            for station in data.stations:
-                df.loc[station, set_name] = data.get_data_generator(station).get_transposed_label().shape[0]
-                if df.loc[station, meta_data].isnull().any():
-                    df.loc[station, meta_data] = data.get_data_generator(station).meta.loc[meta_data].values.flatten()
+            data = self.data_store.get("data_collection", set_name)
+            for station in data:
+                station_name = str(station.id_class)
+                df.loc[station_name, set_name] = station.get_Y()[0].shape[0]
+                if df.loc[station_name, meta_data].isnull().any():
+                    df.loc[station_name, meta_data] = station.id_class.meta.loc[meta_data].values.flatten()
             df.loc["# Samples", set_name] = df.loc[:, set_name].sum()
             df.loc["# Stations", set_name] = df.loc[:, set_name].count()
         df[meta_round] = df[meta_round].astype(float).round(precision)
@@ -147,7 +147,7 @@ class PreProcessing(RunEnvironment):
         Split data into subsets.
 
         Currently: train, val, test and train_val (actually this is only the merge of train and val, but as an separate
-        generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs
+        data_collection). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs
         always to be executed at first, to set a proper transformation.
         """
         fraction_of_training = self.data_store.get("fraction_of_training")
@@ -159,7 +159,7 @@ class PreProcessing(RunEnvironment):
             raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset"
                                  f"order was: {subset_names}.")
         for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names):
-            self.create_set_split(ind, scope)
+            self.create_set_split_new(ind, scope)
 
     @staticmethod
     def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]:
@@ -183,13 +183,27 @@ class PreProcessing(RunEnvironment):
         train_val_index = slice(0, pos_test_split)
         return train_index, val_index, test_index, train_val_index
 
+    def create_set_split_new(self, index_list: slice, set_name: str) -> None:
+        # get set stations
+        stations = self.data_store.get("stations", scope=set_name)
+        if self.data_store.get("use_all_stations_on_all_data_sets"):
+            set_stations = stations
+        else:
+            set_stations = stations[index_list]
+        logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
+        # create set data_collection and store
+        data_preparation = self.data_store.get("data_preparation")
+        collection, valid_stations = self.validate_station(data_preparation, set_stations, set_name)
+        self.data_store.set("stations", valid_stations, scope=set_name)
+        self.data_store.set("data_collection", collection, scope=set_name)
+
     def create_set_split(self, index_list: slice, set_name: str) -> None:
         """
         Create subsets and store in data store.
 
-        Create the subset for given split index and stores the DataGenerator with given set name in data store as
-        `generator`. Check for all valid stations using the default (kw)args for given scope and create the
-        DataGenerator for all valid stations. Also set all transformation information, if subset is training set. Make
+        Create the subset for given split index and stores the data_collection with given set name in data store as
+        `data_collection`. Check for all valid stations using the default (kw)args for given scope and create the
+        data_collection for all valid stations. Also set all transformation information, if subset is training set. Make
         sure, that the train set is executed first, and all other subsets afterwards.
 
         :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True,
@@ -207,13 +221,18 @@ class PreProcessing(RunEnvironment):
         # validate set
         set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False, name=set_name)
         self.data_store.set("stations", set_stations, scope=set_name)
-        # create set generator and store
+        # create set data_collection and store
         set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name)
-        data_set = DataGenerator(**set_args, **kwargs)
-        self.data_store.set("generator", data_set, scope=set_name)
-        # extract transformation from train set
-        if set_name == "train":
-            self.data_store.set("transformation", data_set.transformation)
+        data_prep_kwargs = self.data_store.create_args_dict(["interpolate_dim", "data_path", "min_length", "extreme_values", "extremes_on_right_tail_only"], scope=set_name)
+        collection = DataCollection()
+        for station in set_stations:
+            args["station"] = station
+
+            def f(sp_args, sp_kwargs, dp_kwargs):
+                DataPreparation(StationPrep(**sp_args, **sp_kwargs), **dp_kwargs)
+
+            collection.add(f(**set_args, **kwargs, **data_prep_kwargs))
+        self.data_store.set("data_collection", collection, scope=set_name)
 
     @staticmethod
     def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True,
@@ -257,3 +276,36 @@ class PreProcessing(RunEnvironment):
         logging.info(f"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/"
                      f"{len(all_stations)} valid stations.")
         return valid_stations
+
+    def validate_station(self, data_preparation, set_stations, set_name=None, overwrite_local_data=False):
+        """
+        Check if all given stations in `all_stations` are valid.
+
+        Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the
+        loading time are logged in debug mode.
+
+        :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
+            `variables`, `interpolate_dim`, `target_dim`, `target_var`).
+        :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`,
+            `window_lead_time`).
+        :param all_stations: All stations to check.
+        :param name: name to display in the logging info message
+
+        :return: Corrected list containing only valid station IDs.
+        """
+        t_outer = TimeTracking()
+        logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}")
+        collection = DataCollection()
+        valid_stations = []
+        kwargs = self.data_store.create_args_dict(data_preparation.requirements, scope=set_name)
+        for station in set_stations:
+            try:
+                dp = data_preparation(station, **kwargs)
+                collection.add(dp)
+                valid_stations.append(station)
+            except (AttributeError, EmptyQueryResult):
+                continue
+        logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/"
+                     f"{len(set_stations)} valid stations.")
+        return collection, valid_stations
+
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 1a0d7beb1ec37bb5e59a4129da58572d79a73636..a92fd56fda5599489992b1bccaca3a715dd622d7 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -11,7 +11,7 @@ from typing import Union
 import keras
 from keras.callbacks import Callback, History
 
-from src.data_handling import Distributor
+from src.data_handling import KerasIterator
 from src.model_modules.keras_extensions import CallbackHandler
 from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
 from src.run_modules.run_environment import RunEnvironment
@@ -64,9 +64,9 @@ class Training(RunEnvironment):
         """Set up and run training."""
         super().__init__()
         self.model: keras.Model = self.data_store.get("model", "model")
-        self.train_set: Union[Distributor, None] = None
-        self.val_set: Union[Distributor, None] = None
-        self.test_set: Union[Distributor, None] = None
+        self.train_set: Union[KerasIterator, None] = None
+        self.val_set: Union[KerasIterator, None] = None
+        self.test_set: Union[KerasIterator, None] = None
         self.batch_size = self.data_store.get("batch_size")
         self.epochs = self.data_store.get("epochs")
         self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model")
@@ -102,9 +102,9 @@ class Training(RunEnvironment):
 
         :param mode: name of set, should be from ["train", "val", "test"]
         """
-        gen = self.data_store.get("generator", mode)
-        kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=mode)
-        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs))
+        collection = self.data_store.get("data_collection", mode)
+        kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path"], scope=mode)
+        setattr(self, f"{mode}_set", KerasIterator(collection, self.batch_size, model=self.model, **kwargs))
 
     def set_generators(self) -> None:
         """
@@ -128,15 +128,15 @@ class Training(RunEnvironment):
         """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
         logging.info(f"Train with option upsampling={self.train_set.upsampling}.")
-        logging.info(f"Train with option data_permutation={self.train_set.do_data_permutation}.")
+        logging.info(f"Train with option shuffle={self.train_set.shuffle}.")
 
         checkpoint = self.callbacks.get_checkpoint()
         if not os.path.exists(checkpoint.filepath) or self._create_new_model:
-            history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+            history = self.model.fit_generator(generator=self.train_set,
                                                steps_per_epoch=len(self.train_set),
                                                epochs=self.epochs,
                                                verbose=2,
-                                               validation_data=self.val_set.distribute_on_batches(),
+                                               validation_data=self.val_set,
                                                validation_steps=len(self.val_set),
                                                callbacks=self.callbacks.get_callbacks(as_dict=False))
         else:
@@ -146,11 +146,11 @@ class Training(RunEnvironment):
             self.model = keras.models.load_model(checkpoint.filepath)
             hist: History = self.callbacks.get_callback_by_name("hist")
             initial_epoch = max(hist.epoch) + 1
-            _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+            _ = self.model.fit_generator(generator=self.train_set,
                                          steps_per_epoch=len(self.train_set),
                                          epochs=self.epochs,
                                          verbose=2,
-                                         validation_data=self.val_set.distribute_on_batches(),
+                                         validation_data=self.val_set,
                                          validation_steps=len(self.val_set),
                                          callbacks=self.callbacks.get_callbacks(as_dict=False),
                                          initial_epoch=initial_epoch)