diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 827a53ce869af7d7935f375c6c5007624f40cc66..7a31f83fdf09ceed649aa02d4eecb74a9165eba0 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -79,6 +79,7 @@ class PostProcessing(RunEnvironment): self.train_data = self.data_store.get("data_collection", "train") self.val_data = self.data_store.get("data_collection", "val") self.train_val_data = self.data_store.get("data_collection", "train_val") + self.forecast_path = self.data_store.get("forecast_path") self.plot_path: str = self.data_store.get("plot_path") self.target_var = self.data_store.get("target_var") self._sampling = self.data_store.get("sampling") @@ -202,7 +203,6 @@ class PostProcessing(RunEnvironment): station or actual data contained. This is intended to analyze not only the robustness against the time but also against the number of observations and diversity ot stations. """ - path = self.data_store.get("forecast_path") all_stations = self.data_store.get("stations", "test") start = self.data_store.get("start", "test") end = self.data_store.get("end", "test") @@ -211,7 +211,7 @@ class PostProcessing(RunEnvironment): collector = [] for station in all_stations: # test data - external_data = self._get_external_data(station, path) + external_data = self._get_external_data(station, self.forecast_path) if external_data is not None: pass # competitors @@ -335,7 +335,6 @@ class PostProcessing(RunEnvironment): # forecast with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"): # extract all requirements from data store - forecast_path = self.data_store.get("forecast_path") number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") dims = [self.uncertainty_estimate_boot_dim, self.index_dim, self.ahead_dim, self.model_type_dim] for station in self.test_data: @@ -355,13 +354,13 @@ class PostProcessing(RunEnvironment): coords = (range(number_of_bootstraps), range(shape[0]), range(1, shape[1] + 1)) var = f"{index}_{dimension}" if index is not None else str(dimension) tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims) - file_name = os.path.join(forecast_path, + file_name = os.path.join(self.forecast_path, f"bootstraps_{station}_{var}_{bootstrap_type}_{bootstrap_method}.nc") tmp.to_netcdf(file_name) else: # store also true labels for each station labels = np.expand_dims(Y[..., 0], axis=-1) - file_name = os.path.join(forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc") + file_name = os.path.join(self.forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc") labels = xr.DataArray(labels, coords=(*coords[1:], [self.observation_indicator]), dims=dims[1:]) labels.to_netcdf(file_name) @@ -377,7 +376,6 @@ class PostProcessing(RunEnvironment): """ with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"): # extract all requirements from data store - forecast_path = self.data_store.get("forecast_path") number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") forecast_file = f"forecasts_norm_%s_test.nc" reference_name = "orig" @@ -393,19 +391,20 @@ class PostProcessing(RunEnvironment): score = {} for station in self.test_data: # get station labels - file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc") + file_name = os.path.join(self.forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc") with xr.open_dataarray(file_name) as da: labels = da.load() # get original forecasts - orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), reference_name=reference_name) + orig = self.get_orig_prediction(self.forecast_path, forecast_file % str(station), + reference_name=reference_name) orig.coords[self.index_dim] = labels.coords[self.index_dim] # calculate skill scores for each variable skill = [] for boot_set in bootstrap_iter: boot_var = f"{boot_set[0]}_{boot_set[1]}" if isinstance(boot_set, tuple) else str(boot_set) - file_name = os.path.join(forecast_path, + file_name = os.path.join(self.forecast_path, f"bootstraps_{station}_{boot_var}_{bootstrap_type}_{bootstrap_method}.nc") with xr.open_dataarray(file_name) as da: boot_data = da.load() @@ -508,7 +507,6 @@ class PostProcessing(RunEnvironment): """ logging.info("Run plotting routines...") - path = self.data_store.get("forecast_path") use_multiprocessing = self.data_store.get("use_multiprocessing") plot_list = self.data_store.get("plot_list", "postprocessing") @@ -547,8 +545,8 @@ class PostProcessing(RunEnvironment): try: if "PlotConditionalQuantiles" in plot_list: - PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=path, plot_folder=self.plot_path, - forecast_indicator=self.forecast_indicator, + PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=self.forecast_path, + plot_folder=self.plot_path, forecast_indicator=self.forecast_indicator, obs_indicator=self.observation_indicator) except Exception as e: logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error:" @@ -556,7 +554,7 @@ class PostProcessing(RunEnvironment): try: if "PlotMonthlySummary" in plot_list: - PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var, + PlotMonthlySummary(self.test_data.keys(), self.forecast_path, r"forecasts_%s_test.nc", self.target_var, plot_folder=self.plot_path) except Exception as e: logging.error(f"Could not create plot PlotMonthlySummary due to the following error:" @@ -582,8 +580,8 @@ class PostProcessing(RunEnvironment): try: if "PlotTimeSeries" in plot_list: - PlotTimeSeries(self.test_data.keys(), path, r"forecasts_%s_test.nc", plot_folder=self.plot_path, - sampling=self._sampling, ahead_dim=self.ahead_dim) + PlotTimeSeries(self.test_data.keys(), self.forecast_path, r"forecasts_%s_test.nc", + plot_folder=self.plot_path, sampling=self._sampling, ahead_dim=self.ahead_dim) except Exception as e: logging.error(f"Could not create plot PlotTimeSeries due to the following error:\n{sys.exc_info()[0]}\n" f"{sys.exc_info()[1]}\n{sys.exc_info()[2]}\n{traceback.format_exc()}") @@ -687,7 +685,6 @@ class PostProcessing(RunEnvironment): logging.info(f"start make_prediction for {subset_type}") time_dimension = self.data_store.get("time_dim") window_dim = self.data_store.get("window_dim") - path = self.data_store.get("forecast_path") subset_type = subset.name for i, data in enumerate(subset): input_data = data.get_X() @@ -728,7 +725,7 @@ class PostProcessing(RunEnvironment): # save all forecasts locally prefix = "forecasts_norm" if normalised is True else "forecasts" - file = os.path.join(path, f"{prefix}_{str(data)}_{subset_type}.nc") + file = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_{subset_type}.nc") all_predictions.to_netcdf(file) def _get_frequency(self) -> str: @@ -956,14 +953,13 @@ class PostProcessing(RunEnvironment): :return: competitive and climatological skill scores, error metrics """ - path = self.data_store.get("forecast_path") all_stations = self.data_store.get("stations") skill_score_competitive = {} skill_score_competitive_count = {} skill_score_climatological = {} errors = {} for station in all_stations: - external_data = self._get_external_data(station, path) # test data + external_data = self._get_external_data(station, self.forecast_path) # test data # test errors if external_data is not None: @@ -1002,7 +998,7 @@ class PostProcessing(RunEnvironment): if external_data is not None: skill_score_competitive[station], skill_score_competitive_count[station] = skill_score.skill_scores() - internal_data = self._get_internal_data(station, path) + internal_data = self._get_internal_data(station, self.forecast_path) if internal_data is not None: skill_score_climatological[station] = skill_score.climatological_skill_scores( internal_data, forecast_name=self.forecast_indicator)