Skip to content
Snippets Groups Projects
Commit ab80fd42 authored by lukas leufen's avatar lukas leufen
Browse files

added docs for some generator methods

parent 202f2311
Branches
Tags
2 merge requests!9new version v0.2.0,!8data generator
...@@ -5,9 +5,10 @@ import keras ...@@ -5,9 +5,10 @@ import keras
from src import helpers from src import helpers
from src.data_preparation import DataPrep from src.data_preparation import DataPrep
import os import os
from typing import Union, List from typing import Union, List, Tuple
import decimal import decimal
import numpy as np import numpy as np
import xarray as xr
class DataGenerator(keras.utils.Sequence): class DataGenerator(keras.utils.Sequence):
...@@ -52,14 +53,24 @@ class DataGenerator(keras.utils.Sequence): ...@@ -52,14 +53,24 @@ class DataGenerator(keras.utils.Sequence):
""" """
return len(self.stations) return len(self.stations)
def __iter__(self): def __iter__(self) -> "DataGenerator":
self.iterator = 0 """
Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
`_iterator` to 0.
:return:
"""
self._iterator = 0
return self return self
def __next__(self): def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]:
if self.iterator < self.__len__(): """
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
if self._iterator < self.__len__():
data = self.get_data_generator() data = self.get_data_generator()
self.iterator += 1 self._iterator += 1
if data.history is not None and data.label is not None: if data.history is not None and data.label is not None:
return data.history.transpose("datetime", "window", "Stations", "variables"), \ return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window") data.label.squeeze("Stations").transpose("datetime", "window")
...@@ -68,7 +79,12 @@ class DataGenerator(keras.utils.Sequence): ...@@ -68,7 +79,12 @@ class DataGenerator(keras.utils.Sequence):
else: else:
raise StopIteration raise StopIteration
def __getitem__(self, item: Union[str, int]): def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]:
"""
Defines the get item method for this generator. Retrieve data from generator and return history and labels.
:param item: station key to choose the data generator.
:return: The generator's time series of history data and its labels
"""
data = self.get_data_generator(key=item) data = self.get_data_generator(key=item)
return data.history.transpose("datetime", "window", "Stations", "variables"), \ return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window") data.label.squeeze("Stations").transpose("datetime", "window")
...@@ -113,7 +129,7 @@ class DataGenerator(keras.utils.Sequence): ...@@ -113,7 +129,7 @@ class DataGenerator(keras.utils.Sequence):
raise KeyError(f"More than one key was given: {key}") raise KeyError(f"More than one key was given: {key}")
# return station name either from key or the recent element from iterator # return station name either from key or the recent element from iterator
if key is None: if key is None:
return self.stations[self.iterator] return self.stations[self._iterator]
else: else:
if isinstance(key, int): if isinstance(key, int):
if key < self.__len__(): if key < self.__len__():
......
...@@ -46,12 +46,12 @@ class TestDataGenerator: ...@@ -46,12 +46,12 @@ class TestDataGenerator:
assert hasattr(gen, 'iterator') is False assert hasattr(gen, 'iterator') is False
iter(gen) iter(gen)
assert hasattr(gen, 'iterator') assert hasattr(gen, 'iterator')
assert gen.iterator == 0 assert gen._iterator == 0
def test_next(self, gen): def test_next(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
for i, d in enumerate(gen, start=1): for i, d in enumerate(gen, start=1):
assert i == gen.iterator assert i == gen._iterator
def test_getitem(self, gen): def test_getitem(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
...@@ -74,7 +74,7 @@ class TestDataGenerator: ...@@ -74,7 +74,7 @@ class TestDataGenerator:
def test_get_key_representation(self, gen): def test_get_key_representation(self, gen):
gen.stations.append("DEBW108") gen.stations.append("DEBW108")
f = gen.__iter__().get_station_key f = gen.__iter__.get_station_key
assert f(None) == "DEBW107" assert f(None) == "DEBW107"
assert f([None]) == "DEBW107" assert f([None]) == "DEBW107"
with pytest.raises(KeyError) as e: with pytest.raises(KeyError) as e:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment