diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index f7fa12f541d6b5a09ccc3d9ca8addec813468c78..d4b23f347bde6153a5a7fd6850db96067cbe6cfb 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -15,7 +15,7 @@ import xarray as xr from mlair.data_handler import BootStraps, KerasIterator from mlair.helpers.datastore import NameNotFoundInDataStore -from mlair.helpers import TimeTracking, statistics, extract_value +from mlair.helpers import TimeTracking, statistics, extract_value, remove_items from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules.model_class import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ @@ -85,40 +85,61 @@ class PostProcessing(RunEnvironment): self.target_var # ToDo: make sure this is a string, multiple vars are joined by underscore ) - self.competitor_name = "test_model" - self.competitor_forecast_name = "CNN" # ToDo: another refac, rename the CNN field to something like forecast to be more general + self.competitors = ["test_model", "test_model2"] # ToDo: this shouldn't be hardcoded + self.competitor_forecast_name = "cnn" # ToDo: another refac, rename the cnn field to something like forecast to be more general + # ToDo: this must be applied to all predictions. Maybe there should be a model name parameter in the data store + # that can be requested. I would also argue that forecast_name and the string in competitors should be the same + # name. self._run() def _run(self): # ols model - # self.train_ols_model() # ToDo: remove for final commit + self.train_ols_model() # ToDo: remove comment for final commit # forecasts - # self.make_prediction() # ToDo: remove for final commit + self.make_prediction() # ToDo: remove comment for final commit # competitors - # self.load_competitors() # ToDo: remove for final commit + self.load_competitors( + "DEBW107") # ToDo: remove this line for final commit, it is not required to load the competitors here - # skill scores on test data - # self.calculate_test_score() # ToDo: remove for final commit + # calculate error metrics on test data + # self.calculate_test_score() # ToDo: remove comment for final commit # bootstraps - if self.data_store.get("evaluate_bootstraps", "postprocessing"): - with TimeTracking(name="calculate bootstraps"): - create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing") - self.bootstrap_postprocessing(create_new_bootstraps) + # ToDo: remove comment for final commit + # if self.data_store.get("evaluate_bootstraps", "postprocessing"): + # with TimeTracking(name="calculate bootstraps"): + # create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing") + # self.bootstrap_postprocessing(create_new_bootstraps) # skill scores # ToDo: remove for final commit - # with TimeTracking(name="calculate skill scores"): - # self.skill_scores = self.calculate_skill_scores() + with TimeTracking(name="calculate skill scores"): + self.skill_scores = self.calculate_skill_scores() # plotting self.plot() - def load_competitors(self): - for station in self.test_data: - competing_prediction = self._create_competitor_forecast(str(station)) + def load_competitors(self, station_name: str) -> xr.DataArray: + """ + Load all requested and available competitors for a given station. Forecasts must be available in the competitor + path like `<competitor_path>/<target_var>/forecasts_<station_name>_test.nc`. The naming style is equal for all + forecasts of MLAir, so that forecasts of a different experiment can easily be copied into the competitor path + without any change. + + :param station_name: station indicator to load competitors for + + :return: a single xarray with all competing forecasts + """ + competing_predictions = [] + for competitor_name in self.competitors: + try: + prediction = self._create_competitor_forecast(station_name, competitor_name) + competing_predictions.append(prediction) + except FileNotFoundError: + continue + return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0) -> None: """ @@ -235,7 +256,7 @@ class PostProcessing(RunEnvironment): return score @staticmethod - def get_orig_prediction(path, file_name, number_of_bootstraps, prediction_name="CNN"): + def get_orig_prediction(path, file_name, number_of_bootstraps, prediction_name="cnn"): file = os.path.join(path, file_name) prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze() vals = np.tile(prediction.data, (number_of_bootstraps, 1)) @@ -290,7 +311,7 @@ class PostProcessing(RunEnvironment): try: if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list): - PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN") + PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="cnn") except Exception as e: logging.error(f"Could not create plot PlotBootstrapSkillScore due to the following error: {e}") @@ -326,15 +347,15 @@ class PostProcessing(RunEnvironment): try: if "PlotClimatologicalSkillScore" in plot_list: - PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") + PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="cnn") PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, - extra_name_tag="all_terms_", model_setup="CNN") + extra_name_tag="all_terms_", model_setup="cnn") except Exception as e: logging.error(f"Could not create plot PlotClimatologicalSkillScore due to the following error: {e}") try: if "PlotCompetitiveSkillScore" in plot_list: - PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") + PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="cnn") except Exception as e: logging.error(f"Could not create plot PlotCompetitiveSkillScore due to the following error: {e}") @@ -415,10 +436,10 @@ class PostProcessing(RunEnvironment): full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency()) all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes['window']), time_dimension, - CNN=nn_prediction, + cnn=nn_prediction, persi=persistence_prediction, obs=observation, - OLS=ols_prediction) + ols=ols_prediction) # save all forecasts locally path = self.data_store.get("forecast_path") @@ -431,12 +452,22 @@ class PostProcessing(RunEnvironment): getter = {"daily": "1D", "hourly": "1H"} return getter.get(self._sampling, None) - def _create_competitor_forecast(self, station_name): - path = os.path.join(self.competitor_path, self.competitor_name) + def _create_competitor_forecast(self, station_name: str, competitor_name: str) -> xr.DataArray: + """ + Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station + indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will + raise an `FileNotFoundError` if no competitor could be found for the given station. + + :param station_name: name of the station to load data for + :param competitor_name: name of the model + :return: the forecast of the given competitor + """ + path = os.path.join(self.competitor_path, competitor_name) file = os.path.join(path, f"forecasts_{station_name}_test.nc") data = xr.open_dataarray(file) + # data = data.expand_dims(Stations=[station_name]) forecast = data.sel(type=[self.competitor_forecast_name]) - forecast.coords["type"] = ["competitor"] + forecast.coords["type"] = [competitor_name] return forecast @staticmethod @@ -584,7 +615,8 @@ class PostProcessing(RunEnvironment): res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan), coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type']) for k, v in kwargs.items(): - match_index = np.stack(set(res.index.values) & set(v.indexes[time_dimension].values)) + intersection = set(res.index.values) & set(v.indexes[time_dimension].values) + match_index = np.array(list(intersection)) res.loc[match_index, :, k] = v.loc[match_index] return res @@ -623,12 +655,13 @@ class PostProcessing(RunEnvironment): skill_score_competitive = {} skill_score_climatological = {} for station in self.test_data: + logging.info(str(station)) file = os.path.join(path, f"forecasts_{str(station)}_test.nc") data = xr.open_dataarray(file) - competitor = self._create_competitor_forecast(str(station)) - combined = xr.concat([data, competitor], dim="type") - skill_score = statistics.SkillScores(combined) - external_data = self._get_external_data(station) + competitor = self.load_competitors(str(station)) + combined = xr.concat([data, competitor], dim="type") if competitor is not None else data + skill_score = statistics.SkillScores(combined, models=remove_items(list(combined.type.values), "obs")) + external_data = self._get_external_data(station) # ToDo: check if external is still right? skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time) skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data, self.window_lead_time)