diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index 9bb15068ce3a5ad934f7b0251b84cb19f37702f6..d0b3592905c32a4c7af875014532a5539805dd3a 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -48,7 +48,7 @@ DEFAULT_TEST_END = "2017-12-31" DEFAULT_TEST_MIN_LENGTH = 90 DEFAULT_TRAIN_VAL_MIN_LENGTH = 180 DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True -DEFAULT_COMPETITORS = ["ols"] +DEFAULT_COMPETITORS = ["ols", "persi"] DEFAULT_DO_UNCERTAINTY_ESTIMATE = True DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH = "1m" DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS = True diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index d65a200161a7593fe03df5053328aa3f8cd77310..e7d587cf5e87948f51bd2326fe84e495f940d7a8 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -72,6 +72,7 @@ class PostProcessing(RunEnvironment): self.model: AbstractModelClass = self._load_model() self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] self.ols_model = None + self.persi_model = True self.batch_size: int = self.data_store.get_default("batch_size", "model", 64) self.test_data = self.data_store.get("data_collection", "test") batch_path = self.data_store.get("batch_path", scope="test") @@ -106,6 +107,9 @@ class PostProcessing(RunEnvironment): # ols model self.train_ols_model() + # persi model + self.setup_persistence() + # forecasts on test data self.make_prediction(self.test_data) self.make_prediction(self.train_val_data) @@ -715,6 +719,12 @@ class PostProcessing(RunEnvironment): else: logging.info(f"Skip train ols model as it is not present in competitors.") + def setup_persistence(self): + """Check if persistence is requested from competitors and store this information.""" + self.persi_model = any(x in map(str.lower, self.competitors) for x in ["persi", "persistence"]) + if self.persi_model is False: + logging.info(f"Persistence is not calculated as it is not present in competitors.") + @TimeTrackingWrapper def make_prediction(self, subset): """ @@ -748,8 +758,11 @@ class PostProcessing(RunEnvironment): nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised) # persistence - persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, - transformation_func, normalised) + if self.persi_model is True: + persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, + transformation_func, normalised) + else: + persistence_prediction = None # ols if self.ols_model is not None: