diff --git a/src/data_generator.py b/src/data_generator.py index a43d4bf9772ba7d311b4502f53963e72d5ff98e4..b6b469fe95289c4ca7440800948cabe002b09af3 100644 --- a/src/data_generator.py +++ b/src/data_generator.py @@ -5,9 +5,10 @@ import keras from src import helpers from src.data_preparation import DataPrep import os -from typing import Union, List +from typing import Union, List, Tuple import decimal import numpy as np +import xarray as xr class DataGenerator(keras.utils.Sequence): @@ -52,14 +53,24 @@ class DataGenerator(keras.utils.Sequence): """ return len(self.stations) - def __iter__(self): - self.iterator = 0 + def __iter__(self) -> "DataGenerator": + """ + Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute + `_iterator` to 0. + :return: + """ + self._iterator = 0 return self - def __next__(self): - if self.iterator < self.__len__(): + def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]: + """ + This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return + the history and label data of this generator. + :return: + """ + if self._iterator < self.__len__(): data = self.get_data_generator() - self.iterator += 1 + self._iterator += 1 if data.history is not None and data.label is not None: return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") @@ -68,7 +79,12 @@ class DataGenerator(keras.utils.Sequence): else: raise StopIteration - def __getitem__(self, item: Union[str, int]): + def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]: + """ + Defines the get item method for this generator. Retrieve data from generator and return history and labels. + :param item: station key to choose the data generator. + :return: The generator's time series of history data and its labels + """ data = self.get_data_generator(key=item) return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") @@ -113,7 +129,7 @@ class DataGenerator(keras.utils.Sequence): raise KeyError(f"More than one key was given: {key}") # return station name either from key or the recent element from iterator if key is None: - return self.stations[self.iterator] + return self.stations[self._iterator] else: if isinstance(key, int): if key < self.__len__(): diff --git a/test/test_data_generator.py b/test/test_data_generator.py index b316fa887ec925b25e3a362100218e1f7ddbfe89..0ab8dd2d078e6c5b194b0973132f1f6255008785 100644 --- a/test/test_data_generator.py +++ b/test/test_data_generator.py @@ -46,12 +46,12 @@ class TestDataGenerator: assert hasattr(gen, 'iterator') is False iter(gen) assert hasattr(gen, 'iterator') - assert gen.iterator == 0 + assert gen._iterator == 0 def test_next(self, gen): gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} for i, d in enumerate(gen, start=1): - assert i == gen.iterator + assert i == gen._iterator def test_getitem(self, gen): gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} @@ -74,7 +74,7 @@ class TestDataGenerator: def test_get_key_representation(self, gen): gen.stations.append("DEBW108") - f = gen.__iter__().get_station_key + f = gen.__iter__.get_station_key assert f(None) == "DEBW107" assert f([None]) == "DEBW107" with pytest.raises(KeyError) as e: