From ab80fd422e7c9127106130030dc496ff6faa6b65 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 11 Nov 2019 09:16:38 +0100
Subject: [PATCH] added docs for some generator methods

---
 src/data_generator.py       | 32 ++++++++++++++++++++++++--------
 test/test_data_generator.py |  6 +++---
 2 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/src/data_generator.py b/src/data_generator.py
index a43d4bf9..b6b469fe 100644
--- a/src/data_generator.py
+++ b/src/data_generator.py
@@ -5,9 +5,10 @@ import keras
 from src import helpers
 from src.data_preparation import DataPrep
 import os
-from typing import Union, List
+from typing import Union, List, Tuple
 import decimal
 import numpy as np
+import xarray as xr
 
 
 class DataGenerator(keras.utils.Sequence):
@@ -52,14 +53,24 @@ class DataGenerator(keras.utils.Sequence):
         """
         return len(self.stations)
 
-    def __iter__(self):
-        self.iterator = 0
+    def __iter__(self) -> "DataGenerator":
+        """
+        Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
+        `_iterator` to 0.
+        :return:
+        """
+        self._iterator = 0
         return self
 
-    def __next__(self):
-        if self.iterator < self.__len__():
+    def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]:
+        """
+        This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
+        the history and label data of this generator.
+        :return:
+        """
+        if self._iterator < self.__len__():
             data = self.get_data_generator()
-            self.iterator += 1
+            self._iterator += 1
             if data.history is not None and data.label is not None:
                 return data.history.transpose("datetime", "window", "Stations", "variables"), \
                     data.label.squeeze("Stations").transpose("datetime", "window")
@@ -68,7 +79,12 @@ class DataGenerator(keras.utils.Sequence):
         else:
             raise StopIteration
 
-    def __getitem__(self, item: Union[str, int]):
+    def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]:
+        """
+        Defines the get item method for this generator. Retrieve data from generator and return history and labels.
+        :param item: station key to choose the data generator.
+        :return: The generator's time series of history data and its labels
+        """
         data = self.get_data_generator(key=item)
         return data.history.transpose("datetime", "window", "Stations", "variables"), \
             data.label.squeeze("Stations").transpose("datetime", "window")
@@ -113,7 +129,7 @@ class DataGenerator(keras.utils.Sequence):
                 raise KeyError(f"More than one key was given: {key}")
         # return station name either from key or the recent element from iterator
         if key is None:
-            return self.stations[self.iterator]
+            return self.stations[self._iterator]
         else:
             if isinstance(key, int):
                 if key < self.__len__():
diff --git a/test/test_data_generator.py b/test/test_data_generator.py
index b316fa88..0ab8dd2d 100644
--- a/test/test_data_generator.py
+++ b/test/test_data_generator.py
@@ -46,12 +46,12 @@ class TestDataGenerator:
         assert hasattr(gen, 'iterator') is False
         iter(gen)
         assert hasattr(gen, 'iterator')
-        assert gen.iterator == 0
+        assert gen._iterator == 0
 
     def test_next(self, gen):
         gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
         for i, d in enumerate(gen, start=1):
-            assert i == gen.iterator
+            assert i == gen._iterator
 
     def test_getitem(self, gen):
         gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
@@ -74,7 +74,7 @@ class TestDataGenerator:
 
     def test_get_key_representation(self, gen):
         gen.stations.append("DEBW108")
-        f = gen.__iter__().get_station_key
+        f = gen.__iter__.get_station_key
         assert f(None) == "DEBW107"
         assert f([None]) == "DEBW107"
         with pytest.raises(KeyError) as e:
-- 
GitLab