diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index b460decd60fe31aecf817fed8fe1992f89548c42..1d375c32be06b583abbfb06a20ea482e6775b232 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -297,7 +297,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("sampling", sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") - self._set_param("DataPrep", data_preparation, default=DataPrepJoin) + self._set_param("data_preparation", data_preparation, default=DataPrepJoin) # target self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR) diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 6843ae2ef63d10946184ee9d02fb8bb0c9e894b8..db7fff2ab9e385ce769f86ef95d1565ea783cc95 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -19,7 +19,7 @@ from src.run_modules.run_environment import RunEnvironment DEFAULT_ARGS_LIST = ["data_path", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length", "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation", - "extreme_values", "extremes_on_right_tail_only", "network", "DataPrep"] + "extreme_values", "extremes_on_right_tail_only", "network", "data_preparation"] class PreProcessing(RunEnvironment): diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index 60d140f8845b25432184de1f3890b3ee4d0b034e..6de61b2dbe88e24eb3caccf6de575d6340129b91 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -2,6 +2,7 @@ import os import pytest +from src.data_handling import DataPrepJoin from src.data_handling.data_generator import DataGenerator from src.helpers.datastore import EmptyScope from src.model_modules.keras_extensions import CallbackHandler @@ -29,8 +30,9 @@ class TestModelSetup: @pytest.fixture def gen(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], - 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def setup_with_gen(self, setup, gen): diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 29811fb04789f32a1e2cc1b3affb6f8d4ae99730..0b439e9e9ad54ca3aef70e27b2017482706383c0 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -2,6 +2,7 @@ import logging import pytest +from src.data_handling import DataPrepJoin from src.data_handling.data_generator import DataGenerator from src.helpers.datastore import NameNotFoundInScope from src.helpers import PyTestRegex @@ -27,7 +28,8 @@ class TestPreProcessing: @pytest.fixture def obj_with_exp_setup(self): ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background") + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background", + data_preparation=DataPrepJoin) pre = object.__new__(PreProcessing) super(PreProcessing, pre).__init__() yield pre diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index eb5dfe5adb170981d5d67c94ca1fbcb55e326550..d58c1a973ec474b2ec786271dff9d35ce5ca94d9 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -9,6 +9,7 @@ import mock import pytest from keras.callbacks import History +from src.data_handling import DataPrepJoin from src.data_handling.data_distributor import Distributor from src.data_handling.data_generator import DataGenerator from src.helpers import PyTestRegex @@ -108,9 +109,9 @@ class TestTraining: @pytest.fixture def generator(self, path): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', - ['DEBW107'], ['o3', 'temp'], 'datetime', 'variables', - 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), ['DEBW107'], ['o3', 'temp'], 'datetime', + 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def model(self):