From 6bd57ea0fb834d3f30a8b23fe2753ae0f0447b0a Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Wed, 1 Apr 2020 16:42:47 +0200
Subject: [PATCH] little bugfix: replaced np.repeat by np.tile for correct
 array repetition, first tests for the BootStraps class

---
 src/data_handling/bootstraps.py            |   6 +-
 test/test_data_handling/test_bootstraps.py | 160 +++++++++++----------
 2 files changed, 90 insertions(+), 76 deletions(-)

diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py
index 3e86ed96..3fdabe4c 100644
--- a/src/data_handling/bootstraps.py
+++ b/src/data_handling/bootstraps.py
@@ -165,12 +165,12 @@ class BootStraps:
 
     def get_labels(self, key: Union[str, int]):
         """
-        Reepats labels for given key by the number of boots and yield it one by one.
+        Repeats labels for given key by the number of boots and yield it one by one.
         :param key: key of station (either station name as string or the position in generator as integer)
         :return: yields labels for length of boots
         """
         labels = self.data[key][1]
-        return labels.data.repeat(self.number_of_bootstraps, axis=0)
+        return np.tile(labels.data, (self.number_of_bootstraps, 1))
 
     def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN"):
         """
@@ -182,7 +182,7 @@ class BootStraps:
         """
         file = os.path.join(path, file_name)
         prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
-        vals = prediction.data.repeat(self.number_of_bootstraps, axis=0)
+        vals = np.tile(prediction.data, (self.number_of_bootstraps, 1))
         return vals[~np.isnan(vals).any(axis=1), :]
 
     def _load_shuffled_data(self, station: str, variables: List[str]) -> xr.DataArray:
diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py
index 46f563bb..45d8c86f 100644
--- a/test/test_data_handling/test_bootstraps.py
+++ b/test/test_data_handling/test_bootstraps.py
@@ -17,7 +17,7 @@ import xarray as xr
 ### old ###
 
 
-class TestBootStraps:
+class TestBootStrapsOld:
 
     @pytest.fixture
     def boot_gen(self, orig_generator):
@@ -29,55 +29,6 @@ class TestBootStraps:
         dummy_content.to_netcdf(os.path.join(path, "forecasts_norm_DEBW107_test.nc"))
         return BootStraps(orig_generator, path, 20)
 
-    @pytest.fixture
-    def boot_gen_real(self, orig_generator):
-        path = os.path.join(os.path.dirname(__file__), 'data')
-        for hist, _ in orig_generator:
-            hist = hist.expand_dims({"boots": [1]})
-            station = orig_generator.get_station_key(orig_generator._iterator-1)
-            hist.to_netcdf(os.path.join(path, f"{station}_o3_temp_hist7_nboots20_shuffled.nc"))
-        return BootStrapGenerator(orig_generator, 20, path)
-
-    def test_init(self, orig_generator):
-        gen = BootStraps(orig_generator, os.path.join(os.path.dirname(__file__), 'data'), 20)
-        # assert gen.stations == ["DEBW107", "DEBW013"]
-        # assert gen.variables == ["o3", "temp"]
-        assert isinstance(gen.data, DataGenerator)
-        assert gen.number_of_bootstraps == 20
-        assert gen.bootstrap_path == os.path.join(os.path.dirname(__file__), 'data')
-
-    def test_get_generator(self, boot_gen_real, orig_generator):
-        res = boot_gen_real.get_generator("DEBW107", "o3")
-        hist = orig_generator.get_data_generator("DEBW107").get_transposed_history()
-        assert xr.testing.assert_equal(res[0], hist) is None
-        label = orig_generator.get_data_generator("DEBW107").get_transposed_label()
-        assert xr.testing.assert_equal(res[1], label) is None
-        assert isinstance(res[2], typing.Callable)
-        assert res[3] == 20
-
-    def test_get_generator_station_var_wise(self, boot_gen, orig_generator):
-        res = boot_gen.get_generator("DEBW107", "o3")
-        hist = orig_generator.get_data_generator("DEBW107").get_transposed_history()
-        assert xr.testing.assert_equal(res[0], hist) is None
-        label = orig_generator.get_data_generator("DEBW107").get_transposed_label()
-        assert xr.testing.assert_equal(res[1], label) is None
-        assert isinstance(res[2], typing.Callable)
-        assert res[3] == 20
-
-    def test_get_bootstrap_station_var_wise_meta(self, boot_gen):
-        meta = boot_gen.get_bootstrap_meta_station_var_wise("DEBW107", "o3")
-        labels = boot_gen.orig_generator.get_data_generator("DEBW107").get_transposed_label().shape[0]
-        assert np.shape(meta) == (labels * boot_gen.number_of_boots, 2)
-        assert np.testing.assert_array_equal(np.unique(meta), ["DEBW107", "o3"]) is None
-
-    def test_get_labels(self, boot_gen):
-        res = []
-        for label in boot_gen.get_labels("DEBW107"):
-            res.append(label)
-        assert len(res) == boot_gen.number_of_boots
-        assert xr.testing.assert_equal(res[0], res[-1]) is None
-        assert PyTestAllEqual(res).is_true()
-
     def test_get_orig_prediction(self, boot_gen):
         path = boot_gen.orig_generator.data_path
         res = []
@@ -91,29 +42,6 @@ class TestBootStraps:
         assert isinstance(shuffled_data, xr.DataArray)
         assert all(shuffled_data.compute().values == [1, 2, 3])
 
