Commit af82c5fd authored by lukas leufen's avatar lukas leufen 👻
Browse files

apply some changes from #384 here

parent cad70f04
Pipeline #105198 passed with stages
in 12 minutes and 8 seconds
......@@ -419,10 +419,13 @@ class SkillScores:
skill_score.loc[["CASE III", "AIII"], iahead] = np.stack(self._climatological_skill_score(
data, mu_type=3, forecast_name=forecast_name, observation_name=self.observation_name,
skill_score.loc[["CASE IV", "AIV", "BIV", "CIV"], iahead] = np.stack(self._climatological_skill_score(
data, mu_type=4, forecast_name=forecast_name, observation_name=self.observation_name,
skill_score.loc[["CASE IV", "AIV", "BIV", "CIV"], iahead] = np.stack(
self._climatological_skill_score(data, mu_type=4, forecast_name=forecast_name,
except ValueError:
return skill_score
......@@ -13,6 +13,7 @@ from typing import Dict, Tuple, Union, List, Callable
import numpy as np
import pandas as pd
import xarray as xr
import datetime as dt
from mlair.configuration import path_config
from mlair.data_handler import Bootstraps, KerasIterator
......@@ -261,11 +262,17 @@ class PostProcessing(RunEnvironment):
"""Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
start_data = data.coords[dim].values[0]
freq = {"daily": "1D", "hourly": "1H"}.get(sampling)
datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq))
_ind = pd.date_range(start, end, freq=freq) # two steps required to include all hours of end interval
datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1),
closed="left", freq=freq))
t = data.sel({dim: start_data}, drop=True)
res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords])
res = res.transpose(*data.dims)
res.loc[data.coords] = data
if data.shape == res.shape:
res.loc[data.coords] = data
_d = data.sel({dim: slice(start, end)})
res.loc[_d.coords] = _d
return res
def load_competitors(self, station_name: str) -> xr.DataArray:
......@@ -761,6 +768,7 @@ class PostProcessing(RunEnvironment):
indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will
raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either
there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file.
Forecast is trimmed on interval start and end of test subset.
:param station_name: name of the station to load data for
:param competitor_name: name of the model
......@@ -775,7 +783,9 @@ class PostProcessing(RunEnvironment):
forecast.coords[self.model_type_dim] = [competitor_name]
forecast = data.sel({self.model_type_dim: [competitor_name]})
return forecast
# limit forecast to time range of test subset
start, end = self.data_store.get("start", "test"), self.data_store.get("end", "test")
return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end)
def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray:
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment