Commit 90b04251 authored by lukas leufen's avatar lukas leufen 👻
Browse files

added trimm method as applied in #384

parent a58a6487
Pipeline #104313 passed with stages
in 12 minutes and 15 seconds
......@@ -261,11 +261,17 @@ class PostProcessing(RunEnvironment):
"""Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
start_data = data.coords[dim].values[0]
freq = {"daily": "1D", "hourly": "1H"}.get(sampling)
datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq))
_ind = pd.date_range(start, end, freq=freq) # two steps required to include all hours of end interval
datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1), closed="left",
t = data.sel({dim: start_data}, drop=True)
res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords])
res = res.transpose(*data.dims)
res.loc[data.coords] = data
if data.shape == res.shape:
res.loc[data.coords] = data
_d = data.sel({dim: slice(start, end)})
res.loc[_d.coords] = _d
return res
def load_competitors(self, station_name: str) -> xr.DataArray:
......@@ -761,6 +767,7 @@ class PostProcessing(RunEnvironment):
indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will
raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either
there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file.
Forecast is trimmed on interval start and end of test subset.
:param station_name: name of the station to load data for
:param competitor_name: name of the model
......@@ -769,10 +776,12 @@ class PostProcessing(RunEnvironment):
path = os.path.join(self.competitor_path, competitor_name)
file = os.path.join(path, f"forecasts_{station_name}")
with xr.open_dataarray(file) as da:
data = da.load()
data = da.load()
forecast = data.sel(type=[self.forecast_indicator])
forecast.coords[self.model_type_dim] = [competitor_name]
return forecast
# limit forecast to time range of test subset
start, end = self.data_store.get("start", "test"), self.data_store.get("end", "test")
return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end)
def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray:
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment