diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index b1624410e746ab779b20a60d6a7d19b4ae3b1267..e8c6044280799ded080ab4bff3627aeb9ffde2db 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -8,15 +8,18 @@ import math import keras import numpy as np +from src.data_handling.data_generator import DataGenerator + class Distributor(keras.utils.Sequence): - def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, - permute_data: bool = False): + def __init__(self, generator: DataGenerator, model: keras.models, batch_size: int = 256, + permute_data: bool = False, upsampling: bool = False): self.generator = generator self.model = model self.batch_size = batch_size self.do_data_permutation = permute_data + self.upsampling = upsampling def _get_model_rank(self): mod_out = self.model.output_shape @@ -31,7 +34,7 @@ class Distributor(keras.utils.Sequence): return mod_rank def _get_number_of_mini_batches(self, values): - return math.ceil(values[0].shape[0] / self.batch_size) + return math.ceil(values.shape[0] / self.batch_size) def _permute_data(self, x, y): """ @@ -48,10 +51,18 @@ class Distributor(keras.utils.Sequence): for k, v in enumerate(self.generator): # get rank of output mod_rank = self._get_model_rank() - # get number of mini batches - num_mini_batches = self._get_number_of_mini_batches(v) + # get data x_total = np.copy(v[0]) y_total = np.copy(v[1]) + if self.upsampling: + try: + s = self.generator.get_data_generator(k) + x_total = np.concatenate([x_total, np.copy(s.get_extremes_history())], axis=0) + y_total = np.concatenate([y_total, np.copy(s.get_extremes_label())], axis=0) + except AttributeError: # no extremes history / labels available, copy will fail + pass + # get number of mini batches + num_mini_batches = self._get_number_of_mini_batches(x_total) # permute order for mini-batches x_total, y_total = self._permute_data(x_total, y_total) for prev, curr in enumerate(range(1, num_mini_batches+1)): diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 24c9ada65b4bfd71de12785b2714cc5de94dc21f..0bf0bc35344ecf0f040f5563ddbdbe291b64404d 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -14,6 +14,9 @@ from src import helpers from src.data_handling.data_preparation import DataPrep from src.join import EmptyQueryResult +number = Union[float, int] +num_or_list = Union[number, List[number]] + class DataGenerator(keras.utils.Sequence): """ @@ -27,7 +30,7 @@ class DataGenerator(keras.utils.Sequence): def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, - window_lead_time: int = 4, transformation: Dict = None, **kwargs): + window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs): self.data_path = os.path.abspath(data_path) self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") if not os.path.exists(self.data_path_tmp): @@ -43,6 +46,7 @@ class DataGenerator(keras.utils.Sequence): self.limit_nan_fill = limit_nan_fill self.window_history_size = window_history_size self.window_lead_time = window_lead_time + self.extreme_values = extreme_values self.kwargs = kwargs self.transformation = self.setup_transformation(transformation) @@ -188,6 +192,9 @@ class DataGenerator(keras.utils.Sequence): data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.make_observation(self.target_dim, self.target_var, self.interpolate_dim) data.remove_nan(self.interpolate_dim) + if self.extreme_values: + kwargs = {"extremes_on_right_tail_only": self.kwargs.get("extremes_on_right_tail_only", False)} + data.multiply_extremes(self.extreme_values, **kwargs) if save_local_tmp_storage: self._save_pickle_data(data) return data diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index e3186778b94375ba1d39fa87ba7d2980c785581e..490d661195aa017113f705da7b2e1e896e55fdc1 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -5,7 +5,7 @@ import datetime as dt from functools import reduce import logging import os -from typing import Union, List, Iterable +from typing import Union, List, Iterable, Tuple import numpy as np import pandas as pd @@ -17,6 +17,8 @@ from src import statistics # define a more general date type for type hinting date = Union[dt.date, dt.datetime] str_or_list = Union[str, List[str]] +number = Union[float, int] +num_or_list = Union[number, List[number]] class DataPrep(object): @@ -58,6 +60,8 @@ class DataPrep(object): self.history = None self.label = None self.observation = None + self.extremes_history = None + self.extremes_label = None self.kwargs = kwargs self.data = None self.meta = None @@ -420,6 +424,67 @@ class DataPrep(object): def get_transposed_label(self): return self.label.squeeze("Stations").transpose("datetime", "window").copy() + def get_extremes_history(self): + return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy() + + def get_extremes_label(self): + return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy() + + def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, + timedelta: Tuple[int, str] = (1, 'm')): + """ + This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can + also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of + floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised + space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be + extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is + used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can + identify those "artificial" data points later easily. Extreme inputs and labels are stored in + self.extremes_history and self.extreme_labels, respectively. + + :param extreme_values: user definition of extreme + :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values, + if True only extract values larger than extreme_values + :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime + """ + # check type if inputs + extreme_values = helpers.to_list(extreme_values) + for i in extreme_values: + if not isinstance(i, number.__args__): + raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element " + f"{i} is type {type(i)}") + + for extr_val in sorted(extreme_values): + # check if some extreme values are already extracted + if (self.extremes_label is None) or (self.extremes_history is None): + # extract extremes based on occurance in labels + if extremes_on_right_tail_only: + extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1,) + else: + extreme_label_idx = np.concatenate(((self.label < -extr_val).any(axis=0).values.reshape(-1, 1), + (self.label > extr_val).any(axis=0).values.reshape(-1, 1)), + axis=1).any(axis=1) + extremes_label = self.label[..., extreme_label_idx] + extremes_history = self.history[..., extreme_label_idx, :] + extremes_label.datetime.values += np.timedelta64(*timedelta) + extremes_history.datetime.values += np.timedelta64(*timedelta) + self.extremes_label = extremes_label#.squeeze('Stations').transpose('datetime', 'window') + self.extremes_history = extremes_history#.transpose('datetime', 'window', 'Stations', 'variables') + else: # one extr value iteration is done already: self.extremes_label is NOT None... + if extremes_on_right_tail_only: + extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, ) + else: + extreme_label_idx = np.concatenate(((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1), + (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1) + ), axis=1).any(axis=1) + # check on existing extracted extremes to minimise computational costs for comparison + extremes_label = self.extremes_label[..., extreme_label_idx] + extremes_history = self.extremes_history[..., extreme_label_idx, :] + extremes_label.datetime.values += np.timedelta64(*timedelta) + extremes_history.datetime.values += np.timedelta64(*timedelta) + self.extremes_label = xr.concat([self.extremes_label, extremes_label], dim='datetime') + self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime') + if __name__ == "__main__": dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) diff --git a/src/helpers.py b/src/helpers.py index 12ec837bfcc14753fb7a90ac34f4499f47cb485d..6e9d47d1040aa803358cb60439197fd48641e9e1 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -234,11 +234,11 @@ class Logger: self.formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' # set log path - log_file = self.setup_logging_path(log_path) + self.log_file = self.setup_logging_path(log_path) # set root logger as file handler logging.basicConfig(level=level_file, format=self.formatter, - filename=log_file, + filename=self.log_file, filemode='a') # add stream handler to the root logger logging.getLogger('').addHandler(self.logger_console(level_stream)) diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 2aafe6c693b8d10fe450df510539e0aef8f0487e..295d4342c4527aa54f6e302bbf77c92bcf760c56 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -34,8 +34,9 @@ class ExperimentSetup(RunEnvironment): limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", - create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, - train_min_length=None, val_min_length=None, test_min_length=None): + create_new_model=None, bootstrap_path=None, permute_data_on_training=False, transformation=None, + train_min_length=None, val_min_length=None, test_min_length=None, extreme_values=None, + extremes_on_right_tail_only=None): # create run framework super().__init__() @@ -50,7 +51,11 @@ class ExperimentSetup(RunEnvironment): self._set_param("bootstrap_path", bootstrap_path) self._set_param("trainable", trainable, default=True) self._set_param("fraction_of_training", fraction_of_train, default=0.8) - self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train") + self._set_param("extreme_values", extreme_values, default=None, scope="general.train") + self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only, default=False, scope="general.train") + self._set_param("upsampling", extreme_values is not None, scope="general.train") + upsampling = self.data_store.get("upsampling", "general.train") + self._set_param("permute_data", max([permute_data_on_training, upsampling]), scope="general.train") # set experiment name exp_date = self._get_parser_args(parser_args).get("experiment_date") diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 20286bc43b3227291c66c7844ad43792a7a28480..439793f941e6aaf4085241f200a63614563b550a 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -12,7 +12,8 @@ from src.run_modules.run_environment import RunEnvironment DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length", - "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"] + "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation", + "extreme_values", "extremes_on_right_tail_only"] class PreProcessing(RunEnvironment): diff --git a/src/run_modules/run_environment.py b/src/run_modules/run_environment.py index 56c017290eea4d11881b9b131378d8c5995f0b29..1c44786dfd4830c8053ae1673eac1473fbd19338 100644 --- a/src/run_modules/run_environment.py +++ b/src/run_modules/run_environment.py @@ -2,9 +2,13 @@ __author__ = "Lukas Leufen" __date__ = '2019-11-25' import logging +import os +import shutil import time +from src.helpers import Logger from src.datastore import DataStoreByScope as DataStoreObject +from src.datastore import NameNotFoundInDataStore from src.helpers import TimeTracking @@ -16,6 +20,7 @@ class RunEnvironment(object): del_by_exit = False data_store = DataStoreObject() + logger = Logger() def __init__(self): """ @@ -34,6 +39,11 @@ class RunEnvironment(object): logging.info(f"{self.__class__.__name__} finished after {self.time}") self.del_by_exit = True if self.__class__.__name__ == "RunEnvironment": + try: + new_file = os.path.join(self.data_store.get("experiment_path", "general"), "logging.log") + shutil.copyfile(self.logger.log_file, new_file) + except (NameNotFoundInDataStore, FileNotFoundError): + pass self.data_store.clear_data_store() def __enter__(self): diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 55b5c2964de3155a8d34cf87a646c0d53deebbef..0d6279b132b64f287f541088c2675012a2d1e933 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -9,19 +9,21 @@ import pickle import keras from src.data_handling.data_distributor import Distributor -from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced, CallbackHandler +from src.model_modules.keras_extensions import LearningRateDecay, CallbackHandler from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from src.run_modules.run_environment import RunEnvironment +from typing import Union + class Training(RunEnvironment): def __init__(self): super().__init__() self.model: keras.Model = self.data_store.get("model", "general.model") - self.train_set = None - self.val_set = None - self.test_set = None + self.train_set: Union[Distributor, None] = None + self.val_set: Union[Distributor, None] = None + self.test_set: Union[Distributor, None] = None self.batch_size = self.data_store.get("batch_size", "general.model") self.epochs = self.data_store.get("epochs", "general.model") self.callbacks: CallbackHandler = self.data_store.get("callbacks", "general.model") @@ -65,8 +67,9 @@ class Training(RunEnvironment): :param mode: name of set, should be from ["train", "val", "test"] """ gen = self.data_store.get("generator", f"general.{mode}") - permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False) - setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, permute_data=permute_data)) + # permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False) + kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=f"general.{mode}") + setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs)) def set_generators(self) -> None: """ @@ -86,6 +89,9 @@ class Training(RunEnvironment): locally stored information and the corresponding model and proceed with the already started training. """ logging.info(f"Train with {len(self.train_set)} mini batches.") + logging.info(f"Train with option upsampling={self.train_set.upsampling}.") + logging.info(f"Train with option data_permutation={self.train_set.do_data_permutation}.") + checkpoint = self.callbacks.get_checkpoint() if not os.path.exists(checkpoint.filepath) or self._create_new_model: history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py index 9dd23893ef903bfbd0595a482dceb32724c3b437..e74499523751fd74e449bbb25455579f770d17bc 100644 --- a/test/test_data_handling/test_bootstraps.py +++ b/test/test_data_handling/test_bootstraps.py @@ -52,12 +52,12 @@ class TestBootstraps: boot_no_init.number_bootstraps = 50 assert boot_no_init.valid_bootstrap_file(station, variables, 20) == (False, 60) - def test_shuffle_single_variale(self, boot_no_init): + def test_shuffle_single_variable(self, boot_no_init): data = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) res = boot_no_init.shuffle_single_variable(data, chunks=(2, 3)).compute() assert res.shape == data.shape - assert res.max() == data.max() - assert res.min() == data.min() + assert res.max() <= data.max() + assert res.min() >= data.min() assert set(np.unique(res)).issubset({1, 2, 3}) def test_create_shuffled_data(self): diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index a26e76a0e7f3ef0f5cdbedc07d73a690116966c9..15344fd808a4aa9ee5774ad8ba647bf5ce06d015 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -46,7 +46,7 @@ class TestDistributor: distributor.model = 1 def test_get_number_of_mini_batches(self, distributor): - values = np.zeros((2, 2311, 19)) + values = np.zeros((2311, 19)) assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size) def test_distribute_on_batches_single_loop(self, generator_two_stations, model): @@ -98,3 +98,21 @@ class TestDistributor: assert np.testing.assert_equal(x, x_perm) is None assert np.testing.assert_equal(y, y_perm) is None + def test_distribute_on_batches_upsampling_no_extremes_given(self, generator, model): + d = Distributor(generator, model, upsampling=True) + gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0] + num_mini_batches = math.ceil(gen_len / d.batch_size) + i = 0 + for i, e in enumerate(d.distribute_on_batches(fit_call=False)): + assert e[0].shape[0] <= d.batch_size + assert i + 1 == num_mini_batches + + def test_distribute_on_batches_upsampling(self, generator, model): + generator.extreme_values = [1] + d = Distributor(generator, model, upsampling=True) + gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0] + extr_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_extremes_label().shape[0] + i = 0 + for i, e in enumerate(d.distribute_on_batches(fit_call=False)): + assert e[0].shape[0] <= d.batch_size + assert i + 1 == math.ceil((gen_len + extr_len) / d.batch_size) diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 9bf11154609afa9ada2b488455f7a341a41d21ae..939f93cc9ee01c76a282e755aca14b39c6fc4ac9 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -238,6 +238,20 @@ class TestDataGenerator: assert data._transform_method == "standardise" assert data.mean is not None + def test_get_data_generator_extremes(self, gen_with_transformation): + gen = gen_with_transformation + gen.kwargs = {"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} + gen.extreme_values = [1.] + data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False) + assert data.extremes_label is not None + assert data.extremes_history is not None + assert data.extremes_label.shape[:2] == data.label.shape[:2] + assert data.extremes_label.shape[2] <= data.label.shape[2] + len_both_tails = data.extremes_label.shape[2] + gen.kwargs["extremes_on_right_tail_only"] = True + data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False) + assert data.extremes_label.shape[2] <= len_both_tails + def test_save_pickle_data(self, gen): file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}_2010_2014_.pickle") if os.path.exists(file): diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 85c4420609a466ff5f3eeb3d46cb6bb07fe9c30a..71f3a1d6a0a675a155b517901aef1f3c359b104b 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -1,6 +1,6 @@ import datetime as dt import os -from operator import itemgetter +from operator import itemgetter, lt, gt import logging import numpy as np @@ -403,3 +403,64 @@ class TestDataPrep: data.make_labels("variables", "o3", "datetime", 2) transposed = data.get_transposed_label() assert transposed.coords.dims == ("datetime", "window") + + def test_multiply_extremes(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + orig = data.label + data.multiply_extremes(1) + upsampled = data.extremes_label + assert (upsampled > 1).sum() == (orig > 1).sum() + assert (upsampled < -1).sum() == (orig < -1).sum() + + def test_multiply_extremes_from_list(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + orig = data.label + data.multiply_extremes([1, 1.5, 2, 3]) + upsampled = data.extremes_label + def f(d, op, n): + return op(d, n).any(dim="window").sum() + assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 1.5), f(orig, gt, 2) * 2, f(orig, gt, 3) * 4]) + assert f(upsampled, lt, -1) == sum([f(orig, lt, -1), f(orig, lt, -1.5), f(orig, lt, -2) * 2, f(orig, lt, -3) * 4]) + + def test_multiply_extremes_wrong_extremes(self, data): + with pytest.raises(TypeError) as e: + data.multiply_extremes([1, "1.5", 2]) + assert "Elements of list extreme_values have to be (<class 'float'>, <class 'int'>), but at least element 1.5" \ + " is type <class 'str'>" in e.value.args[0] + + def test_multiply_extremes_right_tail(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + orig = data.label + data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + upsampled = data.extremes_label + def f(d, op, n): + return op(d, n).any(dim="window").sum() + assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 2)]) + assert upsampled.shape[2] == sum([f(orig, gt, 1), f(orig, gt, 2)]) + assert f(upsampled, lt, -1) == 0 + + def test_get_extremes_history(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + data.make_observation("variables", "o3", "datetime") + data.remove_nan("datetime") + data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + assert (data.get_extremes_history() == + data.extremes_history.transpose("datetime", "window", "Stations", "variables")).all() + + def test_get_extremes_label(self, data): + data.transform("datetime") + data.make_history_window("variables", 3, "datetime") + data.make_labels("variables", "o3", "datetime", 2) + data.make_observation("variables", "o3", "datetime") + data.remove_nan("datetime") + data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + assert (data.get_extremes_label() == + data.extremes_label.squeeze("Stations").transpose("datetime", "window")).all() diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index c3f13e1ac1d7bdb0bdf17f81d3385472eaa46640..425ddecc135db75a3f2f624ed150e8dd8f566bdc 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -36,6 +36,7 @@ class TestPreProcessing: def test_init(self, caplog): ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + caplog.clear() caplog.set_level(logging.INFO) with PreProcessing(): assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') @@ -54,7 +55,8 @@ class TestPreProcessing: assert obj_with_exp_setup.data_store.search_name("generator") == [] obj_with_exp_setup.split_train_val_test() data_store = obj_with_exp_setup.data_store - expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length"] + expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length", "extreme_values", + "extremes_on_right_tail_only", "upsampling"] assert data_store.search_scope("general.train") == sorted(expected_params) assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test", "general.train_val"])