diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index 9bb15068ce3a5ad934f7b0251b84cb19f37702f6..d0b3592905c32a4c7af875014532a5539805dd3a 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -48,7 +48,7 @@ DEFAULT_TEST_END = "2017-12-31" DEFAULT_TEST_MIN_LENGTH = 90 DEFAULT_TRAIN_VAL_MIN_LENGTH = 180 DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True -DEFAULT_COMPETITORS = ["ols"] +DEFAULT_COMPETITORS = ["ols", "persi"] DEFAULT_DO_UNCERTAINTY_ESTIMATE = True DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH = "1m" DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS = True diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index d65a200161a7593fe03df5053328aa3f8cd77310..e7d587cf5e87948f51bd2326fe84e495f940d7a8 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -72,6 +72,7 @@ class PostProcessing(RunEnvironment): self.model: AbstractModelClass = self._load_model() self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] self.ols_model = None + self.persi_model = True self.batch_size: int = self.data_store.get_default("batch_size", "model", 64) self.test_data = self.data_store.get("data_collection", "test") batch_path = self.data_store.get("batch_path", scope="test") @@ -106,6 +107,9 @@ class PostProcessing(RunEnvironment): # ols model self.train_ols_model() + # persi model + self.setup_persistence() + # forecasts on test data self.make_prediction(self.test_data) self.make_prediction(self.train_val_data) @@ -715,6 +719,12 @@ class PostProcessing(RunEnvironment): else: logging.info(f"Skip train ols model as it is not present in competitors.") + def setup_persistence(self): + """Check if persistence is requested from competitors and store this information.""" + self.persi_model = any(x in map(str.lower, self.competitors) for x in ["persi", "persistence"]) + if self.persi_model is False: + logging.info(f"Persistence is not calculated as it is not present in competitors.") + @TimeTrackingWrapper def make_prediction(self, subset): """ @@ -748,8 +758,11 @@ class PostProcessing(RunEnvironment): nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised) # persistence - persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, - transformation_func, normalised) + if self.persi_model is True: + persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, + transformation_func, normalised) + else: + persistence_prediction = None # ols if self.ols_model is not None: diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py index 6646e1a4795756edd1792ef91f535132e8cde61d..9debec9f04583b401698cea2f79f78b523fe7927 100644 --- a/test/test_run_modules/test_pre_processing.py +++ b/test/test_run_modules/test_pre_processing.py @@ -45,14 +45,16 @@ class TestPreProcessing: with PreProcessing(): assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (preprocessing)') - assert caplog.record_tuples[-6] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 3 ' + assert caplog.record_tuples[-7] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 3 ' r'station\(s\). Found 3/3 valid stations.')) - assert caplog.record_tuples[-5] == ('root', 20, "use serial create_info_df (train)") - assert caplog.record_tuples[-4] == ('root', 20, "use serial create_info_df (val)") - assert caplog.record_tuples[-3] == ('root', 20, "use serial create_info_df (test)") - assert caplog.record_tuples[-2] == ('root', 20, "Searching for competitors to be prepared for use.") - assert caplog.record_tuples[-1] == ('root', 20, "No preparation required for competitor ols as no specific " + assert caplog.record_tuples[-6] == ('root', 20, "use serial create_info_df (train)") + assert caplog.record_tuples[-5] == ('root', 20, "use serial create_info_df (val)") + assert caplog.record_tuples[-4] == ('root', 20, "use serial create_info_df (test)") + assert caplog.record_tuples[-3] == ('root', 20, "Searching for competitors to be prepared for use.") + assert caplog.record_tuples[-2] == ('root', 20, "No preparation required for competitor ols as no specific " "instruction is provided.") + assert caplog.record_tuples[-1] == ('root', 20, "No preparation required for competitor persi as no " + "specific instruction is provided.") RunEnvironment().__del__() def test_run(self, obj_with_exp_setup):