diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index f3e0496120fa15e958e3d62e68f7d0915bbcf938..d1563b23fd7c4eefc25fdc1debc3e940ded648c5 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -48,6 +48,7 @@ DEFAULT_TEST_END = "2017-12-31" DEFAULT_TEST_MIN_LENGTH = 90 DEFAULT_TRAIN_VAL_MIN_LENGTH = 180 DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True +DEFAULT_COMPETITORS = ["ols"] DEFAULT_DO_UNCERTAINTY_ESTIMATE = True DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH = "1m" DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS = True diff --git a/mlair/plotting/abstract_plot_class.py b/mlair/plotting/abstract_plot_class.py index 21e5d9413b490a4be5281c2a80308be558fe64c8..a26023bb6cb8772623479491ac8bcc731dd42223 100644 --- a/mlair/plotting/abstract_plot_class.py +++ b/mlair/plotting/abstract_plot_class.py @@ -72,7 +72,10 @@ class AbstractPlotClass: # pragma: no cover self._update_rc_params() def __del__(self): - plt.close('all') + try: + plt.close('all') + except ImportError: + pass def _plot(self, *args): """Abstract plot class needs to be implemented in inheritance.""" diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 5e7efde64dda069c57e2b7bb63faa8064f65a57d..adef978498b619d744f5b06f9fdbb219c52ee5ec 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -24,7 +24,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \ DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \ DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE, DEFAULT_EARLY_STOPPING_EPOCHS, \ - DEFAULT_RESTORE_BEST_MODEL_WEIGHTS + DEFAULT_RESTORE_BEST_MODEL_WEIGHTS, DEFAULT_COMPETITORS from mlair.data_handler import DefaultDataHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel @@ -404,7 +404,7 @@ class ExperimentSetup(RunEnvironment): raise IndexError(f"Given model_display_name {model_display_name} is also present in the competitors " f"variable {competitors}. To assure a proper workflow it is required to have unique names " f"for each model and competitor. Please use a different model display name or competitor.") - self._set_param("competitors", competitors, default=[]) + self._set_param("competitors", competitors, default=DEFAULT_COMPETITORS) competitor_path_default = os.path.join(self.data_store.get("data_path"), "competitors", "_".join(self.data_store.get("target_var"))) self._set_param("competitor_path", competitor_path, default=competitor_path_default) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 97d1817f5eb884e80c042d56f02dd7a61f88d935..d4a3c4f012c8499e24f50256bd0d77c33c8776d6 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -34,7 +34,7 @@ class PostProcessing(RunEnvironment): Perform post-processing for performance evaluation. Schedule of post-processing: - #. train a ordinary least squared model (ols) for reference + #. train an ordinary least squared model (ols) for reference #. create forecasts for nn, ols, and persistence #. evaluate feature importance with bootstrapped predictions #. calculate skill scores @@ -695,8 +695,12 @@ class PostProcessing(RunEnvironment): @TimeTrackingWrapper def train_ols_model(self): """Train ordinary least squared model on train data.""" - logging.info(f"start train_ols_model on train data") - self.ols_model = OrdinaryLeastSquaredModel(self.train_data) + if "ols" in map(lambda x: x.lower(), self.competitors): + logging.info(f"start train_ols_model on train data") + self.ols_model = OrdinaryLeastSquaredModel(self.train_data) + self.competitors = [e for e in self.competitors if e.lower() != "ols"] + else: + logging.info(f"Skip train ols model as it is not present in competitors.") @TimeTrackingWrapper def make_prediction(self, subset): @@ -733,7 +737,11 @@ class PostProcessing(RunEnvironment): transformation_func, normalised) # ols - ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_func, normalised) + if self.ols_model is not None: + ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_func, + normalised) + else: + ols_prediction = None # observation observation = self._create_observation(target_data, observation, transformation_func, normalised) @@ -817,8 +825,8 @@ class PostProcessing(RunEnvironment): tmp_ols = self.ols_model.predict(input_data) target_shape = ols_prediction.values.shape if target_shape != tmp_ols.shape: - if len(target_shape)==2: - new_values = np.swapaxes(tmp_ols,1,0) + if len(target_shape) == 2: + new_values = np.swapaxes(tmp_ols, 1, 0) else: new_values = np.swapaxes(tmp_ols, 2, 0) else: @@ -922,6 +930,7 @@ class PostProcessing(RunEnvironment): :return: xarray of dimension 3: index, ahead_names, # predictions """ + kwargs = {k: v for k, v in kwargs.items() if v is not None} keys = list(kwargs.keys()) res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan), coords=[index.index, ahead_names, keys], dims=[index_dim, ahead_dim, type_dim]) diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index ff21e27fd7804cde0cbd59b2f90e5ea704f4b0ef..2b3bfa123dda3a07ab572ab34e94b18f38d20fcb 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -380,8 +380,9 @@ class PreProcessing(RunEnvironment): data_collection = self.data_store.get("data_collection", subset) stations.update({str(s): s.get_coordinates() for s in data_collection if s not in stations}) CAMSforecast("CAMS", ref_store_path=path, data_path=data_path).make_reference_available_locally(stations) - - + else: + logging.info(f"No preparation required for competitor {competitor_name} as no specific instruction " + f"is provided.") else: logging.info("No preparation required because no competitor was provided to the workflow.") diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py index 4618a5e4f3f5eaf2a419e68a5a0e18156aa7fb0d..743900bb4b0eab96a35e0f263d7525fcf060b597 100644 --- a/test/test_run_modules/test_pre_processing.py +++ b/test/test_run_modules/test_pre_processing.py @@ -52,8 +52,8 @@ class TestPreProcessing: assert caplog.record_tuples[-4] == ('root', 20, "use serial create_info_df (val)") assert caplog.record_tuples[-3] == ('root', 20, "use serial create_info_df (test)") assert caplog.record_tuples[-2] == ('root', 20, "Searching for competitors to be prepared for use.") - assert caplog.record_tuples[-1] == ('root', 20, "No preparation required because no competitor was provided" - " to the workflow.") + assert caplog.record_tuples[-1] == ('root', 20, "No preparation required for competitor ols as no specific " + "instruction is provided.") RunEnvironment().__del__() def test_run(self, obj_with_exp_setup): @@ -71,7 +71,7 @@ class TestPreProcessing: "extreme_values", "extremes_on_right_tail_only", "upsampling"] assert data_store.search_scope("general.train") == sorted(expected_params) assert data_store.search_name("data_collection") == sorted(["general.train", "general.val", "general.test", - "general.train_val"]) + "general.train_val"]) def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG)