diff --git a/conftest.py b/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d2159c3b3a3efd7d0c0bfb5bf6bb058697d79c
--- /dev/null
+++ b/conftest.py
@@ -0,0 +1,25 @@
+import os
+import shutil
+
+
+def pytest_runtest_teardown(item, nextitem):
+    """
+    Teardown method to clean up folder creations during testing. This method is called after each test, but performs
+    deletions only after an entire test class was executed.
+    :param item: tested item
+    :param nextitem: next item (could be None, if no following test is available)
+    """
+    if nextitem is None or item.cls != nextitem.cls:
+        # clean up all TestExperiment and data folder that have been created during testing
+        rel_path = os.path.relpath(item.fspath.dirname, os.path.abspath(__file__))
+        path = os.path.dirname(__file__)
+        for stage in filter(None, rel_path.replace("..", ".").split("/")):
+            path = os.path.abspath(os.path.join(path, stage))
+            list_dir = os.listdir(path)
+            if "data" in list_dir and path != os.path.dirname(__file__):  # do not delete data folder in src
+                shutil.rmtree(os.path.join(path, "data"), ignore_errors=True)
+            if "TestExperiment" in list_dir:
+                shutil.rmtree(os.path.join(path, "TestExperiment"), ignore_errors=True)
+    else:
+        pass  # nothing to do if next test is from same test class
+
diff --git a/src/data_handling/__init__.py b/src/data_handling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
new file mode 100644
index 0000000000000000000000000000000000000000..77f83536db5eaed3545d609e1d33a042c7ad23dd
--- /dev/null
+++ b/src/data_handling/data_distributor.py
@@ -0,0 +1,53 @@
+from __future__ import generator_stop
+import math
+
+import keras
+import numpy as np
+
+
+class Distributor(keras.utils.Sequence):
+
+    def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
+                 fit_call: bool = True):
+        self.generator = generator
+        self.model = model
+        self.batch_size = batch_size
+        self.fit_call = fit_call
+
+    def _get_model_rank(self):
+        mod_out = self.model.output_shape
+        if isinstance(mod_out, tuple):
+            # only one output branch: (None, ahead)
+            mod_rank = 1
+        elif isinstance(mod_out, list):
+            # multiple output branches, e.g.: [(None, ahead), (None, ahead)]
+            mod_rank = len(mod_out)
+        else:  # pragma: no cover
+            raise TypeError("model output shape must either be tuple or list.")
+        return mod_rank
+
+    def _get_number_of_mini_batches(self, values):
+        return math.ceil(values[0].shape[0] / self.batch_size)
+
+    def distribute_on_batches(self, fit_call=True):
+        while True:
+            for k, v in enumerate(self.generator):
+                # get rank of output
+                mod_rank = self._get_model_rank()
+                # get number of mini batches
+                num_mini_batches = self._get_number_of_mini_batches(v)
+                x_total = np.copy(v[0])
+                y_total = np.copy(v[1])
+                for prev, curr in enumerate(range(1, num_mini_batches+1)):
+                    x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
+                    y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
+                    if x is not None:
+                        yield (x, y)
+                        if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call:
+                            return
+
+    def __len__(self):
+        num_batch = 0
+        for _ in self.distribute_on_batches(fit_call=False):
+            num_batch += 1
+        return num_batch
diff --git a/src/data_generator.py b/src/data_handling/data_generator.py
similarity index 99%
rename from src/data_generator.py
rename to src/data_handling/data_generator.py
index 4e7dda9363c226c5fa92d03f1dbae6470e48d496..1de0ab2092b46dfca6281963b260b1fc6bc65387 100644
--- a/src/data_generator.py
+++ b/src/data_handling/data_generator.py
@@ -3,7 +3,7 @@ __date__ = '2019-11-07'
 
 import keras
 from src import helpers
-from src.data_preparation import DataPrep
+from src.data_handling.data_preparation import DataPrep
 import os
 from typing import Union, List, Tuple
 import xarray as xr
diff --git a/src/data_preparation.py b/src/data_handling/data_preparation.py
similarity index 100%
rename from src/data_preparation.py
rename to src/data_handling/data_preparation.py
diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py
index 925d83d733b8cb5715b515856e09e773737778a9..8a29fd3d1a204affed70517d3ebb807c820b9160 100644
--- a/src/modules/model_setup.py
+++ b/src/modules/model_setup.py
@@ -3,7 +3,7 @@ __date__ = '2019-12-02'
 
 
 import keras
-from keras import losses, layers
+from keras import losses
 from keras.callbacks import ModelCheckpoint
 from keras.regularizers import l2
 from keras.optimizers import Adam, SGD
@@ -104,7 +104,7 @@ class ModelSetup(RunEnvironment):
         self.data_store.put("batch_size", int(256), self.scope)
 
         # activation
-        activation = layers.PReLU  # ELU #LeakyReLU  keras.activations.tanh #
+        activation = keras.layers.PReLU  # ELU #LeakyReLU  keras.activations.tanh #
         self.data_store.put("activation", activation, self.scope)
 
         # set los
@@ -151,7 +151,7 @@ def my_model(activation, window_history_size, channels, regularizer, dropout_rat
     ##########################################
     inception_model = InceptionModelBase()
 
-    X_input = layers.Input(shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
+    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
 
     X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1, regularizer=regularizer,
                                            batch_normalisation=True)
@@ -159,12 +159,12 @@ def my_model(activation, window_history_size, channels, regularizer, dropout_rat
     out_minor = flatten_tail(X_in, 'Minor_1', bound_weight=True, activation=activation, dropout_rate=dropout_rate,
                              reduction_filter=4, first_dense=32, window_lead_time=window_lead_time)
 
-    X_in = layers.Dropout(dropout_rate)(X_in)
+    X_in = keras.layers.Dropout(dropout_rate)(X_in)
 
     X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=regularizer,
                                            batch_normalisation=True)
 
-    X_in = layers.Dropout(dropout_rate)(X_in)
+    X_in = keras.layers.Dropout(dropout_rate)(X_in)
 
     X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, regularizer=regularizer,
                                            batch_normalisation=True)
diff --git a/src/modules/pre_processing.py b/src/modules/pre_processing.py
index 764613ea4558bfdef4cbada668f16608f81d5f95..8fad9d1bf756baf830c236f16102878ba83515c2 100644
--- a/src/modules/pre_processing.py
+++ b/src/modules/pre_processing.py
@@ -3,12 +3,11 @@ __date__ = '2019-11-25'
 
 
 import logging
-from typing import Any, Tuple, Dict, List
+from typing import Tuple, Dict, List
 
-from src.data_generator import DataGenerator
+from src.data_handling.data_generator import DataGenerator
 from src.helpers import TimeTracking
 from src.modules.run_environment import RunEnvironment
-from src.datastore import NameNotFoundInDataStore, NameNotFoundInScope
 from src.join import EmptyQueryResult
 
 
diff --git a/src/modules/training.py b/src/modules/training.py
index 8ef3138a83f117f558e11022e2d5053a21666364..866e9405acec35d4602a1ca6b079fdc53a05b71f 100644
--- a/src/modules/training.py
+++ b/src/modules/training.py
@@ -1,8 +1,11 @@
-__author__ = "Lukas Leufen"
+
+__author__ = "Lukas Leufen, Felix Kleinert"
 __date__ = '2019-12-05'
 
+import logging
 
 from src.modules.run_environment import RunEnvironment
+from src.data_handling.data_distributor import Distributor
 
 
 class Training(RunEnvironment):
@@ -10,10 +13,36 @@ class Training(RunEnvironment):
     def __init__(self):
         super().__init__()
         self.model = self.data_store.get("model", "general.model")
+        self.train_generator = None
+        self.val_generator = None
+        self.test_generator = None
+        self.batch_size = self.data_store.get("batch_size", "general.model")
+        self.epochs = self.data_store.get("epochs", "general.model")
+        self.checkpoint = self.data_store.get("checkpoint", "general.model")
+        self.lr_sc = self.data_store.get("epochs", "general.model")
 
     def make_predict_function(self):
         # create the predict function before distributing. This is necessary, because tf will compile the predict
         # function just in the moment it is used the first time. This can cause problems, if the model is distributed
         # on different workers. To prevent this, the function is pre-compiled. See discussion @
         # https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252
-        self.model._make_predict_function()
\ No newline at end of file
+        self.model._make_predict_function()
+
+    def _set_gen(self, mode):
+        gen = self.data_store.get("generator", f"general.{mode}")
+        setattr(self, f"{mode}_generator", Distributor(gen, self.model, self.batch_size))
+
+    def set_generators(self):
+        map(lambda mode: self._set_gen(mode), ["train", "val", "test"])
+
+    def train(self):
+        logging.info(f"Train with {len(self.train_generator)} mini batches.")
+        history = self.model.fit_generator(generator=self.train_generator.distribute_on_batches(),
+                                           steps_per_epoch=len(self.train_generator),
+                                           epochs=self.epochs,
+                                           verbose=2,
+                                           validation_data=self.val_generator.distribute_on_batches(),
+                                           validation_steps=len(self.val_generator),
+                                           callbacks=[self.checkpoint, self.lr_sc])
+
+
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb51f20c8771ec49116731f02c7b462a62405394
--- /dev/null
+++ b/test/test_data_handling/test_data_distributor.py
@@ -0,0 +1,76 @@
+import math
+import os
+import shutil
+
+import keras
+import numpy as np
+import pytest
+
+from src.data_handling.data_distributor import Distributor
+from src.data_handling.data_generator import DataGenerator
+from test.test_modules.test_training import my_test_model
+
+
+class TestDistributor:
+
+    @pytest.fixture
+    def generator(self):
+        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
+                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+
+    @pytest.fixture
+    def generator_two_stations(self):
+        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'],
+                             ['o3', 'temp'], 'datetime', 'variables', 'o3',
+                             statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+
+    @pytest.fixture
+    def model(self):
+        return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False)
+
+    @pytest.fixture
+    def model_with_minor_branch(self):
+        return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
+
+    @pytest.fixture
+    def distributor(self, generator, model):
+        return Distributor(generator, model)
+
+    def test_init_defaults(self, distributor):
+        assert distributor.batch_size == 256
+        assert distributor.fit_call is True
+
+    def test_get_model_rank(self, distributor, model_with_minor_branch):
+        assert distributor._get_model_rank() == 1
+        distributor.model = model_with_minor_branch
+        assert distributor._get_model_rank() == 2
+        distributor.model = 1
+
+    def test_get_number_of_mini_batches(self, distributor):
+        values = np.zeros((2, 2311, 19))
+        assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)
+
+    def test_distribute_on_batches_single_loop(self,  generator_two_stations, model):
+        d = Distributor(generator_two_stations, model)
+        for e in d.distribute_on_batches(fit_call=False):
+            assert e[0].shape[0] <= d.batch_size
+
+    def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model):
+        d = Distributor(generator_two_stations, model)
+        elements = []
+        for i, e in enumerate(d.distribute_on_batches()):
+            if i < len(d):
+                elements.append(e[0])
+            elif i == 2*len(d):  # check if all elements are repeated
+                assert np.testing.assert_array_equal(e[0], elements[i - len(d)]) is None
+            else:  # break when 3rd iteration starts (is called as infinite loop)
+                break
+
+    def test_len(self, distributor):
+        assert len(distributor) == math.ceil(len(distributor.generator[0][0]) / 256)
+
+    def test_len_two_stations(self, generator_two_stations, model):
+        gen = generator_two_stations
+        d = Distributor(gen, model)
+        expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
+        assert len(d) == expected
diff --git a/test/test_data_generator.py b/test/test_data_handling/test_data_generator.py
similarity index 90%
rename from test/test_data_generator.py
rename to test/test_data_handling/test_data_generator.py
index 6801e064be804da0368520ab0ef81bdba8bca2d3..879436afddb8da8d11d6cc585da7c703aa12ef8a 100644
--- a/test/test_data_generator.py
+++ b/test/test_data_handling/test_data_generator.py
@@ -1,10 +1,17 @@
 import pytest
 import os
-from src.data_generator import DataGenerator
+import shutil
+from src.data_handling.data_generator import DataGenerator
 
 
 class TestDataGenerator:
 
+    # @pytest.fixture(autouse=True, scope='module')
+    # def teardown_module(self):
+    #     yield
+    #     if "data" in os.listdir(os.path.dirname(__file__)):
+    #         shutil.rmtree(os.path.join(os.path.dirname(__file__), "data"), ignore_errors=True)
+
     @pytest.fixture
     def gen(self):
         return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
diff --git a/test/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py
similarity index 98%
rename from test/test_data_preparation.py
rename to test/test_data_handling/test_data_preparation.py
index 30f93e6d734885252d2c7a438d6065aa680f32f8..12b619d9e31990f6cc24216ff84ad9d030265e36 100644
--- a/test/test_data_preparation.py
+++ b/test/test_data_handling/test_data_preparation.py
@@ -1,8 +1,7 @@
 import pytest
 import os
-from src.data_preparation import DataPrep
+from src.data_handling.data_preparation import DataPrep
 from src.join import EmptyQueryResult
-import logging
 import numpy as np
 import xarray as xr
 import datetime as dt
@@ -45,7 +44,7 @@ class TestDataPrep:
 
     def test_set_file_name_and_meta(self):
         d = object.__new__(DataPrep)
-        d.path = os.path.abspath('test/data/')
+        d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "data")
         d.station = 'TESTSTATION'
         d.variables = ['a', 'bc']
         assert d._set_file_name() == os.path.join(os.path.abspath(os.path.dirname(__file__)),
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
index 7f8c7a051542ff7fc317c0c92454c28f1d0d70b5..5c8223eefebf303733488f08a519208627a2bd91 100644
--- a/test/test_modules/test_model_setup.py
+++ b/test/test_modules/test_model_setup.py
@@ -1,11 +1,10 @@
-import logging
 import pytest
 import os
 import keras
 
 from src.modules.model_setup import ModelSetup
 from src.modules.run_environment import RunEnvironment
-from src.data_generator import DataGenerator
+from src.data_handling.data_generator import DataGenerator
 
 
 class TestModelSetup:
diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py
index 13abe62a2b9199ad8d92528ff5363bd54f1be221..c333322a911732470fc25f413c10f2db14514515 100644
--- a/test/test_modules/test_pre_processing.py
+++ b/test/test_modules/test_pre_processing.py
@@ -1,22 +1,16 @@
 import logging
 import pytest
-import time
 
 from src.helpers import PyTestRegex
 from src.modules.experiment_setup import ExperimentSetup
 from src.modules.pre_processing import PreProcessing, DEFAULT_ARGS_LIST, DEFAULT_KWARGS_LIST
-from src.data_generator import DataGenerator
+from src.data_handling.data_generator import DataGenerator
 from src.datastore import NameNotFoundInScope
 from src.modules.run_environment import RunEnvironment
 
 
 class TestPreProcessing:
 
-    @pytest.fixture
-    def obj_no_init(self):
-        yield object.__new__(PreProcessing)
-        RunEnvironment().__del__()
-
     @pytest.fixture
     def obj_super_init(self):
         obj = object.__new__(PreProcessing)
@@ -42,11 +36,11 @@ class TestPreProcessing:
         ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
                         var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
         caplog.set_level(logging.INFO)
-        PreProcessing()
-        assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
-        assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
-        assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). Found '
-                                                                    r'5/5 valid stations.'))
+        with PreProcessing():
+            assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
+            assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
+            assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). '
+                                                                        r'Found 5/5 valid stations.'))
         RunEnvironment().__del__()
 
     def test_run(self, obj_with_exp_setup):
@@ -97,9 +91,9 @@ class TestPreProcessing:
         assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 6 station\(s\). Found '
                                                                     r'5/6 valid stations.'))
 
-    def test_split_set_indices(self, obj_no_init):
+    def test_split_set_indices(self, obj_super_init):
         dummy_list = list(range(0, 15))
-        train, val, test = obj_no_init.split_set_indices(len(dummy_list), 0.9)
+        train, val, test = obj_super_init.split_set_indices(len(dummy_list), 0.9)
         assert dummy_list[train] == list(range(0, 10))
         assert dummy_list[val] == list(range(10, 13))
         assert dummy_list[test] == list(range(13, 15))
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..590bb38018835243163d1efb1e0634fd4d6e8b2e
--- /dev/null
+++ b/test/test_modules/test_training.py
@@ -0,0 +1,21 @@
+import keras
+
+from src.inception_model import InceptionModelBase
+from src.flatten import flatten_tail
+
+
+def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
+    inception_model = InceptionModelBase()
+    conv_settings_dict1 = {
+        'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
+        'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
+    pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
+    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
+    X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
+    if add_minor_branch:
+        out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
+    else:
+        out = []
+    X_in = keras.layers.Dropout(dropout_rate)(X_in)
+    out.append(flatten_tail(X_in, 'Main', activation=activation))
+    return keras.Model(inputs=X_input, outputs=out)