Skip to content
Snippets Groups Projects
Commit cb5ebf89 authored by leufen1's avatar leufen1
Browse files

can now skip persi reference if not provided in competitor list

parent cb3248d6
Branches
Tags
3 merge requests!500Develop,!499Resolve "release v2.3.0",!486Resolve "enable persi only if requested"
Pipeline #112147 failed
...@@ -48,7 +48,7 @@ DEFAULT_TEST_END = "2017-12-31" ...@@ -48,7 +48,7 @@ DEFAULT_TEST_END = "2017-12-31"
DEFAULT_TEST_MIN_LENGTH = 90 DEFAULT_TEST_MIN_LENGTH = 90
DEFAULT_TRAIN_VAL_MIN_LENGTH = 180 DEFAULT_TRAIN_VAL_MIN_LENGTH = 180
DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True
DEFAULT_COMPETITORS = ["ols"] DEFAULT_COMPETITORS = ["ols", "persi"]
DEFAULT_DO_UNCERTAINTY_ESTIMATE = True DEFAULT_DO_UNCERTAINTY_ESTIMATE = True
DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH = "1m" DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH = "1m"
DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS = True DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS = True
......
...@@ -72,6 +72,7 @@ class PostProcessing(RunEnvironment): ...@@ -72,6 +72,7 @@ class PostProcessing(RunEnvironment):
self.model: AbstractModelClass = self._load_model() self.model: AbstractModelClass = self._load_model()
self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0]
self.ols_model = None self.ols_model = None
self.persi_model = True
self.batch_size: int = self.data_store.get_default("batch_size", "model", 64) self.batch_size: int = self.data_store.get_default("batch_size", "model", 64)
self.test_data = self.data_store.get("data_collection", "test") self.test_data = self.data_store.get("data_collection", "test")
batch_path = self.data_store.get("batch_path", scope="test") batch_path = self.data_store.get("batch_path", scope="test")
...@@ -106,6 +107,9 @@ class PostProcessing(RunEnvironment): ...@@ -106,6 +107,9 @@ class PostProcessing(RunEnvironment):
# ols model # ols model
self.train_ols_model() self.train_ols_model()
# persi model
self.setup_persistence()
# forecasts on test data # forecasts on test data
self.make_prediction(self.test_data) self.make_prediction(self.test_data)
self.make_prediction(self.train_val_data) self.make_prediction(self.train_val_data)
...@@ -715,6 +719,12 @@ class PostProcessing(RunEnvironment): ...@@ -715,6 +719,12 @@ class PostProcessing(RunEnvironment):
else: else:
logging.info(f"Skip train ols model as it is not present in competitors.") logging.info(f"Skip train ols model as it is not present in competitors.")
def setup_persistence(self):
"""Check if persistence is requested from competitors and store this information."""
self.persi_model = any(x in map(str.lower, self.competitors) for x in ["persi", "persistence"])
if self.persi_model is False:
logging.info(f"Persistence is not calculated as it is not present in competitors.")
@TimeTrackingWrapper @TimeTrackingWrapper
def make_prediction(self, subset): def make_prediction(self, subset):
""" """
...@@ -748,8 +758,11 @@ class PostProcessing(RunEnvironment): ...@@ -748,8 +758,11 @@ class PostProcessing(RunEnvironment):
nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised) nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised)
# persistence # persistence
if self.persi_model is True:
persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
transformation_func, normalised) transformation_func, normalised)
else:
persistence_prediction = None
# ols # ols
if self.ols_model is not None: if self.ols_model is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment