Skip to content
Snippets Groups Projects
Commit ab80fd42 authored by lukas leufen's avatar lukas leufen
Browse files

added docs for some generator methods

parent 202f2311
No related branches found
No related tags found
2 merge requests!9new version v0.2.0,!8data generator
......@@ -5,9 +5,10 @@ import keras
from src import helpers
from src.data_preparation import DataPrep
import os
from typing import Union, List
from typing import Union, List, Tuple
import decimal
import numpy as np
import xarray as xr
class DataGenerator(keras.utils.Sequence):
......@@ -52,14 +53,24 @@ class DataGenerator(keras.utils.Sequence):
"""
return len(self.stations)
def __iter__(self):
self.iterator = 0
def __iter__(self) -> "DataGenerator":
"""
Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
`_iterator` to 0.
:return:
"""
self._iterator = 0
return self
def __next__(self):
if self.iterator < self.__len__():
def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]:
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
if self._iterator < self.__len__():
data = self.get_data_generator()
self.iterator += 1
self._iterator += 1
if data.history is not None and data.label is not None:
return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window")
......@@ -68,7 +79,12 @@ class DataGenerator(keras.utils.Sequence):
else:
raise StopIteration
def __getitem__(self, item: Union[str, int]):
def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]:
"""
Defines the get item method for this generator. Retrieve data from generator and return history and labels.
:param item: station key to choose the data generator.
:return: The generator's time series of history data and its labels
"""
data = self.get_data_generator(key=item)
return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window")
......@@ -113,7 +129,7 @@ class DataGenerator(keras.utils.Sequence):
raise KeyError(f"More than one key was given: {key}")
# return station name either from key or the recent element from iterator
if key is None:
return self.stations[self.iterator]
return self.stations[self._iterator]
else:
if isinstance(key, int):
if key < self.__len__():
......
......@@ -46,12 +46,12 @@ class TestDataGenerator:
assert hasattr(gen, 'iterator') is False
iter(gen)
assert hasattr(gen, 'iterator')
assert gen.iterator == 0
assert gen._iterator == 0
def test_next(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
for i, d in enumerate(gen, start=1):
assert i == gen.iterator
assert i == gen._iterator
def test_getitem(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
......@@ -74,7 +74,7 @@ class TestDataGenerator:
def test_get_key_representation(self, gen):
gen.stations.append("DEBW108")
f = gen.__iter__().get_station_key
f = gen.__iter__.get_station_key
assert f(None) == "DEBW107"
assert f([None]) == "DEBW107"
with pytest.raises(KeyError) as e:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment