diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py
index 925d83d733b8cb5715b515856e09e773737778a9..8a29fd3d1a204affed70517d3ebb807c820b9160 100644
--- a/src/modules/model_setup.py
+++ b/src/modules/model_setup.py
@@ -3,7 +3,7 @@ __date__ = '2019-12-02'
 
 
 import keras
-from keras import losses, layers
+from keras import losses
 from keras.callbacks import ModelCheckpoint
 from keras.regularizers import l2
 from keras.optimizers import Adam, SGD
@@ -104,7 +104,7 @@ class ModelSetup(RunEnvironment):
         self.data_store.put("batch_size", int(256), self.scope)
 
         # activation
-        activation = layers.PReLU  # ELU #LeakyReLU  keras.activations.tanh #
+        activation = keras.layers.PReLU  # ELU #LeakyReLU  keras.activations.tanh #
         self.data_store.put("activation", activation, self.scope)
 
         # set los
@@ -151,7 +151,7 @@ def my_model(activation, window_history_size, channels, regularizer, dropout_rat
     ##########################################
     inception_model = InceptionModelBase()
 
-    X_input = layers.Input(shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
+    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
 
     X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1, regularizer=regularizer,
                                            batch_normalisation=True)
@@ -159,12 +159,12 @@ def my_model(activation, window_history_size, channels, regularizer, dropout_rat
     out_minor = flatten_tail(X_in, 'Minor_1', bound_weight=True, activation=activation, dropout_rate=dropout_rate,
                              reduction_filter=4, first_dense=32, window_lead_time=window_lead_time)
 
-    X_in = layers.Dropout(dropout_rate)(X_in)
+    X_in = keras.layers.Dropout(dropout_rate)(X_in)
 
     X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=regularizer,
                                            batch_normalisation=True)
 
-    X_in = layers.Dropout(dropout_rate)(X_in)
+    X_in = keras.layers.Dropout(dropout_rate)(X_in)
 
     X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, regularizer=regularizer,
                                            batch_normalisation=True)
diff --git a/src/modules/training.py b/src/modules/training.py
index 8ef3138a83f117f558e11022e2d5053a21666364..2f61e35d536b18f5511b9bf457c5a78f74784a21 100644
--- a/src/modules/training.py
+++ b/src/modules/training.py
@@ -1,7 +1,12 @@
-__author__ = "Lukas Leufen"
+__author__ = "Lukas Leufen, Felix Kleinert"
 __date__ = '2019-12-05'
 
 
+import keras
+import logging
+import numpy as np
+import math
+
 from src.modules.run_environment import RunEnvironment
 
 
@@ -10,10 +15,85 @@ class Training(RunEnvironment):
     def __init__(self):
         super().__init__()
         self.model = self.data_store.get("model", "general.model")
+        self.train_generator = None
+        self.val_generator = None
+        self.test_generator = None
+        self.batch_size = self.data_store.get("batch_size", "general.model")
+        self.epochs = self.data_store.get("epochs", "general.model")
+        self.checkpoint = self.data_store.get("checkpoint", "general.model")
+        self.lr_sc = self.data_store.get("epochs", "general.model")
 
     def make_predict_function(self):
         # create the predict function before distributing. This is necessary, because tf will compile the predict
         # function just in the moment it is used the first time. This can cause problems, if the model is distributed
         # on different workers. To prevent this, the function is pre-compiled. See discussion @
         # https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252
-        self.model._make_predict_function()
\ No newline at end of file
+        self.model._make_predict_function()
+
+    def _set_gen(self, mode):
+        gen = self.data_store.get("generator", f"general.{mode}")
+        setattr(self, f"{mode}_generator", Distributor(gen, self.model, self.batch_size))
+
+    def set_generators(self):
+        map(lambda mode: self._set_gen(mode), ["train", "val", "test"])
+
+    def train(self):
+        logging.info(f"Train with {len(self.train_generator)} mini batches.")
+        history = self.model.fit_generator(generator=self.train_generator.distribute_on_batches(),
+                                           steps_per_epoch=len(self.train_generator),
+                                           epochs=self.epochs,
+                                           verbose=2,
+                                           validation_data=self.val_generator.distribute_on_batches(),
+                                           validation_steps=len(self.val_generator),
+                                           callbacks=[self.checkpoint, self.lr_sc])
+
+
+class Distributor(keras.utils.Sequence):
+
+    def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
+                 fit_call: bool = True):
+        self.generator = generator
+        self.model = model
+        self.batch_size = batch_size
+        self.fit_call = fit_call
+
+    def _get_model_rank(self):
+        mod_out = self.model.output_shape
+        if isinstance(mod_out, tuple):
+            # only one output branch: (None, ahead)
+            mod_rank = 1
+        elif isinstance(mod_out, list):
+            # multiple output branches, e.g.: [(None, ahead), (None, ahead)]
+            mod_rank = len(mod_out)
+        else:  # pragma: no branch
+            raise TypeError("model output shape must either be tuple or list.")
+        return mod_rank
+
+    def _get_number_of_mini_batches(self, values):
+        return math.ceil(values[0].shape[0] / self.batch_size)
+
+    def distribute_on_batches(self, fit_call=True):
+        while True:
+            for k, v in enumerate(self.generator):
+                # get rank of output
+                mod_rank = self._get_model_rank()
+                # get number of mini batches
+                num_mini_batches = self._get_number_of_mini_batches(v)
+                x_total = np.copy(v[0])
+                y_total = np.copy(v[1])
+                for prev, curr in enumerate(range(1, num_mini_batches+1)):
+                    x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
+                    y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
+                    if x is not None:
+                        yield (x, y)
+                        if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call:
+                            raise StopIteration
+
+    def __len__(self):
+        if self.batch_size > 1:
+            num_batch = 0
+            for _ in self.distribute_on_batches(fit_call=False):
+                num_batch += 1
+        else:
+            num_batch = len(self.generator)
+        return num_batch
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..d37d3e466003f93a0e9497dabcf76e4e74c19797
--- /dev/null
+++ b/test/test_modules/test_training.py
@@ -0,0 +1,86 @@
+import pytest
+import os
+import keras
+import math
+import numpy as np
+
+from src.modules.training import Distributor
+from src.data_generator import DataGenerator
+from src.inception_model import InceptionModelBase
+from src.flatten import flatten_tail
+
+
+def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
+    inception_model = InceptionModelBase()
+    conv_settings_dict1 = {
+        'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
+        'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
+    pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
+    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
+    X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
+    if add_minor_branch:
+        out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
+    else:
+        out = []
+    X_in = keras.layers.Dropout(dropout_rate)(X_in)
+    out.append(flatten_tail(X_in, 'Main', activation=activation))
+    return keras.Model(inputs=X_input, outputs=out)
+
+
+class TestDistributor:
+
+    @pytest.fixture
+    def generator(self):
+        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
+                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+
+    @pytest.fixture
+    def generator_two_stations(self):
+        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'],
+                             ['o3', 'temp'], 'datetime', 'variables', 'o3',
+                             statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+
+    @pytest.fixture
+    def model(self):
+        return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False)
+
+    @pytest.fixture
+    def distributor(self, generator, model):
+        return Distributor(generator, model)
+
+    def test_init_defaults(self, distributor):
+        assert distributor.batch_size == 256
+        assert distributor.fit_call is True
+
+    def test_get_model_rank(self, distributor):
+        assert distributor._get_model_rank() == 1
+        distributor.model = my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
+        assert distributor._get_model_rank() == 2
+        distributor.model = 1
+
+    def test_get_number_of_mini_batches(self, distributor):
+        values = np.zeros((2, 2311, 19))
+        assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)
+
+    def test_distribute_on_batches(self,  generator_two_stations, model):
+        d = Distributor(generator_two_stations, model)
+        for e in d.distribute_on_batches(fit_call=False):
+            assert e[0].shape[0] <= d.batch_size
+        elements = []
+        for i, e in enumerate(d.distribute_on_batches()):
+            if i < len(d):
+                elements.append(e[0])
+            elif i == 2*len(d):  # check if all elements are repeated
+                assert np.testing.assert_array_equal(e[0], elements[i - len(d)]) is None
+            else:  # break when 3rd iteration starts (is called as infinite loop)
+                break
+
+    def test_len(self, distributor):
+        assert len(distributor) == math.ceil(len(distributor.generator[0][0]) / 256)
+
+    def test_len_two_stations(self, generator_two_stations, model):
+        gen = generator_two_stations
+        d = Distributor(gen, model)
+        expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
+        assert len(d) == expected
+