diff --git a/src/data_generator.py b/src/data_generator.py index f36137d9784b38c6526bdf3a0c5748a71307b12e..d067e1e9e0d3225e869ca6a2944b7c749198834b 100644 --- a/src/data_generator.py +++ b/src/data_generator.py @@ -37,7 +37,6 @@ class DataGenerator(keras.utils.Sequence): self.window_lead_time = window_lead_time self.transform_method = transform_method self.kwargs = kwargs - self.threshold = self.threshold_setup() def __repr__(self): """ @@ -89,16 +88,6 @@ class DataGenerator(keras.utils.Sequence): return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") - def threshold_setup(self) -> List[str]: - """ - set threshold for given min/max and number of steps. defaults are [0, 100] with n=200 steps - :return: - """ - thr_min = self.kwargs.get('thr_min', 0) - thr_max = self.kwargs.get('thr_max', 100) - thr_number_of_steps = self.kwargs.get('thr_number_of_steps', 200) - return [str(decimal.Decimal("%.4f" % e)) for e in np.linspace(thr_min, thr_max, thr_number_of_steps)] - def get_data_generator(self, key: Union[str, int] = None) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and diff --git a/test/test_data_generator.py b/test/test_data_generator.py index 4da70783bf1718c033f50969a78e03391bec6bbf..97122a265b911524a06372f7e5a99f02fcf7dd3f 100644 --- a/test/test_data_generator.py +++ b/test/test_data_generator.py @@ -30,7 +30,6 @@ class TestDataGenerator: assert gen.window_lead_time == 4 assert gen.transform_method == "standardise" assert gen.kwargs == {} - assert gen.threshold is not None def test_repr(self, gen): path = os.path.join(os.path.dirname(__file__), 'data') @@ -63,16 +62,6 @@ class TestDataGenerator: assert station[1].data.shape[-1] == gen.window_lead_time assert station[0].data.shape[1] == gen.window_history + 1 - def test_threshold_setup(self, gen): - def res(arg, val): - gen.kwargs[arg] = val - return list(map(float, gen.threshold_setup())) - compare = np.testing.assert_array_almost_equal - assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None - assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None - assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None - assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None - def test_get_key_representation(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key