From e86daa579e5a8f0859d91044820b5a4f7fed00f5 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 12 Nov 2019 16:02:38 +0100
Subject: [PATCH] removed threshold param, because it is not related to this
 class

---
 src/data_generator.py       | 11 -----------
 test/test_data_generator.py | 11 -----------
 2 files changed, 22 deletions(-)

diff --git a/src/data_generator.py b/src/data_generator.py
index f36137d9..d067e1e9 100644
--- a/src/data_generator.py
+++ b/src/data_generator.py
@@ -37,7 +37,6 @@ class DataGenerator(keras.utils.Sequence):
         self.window_lead_time = window_lead_time
         self.transform_method = transform_method
         self.kwargs = kwargs
-        self.threshold = self.threshold_setup()
 
     def __repr__(self):
         """
@@ -89,16 +88,6 @@ class DataGenerator(keras.utils.Sequence):
         return data.history.transpose("datetime", "window", "Stations", "variables"), \
             data.label.squeeze("Stations").transpose("datetime", "window")
 
-    def threshold_setup(self) -> List[str]:
-        """
-        set threshold for given min/max and number of steps. defaults are [0, 100] with n=200 steps
-        :return:
-        """
-        thr_min = self.kwargs.get('thr_min', 0)
-        thr_max = self.kwargs.get('thr_max', 100)
-        thr_number_of_steps = self.kwargs.get('thr_number_of_steps', 200)
-        return [str(decimal.Decimal("%.4f" % e)) for e in np.linspace(thr_min, thr_max, thr_number_of_steps)]
-
     def get_data_generator(self, key: Union[str, int] = None) -> DataPrep:
         """
         Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
diff --git a/test/test_data_generator.py b/test/test_data_generator.py
index 4da70783..97122a26 100644
--- a/test/test_data_generator.py
+++ b/test/test_data_generator.py
@@ -30,7 +30,6 @@ class TestDataGenerator:
         assert gen.window_lead_time == 4
         assert gen.transform_method == "standardise"
         assert gen.kwargs == {}
-        assert gen.threshold is not None
 
     def test_repr(self, gen):
         path = os.path.join(os.path.dirname(__file__), 'data')
@@ -63,16 +62,6 @@ class TestDataGenerator:
         assert station[1].data.shape[-1] == gen.window_lead_time
         assert station[0].data.shape[1] == gen.window_history + 1
 
-    def test_threshold_setup(self, gen):
-        def res(arg, val):
-            gen.kwargs[arg] = val
-            return list(map(float, gen.threshold_setup()))
-        compare = np.testing.assert_array_almost_equal
-        assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None
-        assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None
-        assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None
-        assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None
-
     def test_get_key_representation(self, gen):
         gen.stations.append("DEBW108")
         f = gen.get_station_key
-- 
GitLab