diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 491aa52e0a9fe0010f77cde315d1f9b7ddb76dfb..cd898e9e03bc60df481011524ead0aa674f0effa 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -81,6 +81,8 @@ class PlotMonthlySummary(AbstractPlotClass): data_nn = data.sel(type=self._model_name).squeeze() if len(data_nn.shape) > 1: data_nn = data_nn.assign_coords(ahead=[f"{days}d" for days in data_nn.coords["ahead"].values]) + else: + data_nn.coords["ahead"].values = str(data_nn.coords["ahead"].values) + "d" data_obs = data.sel(type="obs", ahead=1).squeeze() data_obs.coords["ahead"] = "obs" diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 0d7bfeb4c411eeeb4550bf33e187053ca84cd551..2c31fba97a6ee542cb706c55e7fa11c948f69a40 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -533,7 +533,14 @@ class PostProcessing(RunEnvironment): """ tmp_ols = self.ols_model.predict(input_data) target_shape = ols_prediction.values.shape - ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols + if target_shape != tmp_ols.shape: + if len(target_shape)==2: + new_values = np.swapaxes(tmp_ols,1,0) + else: + new_values = np.swapaxes(tmp_ols, 2, 0) + else: + new_values = tmp_ols + ols_prediction.values = new_values if not normalised: ols_prediction = transformation_func(ols_prediction, "target", inverse=True) return ols_prediction