diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 033859fa9297c31e13b30c87ef3bcbc90c1402be..4751bd35d0852a7044260e5b08a1e203234e1e1b 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -443,15 +443,14 @@ class DataPrep(object): """ # check type if inputs extreme_values = helpers.to_list(extreme_values) - extreme_values.sort() for i in extreme_values: if not isinstance(i, number.__args__): raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element " f"{i} is type {type(i)}") - for extr_val in extreme_values: + for extr_val in sorted(extreme_values): # check if some extreme values are already extracted - if not all([self.extremes_labels, self.extremes_history]): + if (self.extremes_labels is None) or (self.extremes_history is None): # extract extremes based on occurance in labels if extremes_on_right_tail_only: extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1,) diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 85c4420609a466ff5f3eeb3d46cb6bb07fe9c30a..f202a6efc09502c635ad548471b588e259ebc7e1 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -1,6 +1,6 @@ import datetime as dt import os -from operator import itemgetter +from operator import itemgetter, lt, gt import logging import numpy as np @@ -403,3 +403,44 @@ class TestDataPrep: data.make_labels("variables", "o3", "datetime", 2) transposed = data.get_transposed_label() assert transposed.coords.dims == ("datetime", "window") + + def test_multiply_extremes(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + orig = data.label + data.multiply_extremes(1) + upsampled = data.extremes_labels + assert (upsampled > 1).sum() == (orig > 1).sum() + assert (upsampled < -1).sum() == (orig < -1).sum() + + def test_multiply_extremes_from_list(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + orig = data.label + data.multiply_extremes([1, 1.5, 2, 3]) + upsampled = data.extremes_labels + def f(d, op, n): + return op(d, n).any(dim="window").sum() + assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 1.5), f(orig, gt, 2) * 2, f(orig, gt, 3) * 4]) + assert f(upsampled, lt, -1) == sum([f(orig, lt, -1), f(orig, lt, -1.5), f(orig, lt, -2) * 2, f(orig, lt, -3) * 4]) + + def test_multiply_extremes_wrong_extremes(self, data): + with pytest.raises(TypeError) as e: + data.multiply_extremes([1, "1.5", 2]) + assert "Elements of list extreme_values have to be (<class 'float'>, <class 'int'>), but at least element 1.5" \ + " is type <class 'str'>" in e.value.args[0] + + def test_multiply_extremes_right_tail(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + orig = data.label + data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + upsampled = data.extremes_labels + def f(d, op, n): + return op(d, n).any(dim="window").sum() + assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 2)]) + assert len(upsampled) == sum([f(orig, gt, 1), f(orig, gt, 2)]) + assert f(upsampled, lt, -1) == 0 diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index c3f13e1ac1d7bdb0bdf17f81d3385472eaa46640..3637410126362471120df9e5be334190634a3bca 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -54,7 +54,8 @@ class TestPreProcessing: assert obj_with_exp_setup.data_store.search_name("generator") == [] obj_with_exp_setup.split_train_val_test() data_store = obj_with_exp_setup.data_store - expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length"] + expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length", "extreme_values", + "extremes_on_right_tail_only", "upsampling"] assert data_store.search_scope("general.train") == sorted(expected_params) assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test", "general.train_val"])