-    def test_get_shuffled_data_file(self, boot_gen):
-        file_name = boot_gen._get_shuffled_data_file("DEBW107", ["o3"])
-        assert file_name == os.path.join(boot_gen.bootstrap_path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc")
-
-    def test_get_shuffled_data_file_not_found(self, boot_gen):
-        boot_gen.number_of_boots = 100
-        with pytest.raises(FileNotFoundError) as e:
-            boot_gen._get_shuffled_data_file("DEBW107", ["o3"])
-        assert "Could not find a file to match pattern" in e.value.args[0]
-
-    def test_create_file_regex(self, boot_gen):
-        regex = boot_gen._create_file_regex("DEBW108", ["o3", "temp", "h2o"])
-        test_list = ["DEBW108_o3_test23_test_shuffled.nc",
-                     "DEBW107_o3_test23_test_shuffled.nc",
-                     "DEBW108_o3_test23_test.nc",
-                     "DEBW108_h2o_o3_temp_test_shuffled.nc",
-                     "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc",
-                     "DEBW108_o3_temp_hist9_nboots20_shuffled.nc",
-                     "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"]
-        assert boot_gen._filter_files(regex, test_list, 10, 10) is None
-        assert boot_gen._filter_files(regex, test_list, 9, 10) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"
-        assert boot_gen._filter_files(regex, test_list, 9, 20) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"
-
 
 ### new ###
 
@@ -271,3 +199,89 @@ class TestCreateShuffledData:
         assert dummy.max() >= res.max()
         assert dummy.min() <= res.min()
         assert set(np.unique(res)).issubset({1, 2, 3})
+
+
+class TestBootStraps:
+
+    @pytest.fixture
+    def bootstrap(self, orig_generator, data_path):
+        return BootStraps(orig_generator, data_path, 20)
+
+    @pytest.fixture
+    @mock.patch("src.data_handling.bootstraps.CreateShuffledData", return_value=None)
+    def bootstrap_no_shuffling(self, mock_create_shuffle_data, orig_generator, data_path):
+        shutil.rmtree(data_path)
+        return BootStraps(orig_generator, data_path, 20)
+
+    def test_init_no_shuffling(self, bootstrap_no_shuffling, data_path):
+        assert isinstance(bootstrap_no_shuffling, BootStraps)
+        assert bootstrap_no_shuffling.number_of_bootstraps == 20
+        assert bootstrap_no_shuffling.bootstrap_path == data_path
+
+    def test_init_with_shuffling(self, orig_generator, data_path, caplog):
+        caplog.set_level(logging.INFO)
+        BootStraps(orig_generator, data_path, 20)
+        assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data")
+
+    def test_stations(self, bootstrap_no_shuffling, orig_generator):
+        assert bootstrap_no_shuffling.stations == orig_generator.stations
+
+    def test_variables(self, bootstrap_no_shuffling, orig_generator):
+        assert bootstrap_no_shuffling.variables == orig_generator.variables
+
+    def test_window_history_size(self, bootstrap_no_shuffling, orig_generator):
+        assert bootstrap_no_shuffling.window_history_size == orig_generator.window_history_size
+
+    def test_get_generator(self, bootstrap, orig_generator):
+        station = bootstrap.stations[0]
+        var = bootstrap.variables[0]
+        var_others = bootstrap.variables[1:]
+        gen = bootstrap.get_generator(station, var)
+        assert isinstance(gen, BootStrapGenerator)
+        assert gen.number_of_boots == bootstrap.number_of_bootstraps
+        assert gen.variables == bootstrap.variables
+        expected = orig_generator.get_data_generator(station).get_transposed_history()
+        assert xr.testing.assert_equal(gen.history_orig, expected) is None
+        assert xr.testing.assert_equal(gen.history, expected.sel(variables=var_others)) is None
+        assert gen.shuffled.variables == "o3"
+
+    def test_get_labels(self, bootstrap, orig_generator):
+        station = bootstrap.stations[0]
+        labels = bootstrap.get_labels(station)
+        labels_orig = orig_generator.get_data_generator(station).get_transposed_label()
+        assert labels.shape == (labels_orig.shape[0] * bootstrap.number_of_bootstraps, *labels_orig.shape[1:])
+        assert np.testing.assert_array_equal(labels[:labels_orig.shape[0], :], labels_orig.values) is None
+
+    def test_get_orig_prediction(self):
+        pass
+
+    def test_load_shuffled_data(self):
+        pass
+
+    def test_get_shuffled_data_file(self, bootstrap):
+        file_name = bootstrap._get_shuffled_data_file("DEBW107", ["o3"])
+        assert file_name == os.path.join(bootstrap.bootstrap_path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc")
+
+    def test_get_shuffled_data_file_not_found(self, bootstrap_no_shuffling, data_path):
+        bootstrap_no_shuffling.number_of_boots = 100
+        os.makedirs(data_path)
+        with pytest.raises(FileNotFoundError) as e:
+            bootstrap_no_shuffling._get_shuffled_data_file("DEBW107", ["o3"])
+        assert "Could not find a file to match pattern" in e.value.args[0]
+
+    def test_create_file_regex(self, bootstrap_no_shuffling):
+        regex = bootstrap_no_shuffling._create_file_regex("DEBW108", ["o3", "temp", "h2o"])
+        test_list = ["DEBW108_o3_test23_test_shuffled.nc",
+                     "DEBW107_o3_test23_test_shuffled.nc",
+                     "DEBW108_o3_test23_test.nc",
+                     "DEBW108_h2o_o3_temp_test_shuffled.nc",
+                     "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc",
+                     "DEBW108_o3_temp_hist9_nboots20_shuffled.nc",
+                     "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"]
+        f = bootstrap_no_shuffling._filter_files
+        assert f(regex, test_list, 10, 10) is None
+        assert f(regex, test_list, 9, 10) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"
+        assert f(regex, test_list, 9, 20) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"
+
+    def test_filter_files(self):
+        pass
-- 
GitLab