From 90b0425116996e112c5d3c7907bc61a1ae2f5d73 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Thu, 30 Jun 2022 11:23:30 +0200 Subject: [PATCH] added trimm method as applied in #384 --- mlair/run_modules/post_processing.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 00d82f3c..8c5080f2 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -261,11 +261,17 @@ class PostProcessing(RunEnvironment): """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" start_data = data.coords[dim].values[0] freq = {"daily": "1D", "hourly": "1H"}.get(sampling) - datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq)) + _ind = pd.date_range(start, end, freq=freq) # two steps required to include all hours of end interval + datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1), closed="left", + freq=freq)) t = data.sel({dim: start_data}, drop=True) res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) res = res.transpose(*data.dims) - res.loc[data.coords] = data + if data.shape == res.shape: + res.loc[data.coords] = data + else: + _d = data.sel({dim: slice(start, end)}) + res.loc[_d.coords] = _d return res def load_competitors(self, station_name: str) -> xr.DataArray: @@ -761,6 +767,7 @@ class PostProcessing(RunEnvironment): indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file. + Forecast is trimmed on interval start and end of test subset. :param station_name: name of the station to load data for :param competitor_name: name of the model @@ -769,10 +776,12 @@ class PostProcessing(RunEnvironment): path = os.path.join(self.competitor_path, competitor_name) file = os.path.join(path, f"forecasts_{station_name}_test.nc") with xr.open_dataarray(file) as da: - data = da.load() + data = da.load() forecast = data.sel(type=[self.forecast_indicator]) forecast.coords[self.model_type_dim] = [competitor_name] - return forecast + # limit forecast to time range of test subset + start, end = self.data_store.get("start", "test"), self.data_store.get("end", "test") + return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end) def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray: """ -- GitLab