diff --git a/src/data_generator.py b/src/data_generator.py
index a43d4bf9772ba7d311b4502f53963e72d5ff98e4..b6b469fe95289c4ca7440800948cabe002b09af3 100644
--- a/src/data_generator.py
+++ b/src/data_generator.py
@@ -5,9 +5,10 @@ import keras
 from src import helpers
 from src.data_preparation import DataPrep
 import os
-from typing import Union, List
+from typing import Union, List, Tuple
 import decimal
 import numpy as np
+import xarray as xr
 
 
 class DataGenerator(keras.utils.Sequence):
@@ -52,14 +53,24 @@ class DataGenerator(keras.utils.Sequence):
         """
         return len(self.stations)
 
-    def __iter__(self):
-        self.iterator = 0
+    def __iter__(self) -> "DataGenerator":
+        """
+        Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
+        `_iterator` to 0.
+        :return:
+        """
+        self._iterator = 0
         return self
 
-    def __next__(self):
-        if self.iterator < self.__len__():
+    def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]:
+        """
+        This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
+        the history and label data of this generator.
+        :return:
+        """
+        if self._iterator < self.__len__():
             data = self.get_data_generator()
-            self.iterator += 1
+            self._iterator += 1
             if data.history is not None and data.label is not None:
                 return data.history.transpose("datetime", "window", "Stations", "variables"), \
                     data.label.squeeze("Stations").transpose("datetime", "window")
@@ -68,7 +79,12 @@ class DataGenerator(keras.utils.Sequence):
         else:
             raise StopIteration
 
-    def __getitem__(self, item: Union[str, int]):
+    def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]:
+        """
+        Defines the get item method for this generator. Retrieve data from generator and return history and labels.
+        :param item: station key to choose the data generator.
+        :return: The generator's time series of history data and its labels
+        """
         data = self.get_data_generator(key=item)
         return data.history.transpose("datetime", "window", "Stations", "variables"), \
             data.label.squeeze("Stations").transpose("datetime", "window")
@@ -113,7 +129,7 @@ class DataGenerator(keras.utils.Sequence):
                 raise KeyError(f"More than one key was given: {key}")
         # return station name either from key or the recent element from iterator
         if key is None:
-            return self.stations[self.iterator]
+            return self.stations[self._iterator]
         else:
             if isinstance(key, int):
                 if key < self.__len__():
diff --git a/test/test_data_generator.py b/test/test_data_generator.py
index b316fa887ec925b25e3a362100218e1f7ddbfe89..0ab8dd2d078e6c5b194b0973132f1f6255008785 100644
--- a/test/test_data_generator.py
+++ b/test/test_data_generator.py
@@ -46,12 +46,12 @@ class TestDataGenerator:
         assert hasattr(gen, 'iterator') is False
         iter(gen)
         assert hasattr(gen, 'iterator')
-        assert gen.iterator == 0
+        assert gen._iterator == 0
 
     def test_next(self, gen):
         gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
         for i, d in enumerate(gen, start=1):
-            assert i == gen.iterator
+            assert i == gen._iterator
 
     def test_getitem(self, gen):
         gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
@@ -74,7 +74,7 @@ class TestDataGenerator:
 
     def test_get_key_representation(self, gen):
         gen.stations.append("DEBW108")
-        f = gen.__iter__().get_station_key
+        f = gen.__iter__.get_station_key
         assert f(None) == "DEBW107"
         assert f([None]) == "DEBW107"
         with pytest.raises(KeyError) as e: