diff --git a/src/data_generator.py b/src/data_generator.py
index b6b469fe95289c4ca7440800948cabe002b09af3..f36137d9784b38c6526bdf3a0c5748a71307b12e 100644
--- a/src/data_generator.py
+++ b/src/data_generator.py
@@ -103,7 +103,7 @@ class DataGenerator(keras.utils.Sequence):
"""
Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
remove nans.
- :param key:
+ :param key: station key to choose the data generator.
:return: preprocessed data as a DataPrep instance
"""
station = self.get_station_key(key)
@@ -115,11 +115,11 @@ class DataGenerator(keras.utils.Sequence):
data.history_label_nan_remove(self.interpolate_dim)
return data
- def get_station_key(self, key: Union[str, int, List[Union[str, int]]]) -> str:
+ def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str:
"""
Return a valid station key or raise KeyError if this wasn't possible
- :param key:
- :return:
+ :param key: station key to choose the data generator.
+ :return: station key (id from database)
"""
# extract value if given as list
if isinstance(key, list):
diff --git a/test/test_data_generator.py b/test/test_data_generator.py
index 0ab8dd2d078e6c5b194b0973132f1f6255008785..12162c7d6ebdb12262ab987b62eab6e4cf879ccd 100644
--- a/test/test_data_generator.py
+++ b/test/test_data_generator.py
@@ -74,9 +74,9 @@ class TestDataGenerator:
def test_get_key_representation(self, gen):
gen.stations.append("DEBW108")
- f = gen.__iter__.get_station_key
+ f = gen.get_station_key
+ iter(gen)
assert f(None) == "DEBW107"
- assert f([None]) == "DEBW107"
with pytest.raises(KeyError) as e:
f([None, None])
assert "More than one key was given: [None, None]" in e.value.args[0]