From 90b0425116996e112c5d3c7907bc61a1ae2f5d73 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Thu, 30 Jun 2022 11:23:30 +0200
Subject: [PATCH] added trimm method as applied in #384

---
 mlair/run_modules/post_processing.py | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 00d82f3c..8c5080f2 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -261,11 +261,17 @@ class PostProcessing(RunEnvironment):
         """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
         start_data = data.coords[dim].values[0]
         freq = {"daily": "1D", "hourly": "1H"}.get(sampling)
-        datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq))
+        _ind = pd.date_range(start, end, freq=freq)  # two steps required to include all hours of end interval
+        datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1), closed="left",
+                                                          freq=freq))
         t = data.sel({dim: start_data}, drop=True)
         res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords])
         res = res.transpose(*data.dims)
-        res.loc[data.coords] = data
+        if data.shape == res.shape:
+            res.loc[data.coords] = data
+        else:
+            _d = data.sel({dim: slice(start, end)})
+            res.loc[_d.coords] = _d
         return res
 
     def load_competitors(self, station_name: str) -> xr.DataArray:
@@ -761,6 +767,7 @@ class PostProcessing(RunEnvironment):
         indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will
         raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either
         there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file.
+        Forecast is trimmed on interval start and end of test subset.
 
         :param station_name: name of the station to load data for
         :param competitor_name: name of the model
@@ -769,10 +776,12 @@ class PostProcessing(RunEnvironment):
         path = os.path.join(self.competitor_path, competitor_name)
         file = os.path.join(path, f"forecasts_{station_name}_test.nc")
         with xr.open_dataarray(file) as da:
-            data = da.load()
+            data = da.load() 
         forecast = data.sel(type=[self.forecast_indicator])
         forecast.coords[self.model_type_dim] = [competitor_name]
-        return forecast
+        # limit forecast to time range of test subset
+        start, end = self.data_store.get("start", "test"), self.data_store.get("end", "test")
+        return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end)
 
     def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray:
         """
-- 
GitLab