diff --git a/src/data_handling/iterator.py b/src/data_handling/iterator.py
index 4cfa459ac4238af3365a509928925299fd36b357..e6831e44f53bb545f177fff23e63a05f78cbeb93 100644
--- a/src/data_handling/iterator.py
+++ b/src/data_handling/iterator.py
@@ -1,4 +1,7 @@
 
+__author__ = 'Lukas Leufen'
+__date__ = '2020-07-07'
+
 from collections import Iterator, Iterable
 import keras
 import numpy as np
@@ -6,17 +9,20 @@ import math
 import os
 import shutil
 import pickle
+from typing import Tuple, List
 
 
 class StandardIterator(Iterator):
 
     _position: int = None
 
-    def __init__(self, collection):
+    def __init__(self, collection: list):
+        assert isinstance(collection, list)
         self._collection = collection
         self._position = 0
 
     def __next__(self):
+        """Return next element or stop iteration."""
         try:
             value = self._collection[self._position]
             self._position += 1
@@ -27,97 +33,107 @@ class StandardIterator(Iterator):
 
 class DataCollection(Iterable):
 
-    def __init__(self, collection):
+    def __init__(self, collection: list):
+        assert isinstance(collection, list)
         self._collection = collection
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator:
         return StandardIterator(self._collection)
 
 
 class KerasIterator(keras.utils.Sequence):
 
-    def __init__(self, collection, batch_size, path, shuffle=False):
-        self._collection: DataCollection = collection
+    def __init__(self, collection: DataCollection, batch_size: int, path: str, shuffle: bool = False):
+        self._collection = collection
         self._path = os.path.join(path, "%i.pickle")
         self.batch_size = batch_size
         self.shuffle = shuffle
-        self.indexes = []
+        self.indexes: list = []
         self._cleanup_path(path)
         self._prepare_batches()
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.indexes)
 
-    def __getitem__(self, index):
+    def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
+        """Get batch for given index."""
         return self.__data_generation(self.indexes[index])
 
-    def __data_generation(self, index):
+    def __data_generation(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
+        """Load pickle data from disk."""
         file = self._path % index
         with open(file, "rb") as f:
             data = pickle.load(f)
         return data["X"], data["Y"]
 
     @staticmethod
-    def _concatenate(new, old):
+    def _concatenate(new: List[np.ndarray], old: List[np.ndarray]) -> List[np.ndarray]:
+        """Concatenate two lists of data along axis=0."""
         return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new))
 
-    def _get_batch(self, data_list, b):
+    def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]:
+        """Get batch according to batch size from data list."""
         return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list))
 
-    def _prepare_batches(self):
+    def _prepare_batches(self) -> None:
+        """
+        Prepare all batches as locally stored files.
+
+        Walk through all elements of collection and split (or merge) data according to the batch size. Too long data
+        sets are divided into multiple batches. Not fully filled batches are merged with data from the next collection
+        element. If data is remaining after the last element, it is saved as smaller batch. All batches are enumerated
+        beginning from 0. A list with all batch numbers is stored in class's parameter indexes.
+        """
         index = 0
         remaining = None
         for data in self._collection:
             X, Y = data.get_X(), data.get_Y()
             if remaining is not None:
-                # X = np.concatenate((remaining[0], X), axis=0)
-                # Y = np.concatenate((remaining[1], Y), axis=0)
-                X = self._concatenate(X, remaining[0])
-                Y = self._concatenate(Y, remaining[1])
-            # check shape
+                X, Y = self._concatenate(X, remaining[0]), self._concatenate(Y, remaining[1])
             length = X[0].shape[0]
             batches = self._get_number_of_mini_batches(length)
             for b in range(batches):
-                # batch_X = X[b * self.batch_size:(b+1) * self.batch_size, ...]
-                # batch_Y = Y[b * self.batch_size:(b+1) * self.batch_size, ...]
-                batch_X = self._get_batch(X, b)
-                batch_Y = self._get_batch(Y, b)
+                batch_X, batch_Y = self._get_batch(X, b), self._get_batch(Y, b)
                 self._save_to_pickle(X=batch_X, Y=batch_Y, index=index)
                 index += 1
-            if (batches * self.batch_size) < length:
+            if (batches * self.batch_size) < length:  # keep remaining to concatenate with next data element
                 remaining = (self._get_batch(X, batches), self._get_batch(Y, batches))
             else:
                 remaining = None
-        if remaining is not None:
+        if remaining is not None:  # add remaining as smaller batch
             self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index)
             index += 1
         self.indexes = np.arange(0, index).tolist()
 
