Skip to content
Snippets Groups Projects
Commit e86daa57 authored by lukas leufen's avatar lukas leufen
Browse files

removed threshold param, because it is not related to this class

parent 6e754bac
No related branches found
No related tags found
2 merge requests!9new version v0.2.0,!8data generator
Pipeline #25753 passed
...@@ -37,7 +37,6 @@ class DataGenerator(keras.utils.Sequence): ...@@ -37,7 +37,6 @@ class DataGenerator(keras.utils.Sequence):
self.window_lead_time = window_lead_time self.window_lead_time = window_lead_time
self.transform_method = transform_method self.transform_method = transform_method
self.kwargs = kwargs self.kwargs = kwargs
self.threshold = self.threshold_setup()
def __repr__(self): def __repr__(self):
""" """
...@@ -89,16 +88,6 @@ class DataGenerator(keras.utils.Sequence): ...@@ -89,16 +88,6 @@ class DataGenerator(keras.utils.Sequence):
return data.history.transpose("datetime", "window", "Stations", "variables"), \ return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window") data.label.squeeze("Stations").transpose("datetime", "window")
def threshold_setup(self) -> List[str]:
"""
set threshold for given min/max and number of steps. defaults are [0, 100] with n=200 steps
:return:
"""
thr_min = self.kwargs.get('thr_min', 0)
thr_max = self.kwargs.get('thr_max', 100)
thr_number_of_steps = self.kwargs.get('thr_number_of_steps', 200)
return [str(decimal.Decimal("%.4f" % e)) for e in np.linspace(thr_min, thr_max, thr_number_of_steps)]
def get_data_generator(self, key: Union[str, int] = None) -> DataPrep: def get_data_generator(self, key: Union[str, int] = None) -> DataPrep:
""" """
Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
......
...@@ -30,7 +30,6 @@ class TestDataGenerator: ...@@ -30,7 +30,6 @@ class TestDataGenerator:
assert gen.window_lead_time == 4 assert gen.window_lead_time == 4
assert gen.transform_method == "standardise" assert gen.transform_method == "standardise"
assert gen.kwargs == {} assert gen.kwargs == {}
assert gen.threshold is not None
def test_repr(self, gen): def test_repr(self, gen):
path = os.path.join(os.path.dirname(__file__), 'data') path = os.path.join(os.path.dirname(__file__), 'data')
...@@ -63,16 +62,6 @@ class TestDataGenerator: ...@@ -63,16 +62,6 @@ class TestDataGenerator:
assert station[1].data.shape[-1] == gen.window_lead_time assert station[1].data.shape[-1] == gen.window_lead_time
assert station[0].data.shape[1] == gen.window_history + 1 assert station[0].data.shape[1] == gen.window_history + 1
def test_threshold_setup(self, gen):
def res(arg, val):
gen.kwargs[arg] = val
return list(map(float, gen.threshold_setup()))
compare = np.testing.assert_array_almost_equal
assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None
assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None
assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None
assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None
def test_get_key_representation(self, gen): def test_get_key_representation(self, gen):
gen.stations.append("DEBW108") gen.stations.append("DEBW108")
f = gen.get_station_key f = gen.get_station_key
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment