From 9d89303ebf3714fbe39220302b11affb140aeac1 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 18 Mar 2020 11:35:40 +0100 Subject: [PATCH] tests for prep, gen and dist --- src/data_handling/data_distributor.py | 4 +-- src/data_handling/data_preparation.py | 30 +++++++++++-------- .../test_data_distributor.py | 18 +++++++++++ .../test_data_handling/test_data_generator.py | 14 +++++++++ .../test_data_preparation.py | 28 ++++++++++++++--- 5 files changed, 75 insertions(+), 19 deletions(-) diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index c6015ed7..e8c60442 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -57,8 +57,8 @@ class Distributor(keras.utils.Sequence): if self.upsampling: try: s = self.generator.get_data_generator(k) - x_total = np.concatenate([x_total, np.copy(s.extremes_history.copy())], axis=0) - y_total = np.concatenate([y_total, np.copy(s.extremes_labels.copy())], axis=0) + x_total = np.concatenate([x_total, np.copy(s.get_extremes_history())], axis=0) + y_total = np.concatenate([y_total, np.copy(s.get_extremes_label())], axis=0) except AttributeError: # no extremes history / labels available, copy will fail pass # get number of mini batches diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 4751bd35..490d6611 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -61,7 +61,7 @@ class DataPrep(object): self.label = None self.observation = None self.extremes_history = None - self.extremes_labels = None + self.extremes_label = None self.kwargs = kwargs self.data = None self.meta = None @@ -424,6 +424,12 @@ class DataPrep(object): def get_transposed_label(self): return self.label.squeeze("Stations").transpose("datetime", "window").copy() + def get_extremes_history(self): + return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy() + + def get_extremes_label(self): + return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy() + def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, timedelta: Tuple[int, str] = (1, 'm')): """ @@ -450,7 +456,7 @@ class DataPrep(object): for extr_val in sorted(extreme_values): # check if some extreme values are already extracted - if (self.extremes_labels is None) or (self.extremes_history is None): + if (self.extremes_label is None) or (self.extremes_history is None): # extract extremes based on occurance in labels if extremes_on_right_tail_only: extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1,) @@ -462,23 +468,21 @@ class DataPrep(object): extremes_history = self.history[..., extreme_label_idx, :] extremes_label.datetime.values += np.timedelta64(*timedelta) extremes_history.datetime.values += np.timedelta64(*timedelta) - self.extremes_labels = extremes_label.squeeze('Stations').transpose('datetime', 'window') - self.extremes_history = extremes_history.transpose('datetime', 'window', 'Stations', 'variables') - else: # one extr value iteration is done already: self.extremes_labels is NOT None... + self.extremes_label = extremes_label#.squeeze('Stations').transpose('datetime', 'window') + self.extremes_history = extremes_history#.transpose('datetime', 'window', 'Stations', 'variables') + else: # one extr value iteration is done already: self.extremes_label is NOT None... if extremes_on_right_tail_only: - extreme_label_idx = (self.extremes_labels > extr_val).any(axis=1).values.reshape(-1,) + extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, ) else: - extreme_label_idx = np.concatenate(((self.extremes_labels < -extr_val).any(axis=1 - ).values.reshape(-1, 1), - (self.extremes_labels > extr_val).any(axis=1 - ).values.reshape(-1, 1) + extreme_label_idx = np.concatenate(((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1), + (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1) ), axis=1).any(axis=1) # check on existing extracted extremes to minimise computational costs for comparison - extremes_label = self.extremes_labels[extreme_label_idx, ...] - extremes_history = self.extremes_history[extreme_label_idx, ...] + extremes_label = self.extremes_label[..., extreme_label_idx] + extremes_history = self.extremes_history[..., extreme_label_idx, :] extremes_label.datetime.values += np.timedelta64(*timedelta) extremes_history.datetime.values += np.timedelta64(*timedelta) - self.extremes_labels = xr.concat([self.extremes_labels, extremes_label], dim='datetime') + self.extremes_label = xr.concat([self.extremes_label, extremes_label], dim='datetime') self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime') diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index dd0ca99d..15344fd8 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -98,3 +98,21 @@ class TestDistributor: assert np.testing.assert_equal(x, x_perm) is None assert np.testing.assert_equal(y, y_perm) is None + def test_distribute_on_batches_upsampling_no_extremes_given(self, generator, model): + d = Distributor(generator, model, upsampling=True) + gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0] + num_mini_batches = math.ceil(gen_len / d.batch_size) + i = 0 + for i, e in enumerate(d.distribute_on_batches(fit_call=False)): + assert e[0].shape[0] <= d.batch_size + assert i + 1 == num_mini_batches + + def test_distribute_on_batches_upsampling(self, generator, model): + generator.extreme_values = [1] + d = Distributor(generator, model, upsampling=True) + gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0] + extr_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_extremes_label().shape[0] + i = 0 + for i, e in enumerate(d.distribute_on_batches(fit_call=False)): + assert e[0].shape[0] <= d.batch_size + assert i + 1 == math.ceil((gen_len + extr_len) / d.batch_size) diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 9bf11154..939f93cc 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -238,6 +238,20 @@ class TestDataGenerator: assert data._transform_method == "standardise" assert data.mean is not None + def test_get_data_generator_extremes(self, gen_with_transformation): + gen = gen_with_transformation + gen.kwargs = {"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} + gen.extreme_values = [1.] + data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False) + assert data.extremes_label is not None + assert data.extremes_history is not None + assert data.extremes_label.shape[:2] == data.label.shape[:2] + assert data.extremes_label.shape[2] <= data.label.shape[2] + len_both_tails = data.extremes_label.shape[2] + gen.kwargs["extremes_on_right_tail_only"] = True + data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False) + assert data.extremes_label.shape[2] <= len_both_tails + def test_save_pickle_data(self, gen): file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}_2010_2014_.pickle") if os.path.exists(file): diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index f202a6ef..71f3a1d6 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -410,7 +410,7 @@ class TestDataPrep: data.make_labels("variables", "o3", "datetime", 2) orig = data.label data.multiply_extremes(1) - upsampled = data.extremes_labels + upsampled = data.extremes_label assert (upsampled > 1).sum() == (orig > 1).sum() assert (upsampled < -1).sum() == (orig < -1).sum() @@ -420,7 +420,7 @@ class TestDataPrep: data.make_labels("variables", "o3", "datetime", 2) orig = data.label data.multiply_extremes([1, 1.5, 2, 3]) - upsampled = data.extremes_labels + upsampled = data.extremes_label def f(d, op, n): return op(d, n).any(dim="window").sum() assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 1.5), f(orig, gt, 2) * 2, f(orig, gt, 3) * 4]) @@ -438,9 +438,29 @@ class TestDataPrep: data.make_labels("variables", "o3", "datetime", 2) orig = data.label data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) - upsampled = data.extremes_labels + upsampled = data.extremes_label def f(d, op, n): return op(d, n).any(dim="window").sum() assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 2)]) - assert len(upsampled) == sum([f(orig, gt, 1), f(orig, gt, 2)]) + assert upsampled.shape[2] == sum([f(orig, gt, 1), f(orig, gt, 2)]) assert f(upsampled, lt, -1) == 0 + + def test_get_extremes_history(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + data.make_observation("variables", "o3", "datetime") + data.remove_nan("datetime") + data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + assert (data.get_extremes_history() == + data.extremes_history.transpose("datetime", "window", "Stations", "variables")).all() + + def test_get_extremes_label(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + data.make_observation("variables", "o3", "datetime") + data.remove_nan("datetime") + data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + assert (data.get_extremes_label() == + data.extremes_label.squeeze("Stations").transpose("datetime", "window")).all() -- GitLab