-    def _save_to_pickle(self, X, Y, index):
+    def _save_to_pickle(self, X: List[np.ndarray], Y: List[np.ndarray], index: int) -> None:
+        """Save data as pickle file with variables X and Y and given index as <index>.pickle ."""
         data = {"X": X, "Y": Y}
         file = self._path % index
         with open(file, "wb") as f:
             pickle.dump(data, f)
 
-    def _get_number_of_mini_batches(self, number_of_samples):
+    def _get_number_of_mini_batches(self, number_of_samples: int) -> int:
+        """Return number of mini batches as the floored ration of number of samples to batch size."""
         return math.floor(number_of_samples / self.batch_size)
 
     @staticmethod
-    def _cleanup_path(path, create_new=True):
+    def _cleanup_path(path: str, create_new: bool = True) -> None:
+        """First remove existing path, second create empty path if enabled."""
         if os.path.exists(path):
             shutil.rmtree(path)
         if create_new is True:
             os.makedirs(path)
 
-    def on_epoch_end(self):
+    def on_epoch_end(self) -> None:
+        """Randomly shuffle indexes if enabled."""
         if self.shuffle is True:
             np.random.shuffle(self.indexes)
 
 
 class DummyData:
 
-    def __init__(self):
-        self.number_of_samples = np.random.randint(100, 150)
+    def __init__(self, number_of_samples=np.random.randint(100, 150)):
+        self.number_of_samples = number_of_samples
 
     def get_X(self):
         X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables
@@ -135,12 +151,12 @@ if __name__ == "__main__":
 
     collection = []
     for _ in range(3):
-        collection.append(DummyData())
+        collection.append(DummyData(50))
 
     data_collection = DataCollection(collection=collection)
 
     path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
-    iterator = KerasIterator(data_collection, 1000, path, shuffle=True)
+    iterator = KerasIterator(data_collection, 25, path, shuffle=True)
 
     for data in data_collection:
         print(data)
\ No newline at end of file
diff --git a/test/test_data_handling/test_iterator.py b/test/test_data_handling/test_iterator.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f1cf683d627495cf958b6c2376a5c42a4c6e61f
--- /dev/null
+++ b/test/test_data_handling/test_iterator.py
@@ -0,0 +1,201 @@
+
+from src.data_handling.iterator import DataCollection, StandardIterator, KerasIterator
+from src.helpers.testing import PyTestAllEqual
+
+import numpy as np
+import pytest
+import os
+import shutil
+
+
+class TestStandardIterator:
+
+    @pytest.fixture
+    def collection(self):
+        return list(range(10))
+
+    def test_blank(self):
+        std_iterator = object.__new__(StandardIterator)
+        assert std_iterator._position is None
+
+    def test_init(self, collection):
+        std_iterator = StandardIterator(collection)
+        assert std_iterator._collection == list(range(10))
+        assert std_iterator._position == 0
+
+    def test_next(self, collection):
+        std_iterator = StandardIterator(collection)
+        for i in range(10):
+            assert i == next(std_iterator)
+        with pytest.raises(StopIteration):
+            next(std_iterator)
+        std_iterator = StandardIterator(collection)
+        for e, i in enumerate(iter(std_iterator)):
+            assert i == e
+
+
+class TestDataCollection:
+
+    @pytest.fixture
+    def collection(self):
+        return list(range(10))
+
+    def test_init(self, collection):
+        data_collection = DataCollection(collection)
+        assert data_collection._collection == collection
+
+    def test_iter(self, collection):
+        data_collection = DataCollection(collection)
+        assert isinstance(iter(data_collection), StandardIterator)
+        for e, i in enumerate(data_collection):
+            assert i == e
+
+
+class DummyData:
+
+    def __init__(self, number_of_samples=np.random.randint(100, 150)):
+        self.number_of_samples = number_of_samples
+
+    def get_X(self):
+        X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5))  # samples, window, variables
+        X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2))  # samples, window, variables
+        X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2))  # samples, window, variables
+        return [X1, X2, X3]
+
+    def get_Y(self):
+        Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1))  # samples, window, variables
+        Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1))  # samples, window, variables
+        return [Y1, Y2]
+
+
+class TestKerasIterator:
+
+    @pytest.fixture
+    def collection(self):
+        coll = []
+        for i in range(3):
+            coll.append(DummyData(50 + i))
+        data_coll = DataCollection(collection=coll)
+        return data_coll
+
+    @pytest.fixture
+    def path(self):
+        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+        shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
+        yield p
+        shutil.rmtree(p, ignore_errors=True)
+
+    def test_init(self, collection, path):
+        iterator = KerasIterator(collection, 25, path)
+        assert isinstance(iterator._collection, DataCollection)
+        assert iterator._path == os.path.join(path, "%i.pickle")
+        assert iterator.batch_size == 25
+        assert iterator.shuffle is False
+
+    def test_cleanup_path(self, path):
+        assert os.path.exists(path) is False
+        iterator = object.__new__(KerasIterator)
+        iterator._cleanup_path(path, create_new=False)
+        assert os.path.exists(path) is False
+        iterator._cleanup_path(path)
+        assert os.path.exists(path) is True
+        iterator._cleanup_path(path, create_new=False)
+        assert os.path.exists(path) is False
+
+    def test_get_number_of_mini_batches(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.batch_size = 36
+        assert iterator._get_number_of_mini_batches(30) == 0
+        assert iterator._get_number_of_mini_batches(40) == 1
+        assert iterator._get_number_of_mini_batches(72) == 2
+
+    def test_len(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.indexes = [0, 1, 2, 3, 4, 5]
+        assert len(iterator) == 6
+
+    def test_concatenate(self):
+        arr1 = DummyData(10).get_X()
+        arr2 = DummyData(50).get_X()
+        iterator = object.__new__(KerasIterator)
+        new_arr = iterator._concatenate(arr2, arr1)
+        test_arr = [np.concatenate((arr1[0], arr2[0]), axis=0),
+                    np.concatenate((arr1[1], arr2[1]), axis=0),
+                    np.concatenate((arr1[2], arr2[2]), axis=0)]
+        for i in range(3):
+            assert PyTestAllEqual([new_arr[i], test_arr[i]])
+
+    def test_get_batch(self):
+        arr = DummyData(20).get_X()
+        iterator = object.__new__(KerasIterator)
+        iterator.batch_size = 19
+        batch1 = iterator._get_batch(arr, 0)
+        assert batch1[0].shape[0] == 19
+        batch2 = iterator._get_batch(arr, 1)
+        assert batch2[0].shape[0] == 1
+
+    def test_save_to_pickle(self, path):
+        os.makedirs(path)
+        d = DummyData(20)
+        X, Y = d.get_X(), d.get_Y()
+        iterator = object.__new__(KerasIterator)
+        iterator._path = os.path.join(path, "%i.pickle")
+        assert os.path.exists(iterator._path % 2) is False
+        iterator._save_to_pickle(X=X, Y=Y, index=2)
+        assert os.path.exists(iterator._path % 2) is True
+
+    def test_prepare_batches(self, collection, path):
+        iterator = object.__new__(KerasIterator)
+        iterator._collection = collection
+        iterator.batch_size = 50
+        iterator.indexes = []
+        iterator._path = os.path.join(path, "%i.pickle")
+        os.makedirs(path)
+        iterator._prepare_batches()
+        assert len(os.listdir(path)) == 4
+        assert len(iterator.indexes) == 4
+        assert len(iterator) == 4
+        assert iterator.indexes == [0, 1, 2, 3]
+
+    def test_prepare_batches_no_remaining(self, path):
+        iterator = object.__new__(KerasIterator)
+        iterator._collection = DataCollection([DummyData(50)])
+        iterator.batch_size = 50
+        iterator.indexes = []
+        iterator._path = os.path.join(path, "%i.pickle")
+        os.makedirs(path)
+        iterator._prepare_batches()
+        assert len(os.listdir(path)) == 1
+        assert len(iterator.indexes) == 1
+        assert len(iterator) == 1
+        assert iterator.indexes == [0]
+
+    def test_data_generation(self, collection, path):
+        iterator = KerasIterator(collection, 50, path)
+        X, Y = iterator._KerasIterator__data_generation(0)
+        expected = next(iter(collection))
+        assert PyTestAllEqual([X, expected.get_X()])
+        assert PyTestAllEqual([Y, expected.get_Y()])
+
+    def test_getitem(self, collection, path):
+        iterator = KerasIterator(collection, 50, path)
+        X, Y = iterator[0]
+        expected = next(iter(collection))
+        assert PyTestAllEqual([X, expected.get_X()])
+        assert PyTestAllEqual([Y, expected.get_Y()])
+        reversed(iterator.indexes)
+        X, Y = iterator[3]
+        assert PyTestAllEqual([X, expected.get_X()])
+        assert PyTestAllEqual([Y, expected.get_Y()])
+
+    def test_on_epoch_end(self):
+        iterator = object.__new__(KerasIterator)
+        iterator.indexes = [0, 1, 2, 3, 4]
+        iterator.shuffle = False
+        iterator.on_epoch_end()
+        assert iterator.indexes == [0, 1, 2, 3, 4]
+        iterator.shuffle = True
+        while iterator.indexes == sorted(iterator.indexes):
+            iterator.on_epoch_end()
+        assert iterator.indexes != [0, 1, 2, 3, 4]
+        assert sorted(iterator.indexes) == [0, 1, 2, 3, 4]
\ No newline at end of file