diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index c6015ed78956fca539a951d6b0072ca8924a7a9c..e8c6044280799ded080ab4bff3627aeb9ffde2db 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -57,8 +57,8 @@ class Distributor(keras.utils.Sequence): if self.upsampling: try: s = self.generator.get_data_generator(k) - x_total = np.concatenate([x_total, np.copy(s.extremes_history.copy())], axis=0) - y_total = np.concatenate([y_total, np.copy(s.extremes_labels.copy())], axis=0) + x_total = np.concatenate([x_total, np.copy(s.get_extremes_history())], axis=0) + y_total = np.concatenate([y_total, np.copy(s.get_extremes_label())], axis=0) except AttributeError: # no extremes history / labels available, copy will fail pass # get number of mini batches diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 4751bd35d0852a7044260e5b08a1e203234e1e1b..490d661195aa017113f705da7b2e1e896e55fdc1 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -61,7 +61,7 @@ class DataPrep(object): self.label = None self.observation = None self.extremes_history = None - self.extremes_labels = None + self.extremes_label = None self.kwargs = kwargs self.data = None self.meta = None @@ -424,6 +424,12 @@ class DataPrep(object): def get_transposed_label(self): return self.label.squeeze("Stations").transpose("datetime", "window").copy() + def get_extremes_history(self): + return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy() + + def get_extremes_label(self): + return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy() + def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, timedelta: Tuple[int, str] = (1, 'm')): """ @@ -450,7 +456,7 @@ class DataPrep(object): for extr_val in sorted(extreme_values): # check if some extreme values are already extracted - if (self.extremes_labels is None) or (self.extremes_history is None): + if (self.extremes_label is None) or (self.extremes_history is None): # extract extremes based on occurance in labels if extremes_on_right_tail_only: extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1,) @@ -462,23 +468,21 @@ class DataPrep(object): extremes_history = self.history[..., extreme_label_idx, :] extremes_label.datetime.values += np.timedelta64(*timedelta) extremes_history.datetime.values += np.timedelta64(*timedelta) - self.extremes_labels = extremes_label.squeeze('Stations').transpose('datetime', 'window') - self.extremes_history = extremes_history.transpose('datetime', 'window', 'Stations', 'variables') - else: # one extr value iteration is done already: self.extremes_labels is NOT None... + self.extremes_label = extremes_label#.squeeze('Stations').transpose('datetime', 'window') + self.extremes_history = extremes_history#.transpose('datetime', 'window', 'Stations', 'variables') + else: # one extr value iteration is done already: self.extremes_label is NOT None... if extremes_on_right_tail_only: - extreme_label_idx = (self.extremes_labels > extr_val).any(axis=1).values.reshape(-1,) + extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, ) else: - extreme_label_idx = np.concatenate(((self.extremes_labels < -extr_val).any(axis=1 - ).values.reshape(-1, 1), - (self.extremes_labels > extr_val).any(axis=1 - ).values.reshape(-1, 1) + extreme_label_idx = np.concatenate(((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1), + (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1) ), axis=1).any(axis=1) # check on existing extracted extremes to minimise computational costs for comparison - extremes_label = self.extremes_labels[extreme_label_idx, ...] - extremes_history = self.extremes_history[extreme_label_idx, ...] + extremes_label = self.extremes_label[..., extreme_label_idx] + extremes_history = self.extremes_history[..., extreme_label_idx, :] extremes_label.datetime.values += np.timedelta64(*timedelta) extremes_history.datetime.values += np.timedelta64(*timedelta) - self.extremes_labels = xr.concat([self.extremes_labels, extremes_label], dim='datetime') + self.extremes_label = xr.concat([self.extremes_label, extremes_label], dim='datetime') self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime') diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index dd0ca99dde0ed38bfbbd392f65eeddca78bbb075..15344fd808a4aa9ee5774ad8ba647bf5ce06d015 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -98,3 +98,21 @@ class TestDistributor: assert np.testing.assert_equal(x, x_perm) is None assert np.testing.assert_equal(y, y_perm) is None + def test_distribute_on_batches_upsampling_no_extremes_given(self, generator, model): + d = Distributor(generator, model, upsampling=True) + gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0] + num_mini_batches = math.ceil(gen_len / d.batch_size) + i = 0 + for i, e in enumerate(d.distribute_on_batches(fit_call=False)): + assert e[0].shape[0] <= d.batch_size + assert i + 1 == num_mini_batches + + def test_distribute_on_batches_upsampling(self, generator, model): + generator.extreme_values = [1] + d = Distributor(generator, model, upsampling=True) + gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0] + extr_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_extremes_label().shape[0] + i = 0 + for i, e in enumerate(d.distribute_on_batches(fit_call=False)): + assert e[0].shape[0] <= d.batch_size + assert i + 1 == math.ceil((gen_len + extr_len) / d.batch_size) diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 9bf11154609afa9ada2b488455f7a341a41d21ae..939f93cc9ee01c76a282e755aca14b39c6fc4ac9 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -238,6 +238,20 @@ class TestDataGenerator: assert data._transform_method == "standardise" assert data.mean is not None + def test_get_data_generator_extremes(self, gen_with_transformation): + gen = gen_with_transformation + gen.kwargs = {"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} + gen.extreme_values = [1.] + data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False) + assert data.extremes_label is not None + assert data.extremes_history is not None + assert data.extremes_label.shape[:2] == data.label.shape[:2] + assert data.extremes_label.shape[2] <= data.label.shape[2] + len_both_tails = data.extremes_label.shape[2] + gen.kwargs["extremes_on_right_tail_only"] = True + data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False) + assert data.extremes_label.shape[2] <= len_both_tails + def test_save_pickle_data(self, gen): file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}_2010_2014_.pickle") if os.path.exists(file): diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index f202a6efc09502c635ad548471b588e259ebc7e1..71f3a1d6a0a675a155b517901aef1f3c359b104b 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -410,7 +410,7 @@ class TestDataPrep: data.make_labels("variables", "o3", "datetime", 2) orig = data.label data.multiply_extremes(1) - upsampled = data.extremes_labels + upsampled = data.extremes_label assert (upsampled > 1).sum() == (orig > 1).sum() assert (upsampled < -1).sum() == (orig < -1).sum() @@ -420,7 +420,7 @@ class TestDataPrep: data.make_labels("variables", "o3", "datetime", 2) orig = data.label data.multiply_extremes([1, 1.5, 2, 3]) - upsampled = data.extremes_labels + upsampled = data.extremes_label def f(d, op, n): return op(d, n).any(dim="window").sum() assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 1.5), f(orig, gt, 2) * 2, f(orig, gt, 3) * 4]) @@ -438,9 +438,29 @@ class TestDataPrep: data.make_labels("variables", "o3", "datetime", 2) orig = data.label data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) - upsampled = data.extremes_labels + upsampled = data.extremes_label def f(d, op, n): return op(d, n).any(dim="window").sum() assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 2)]) - assert len(upsampled) == sum([f(orig, gt, 1), f(orig, gt, 2)]) + assert upsampled.shape[2] == sum([f(orig, gt, 1), f(orig, gt, 2)]) assert f(upsampled, lt, -1) == 0 + + def test_get_extremes_history(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + data.make_observation("variables", "o3", "datetime") + data.remove_nan("datetime") + data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + assert (data.get_extremes_history() == + data.extremes_history.transpose("datetime", "window", "Stations", "variables")).all() + + def test_get_extremes_label(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + data.make_observation("variables", "o3", "datetime") + data.remove_nan("datetime") + data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + assert (data.get_extremes_label() == + data.extremes_label.squeeze("Stations").transpose("datetime", "window")).all()