Skip to content
Snippets Groups Projects
Select Git revision
  • a18730525f189e9edd6c9210a79b072fb3f1fdb8
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

post_processing.py

Blame
  • post_processing.py 36.54 KiB
    """Post-processing module."""
    
    __author__ = "Lukas Leufen, Felix Kleinert"
    __date__ = '2019-12-11'
    
    import inspect
    import logging
    import os
    from typing import Dict, Tuple, Union, List, Callable
    
    import keras
    import numpy as np
    import pandas as pd
    import xarray as xr
    
    from mlair.configuration import path_config
    from mlair.data_handler import BootStraps, KerasIterator
    from mlair.helpers.datastore import NameNotFoundInDataStore
    from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, to_list, tables
    from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
    from mlair.model_modules import AbstractModelClass
    from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
        PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotAvailabilityHistogram, \
        PlotConditionalQuantiles, PlotSeparationOfScales
    from mlair.run_modules.run_environment import RunEnvironment
    
    
    class PostProcessing(RunEnvironment):
        """
        Perform post-processing for performance evaluation.
    
        Schedule of post-processing:
            #. train a ordinary least squared model (ols) for reference
            #. create forecasts for nn, ols, and persistence
            #. evaluate feature importance with bootstrapped predictions
            #. calculate skill scores
            #. create plots
    
        Required objects [scope] from data store:
            * `best_model` [.] or locally saved model plus `model_name` [model] and `model` [model]
            * `generator` [train, val, test, train_val]
            * `forecast_path` [.]
            * `plot_path` [postprocessing]
            * `model_path` [.]
            * `target_var` [.]
            * `sampling` [.]
            * `output_shape` [model]
            * `evaluate_bootstraps` [postprocessing] and if enabled:
    
                * `create_new_bootstraps` [postprocessing]
                * `bootstrap_path` [postprocessing]
                * `number_of_bootstraps` [postprocessing]
    
        Optional objects
            * `batch_size` [model]
    
        Creates
            * forecasts in `forecast_path` if enabled
            * bootstraps in `bootstrap_path` if enabled
            * plots in `plot_path`
    
        """
    
        def __init__(self):
            """Initialise and run post-processing."""
            super().__init__()
            self.model: keras.Model = self._load_model()
            self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0]
            self.ols_model = None
            self.batch_size: int = self.data_store.get_default("batch_size", "model", 64)
            self.test_data = self.data_store.get("data_collection", "test")
            batch_path = self.data_store.get("batch_path", scope="test")
            self.test_data_distributed = KerasIterator(self.test_data, self.batch_size, model=self.model, name="test",
                                                       batch_path=batch_path)
            self.train_data = self.data_store.get("data_collection", "train")
            self.val_data = self.data_store.get("data_collection", "val")
            self.train_val_data = self.data_store.get("data_collection", "train_val")
            self.plot_path: str = self.data_store.get("plot_path")
            self.target_var = self.data_store.get("target_var")
            self._sampling = self.data_store.get("sampling")
            self.window_lead_time = extract_value(self.data_store.get("output_shape", "model"))
            self.skill_scores = None
            self.bootstrap_skill_scores = None
            self.competitor_path = self.data_store.get("competitor_path")
            self.competitors = to_list(self.data_store.get_default("competitors", default=[]))
            self.forecast_indicator = "nn"
            self._run()
    
        def _run(self):
            # ols model
            self.train_ols_model()
    
            # forecasts on test data
            self.make_prediction(self.test_data)
            self.make_prediction(self.train_val_data)
    
            # calculate error metrics on test data
            self.calculate_test_score()
    
            # bootstraps
            if self.data_store.get("evaluate_bootstraps", "postprocessing"):
                with TimeTracking(name="calculate bootstraps"):
                    create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing")
                    self.bootstrap_postprocessing(create_new_bootstraps)
    
            # skill scores and error metrics
            with TimeTracking(name="calculate skill scores"):
                skill_score_competitive, skill_score_climatological, errors = self.calculate_error_metrics()
                self.skill_scores = (skill_score_competitive, skill_score_climatological)
            self.report_error_metrics(errors)
    
            # plotting
            self.plot()
    
        def load_competitors(self, station_name: str) -> xr.DataArray:
            """
            Load all requested and available competitors for a given station. Forecasts must be available in the competitor
            path like `<competitor_path>/<target_var>/forecasts_<station_name>_test.nc`. The naming style is equal for all
            forecasts of MLAir, so that forecasts of a different experiment can easily be copied into the competitor path
            without any change.
    
            :param station_name: station indicator to load competitors for
    
            :return: a single xarray with all competing forecasts
            """
            competing_predictions = []
            for competitor_name in self.competitors:
                try:
                    prediction = self._create_competitor_forecast(station_name, competitor_name)
                    competing_predictions.append(prediction)
                except (FileNotFoundError, KeyError):
                    logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.")
                    continue
            return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None
    
        def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0) -> None:
            """
            Calculate skill scores of bootstrapped data.
    
            Create bootstrapped data if create_new_bootstraps is true or a failure occurred during skill score calculation
            (this will happen by default, if no bootstrapped data is available locally). Set class attribute
            bootstrap_skill_scores. This method is implemented in a recursive fashion, but is only allowed to call itself
            once.
    
            :param create_new_bootstraps: calculate all bootstrap predictions and overwrite already available predictions
            :param _iter: internal counter to reduce unnecessary recursive calls (maximum number is 2, otherwise something
                went wrong).
            """
            try:
                if create_new_bootstraps:
                    self.create_bootstrap_forecast()
                self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
            except FileNotFoundError:
                if _iter != 0:
                    raise RuntimeError("bootstrap_postprocessing is called for the 2nd time. This means, that calling"
                                       "manually the reason for the failure.")
                logging.info("Couldn't load all files, restart bootstrap postprocessing with create_new_bootstraps=True.")
                self.bootstrap_postprocessing(True, _iter=1)
    
        def create_bootstrap_forecast(self) -> None:
            """
            Create bootstrapped predictions for all stations and variables.
    
            These forecasts are saved in bootstrap_path with the names `bootstraps_{var}_{station}.nc` and
            `bootstraps_labels_{station}.nc`.
            """
            # forecast
            with TimeTracking(name=inspect.stack()[0].function):
                # extract all requirements from data store
                bootstrap_path = self.data_store.get("bootstrap_path")
                forecast_path = self.data_store.get("forecast_path")
                number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
                dims = ["index", "ahead", "type"]
                for station in self.test_data:
                    logging.info(str(station))
                    X, Y = None, None
                    bootstraps = BootStraps(station, number_of_bootstraps)
                    for boot in bootstraps:
                        X, Y, (index, dimension) = boot
                        # make bootstrap predictions
                        bootstrap_predictions = self.model.predict(X)
                        if isinstance(bootstrap_predictions, list):  # if model is branched model
                            bootstrap_predictions = bootstrap_predictions[-1]
                        # save bootstrap predictions separately for each station and variable combination
                        bootstrap_predictions = np.expand_dims(bootstrap_predictions, axis=-1)
                        shape = bootstrap_predictions.shape
                        coords = (range(shape[0]), range(1, shape[1] + 1))
                        var = f"{index}_{dimension}"
                        tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims)
                        file_name = os.path.join(forecast_path, f"bootstraps_{station}_{var}.nc")
                        tmp.to_netcdf(file_name)
                    else:
                        # store also true labels for each station
                        labels = np.expand_dims(Y, axis=-1)
                        file_name = os.path.join(forecast_path, f"bootstraps_{station}_labels.nc")
                        labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=dims)
                        labels.to_netcdf(file_name)
    
        def calculate_bootstrap_skill_scores(self) -> Dict[str, xr.DataArray]:
            """
            Calculate skill score of bootstrapped variables.
    
            Use already created bootstrap predictions and the original predictions (the not-bootstrapped ones) and calculate
            skill scores for the bootstraps. The result is saved as a xarray DataArray in a dictionary structure separated
            for each station (keys of dictionary).
    
            :return: The result dictionary with station-wise skill scores
            """
            with TimeTracking(name=inspect.stack()[0].function):
                # extract all requirements from data store
                bootstrap_path = self.data_store.get("bootstrap_path")
                forecast_path = self.data_store.get("forecast_path")
                number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
                forecast_file = f"forecasts_norm_%s_test.nc"
                bootstraps = BootStraps(self.test_data[0], number_of_bootstraps).bootstraps()
                skill_scores = statistics.SkillScores(None)
                score = {}
                for station in self.test_data:
                    logging.info(station)
    
                    # get station labels
                    file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_labels.nc")
                    labels = xr.open_dataarray(file_name)
                    shape = labels.shape
    
                    # get original forecasts
                    orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), number_of_bootstraps)
                    orig = orig.reshape(shape)
                    coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"])
                    orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"])
    
                    # calculate skill scores for each variable
                    skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1))
                    for boot_set in bootstraps:
                        boot_var = f"{boot_set[0]}_{boot_set[1]}"
                        file_name = os.path.join(forecast_path, f"bootstraps_{station}_{boot_var}.nc")
                        boot_data = xr.open_dataarray(file_name)
                        boot_data = boot_data.combine_first(labels).combine_first(orig)
                        boot_scores = []
                        for ahead in range(1, self.window_lead_time + 1):
                            data = boot_data.sel(ahead=ahead)
                            boot_scores.append(
                                skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name="orig"))
                        skill.loc[boot_var] = np.array(boot_scores)
    
                    # collect all results in single dictionary
                    score[str(station)] = xr.DataArray(skill, dims=["boot_var", "ahead"])
                return score
    
        def get_orig_prediction(self, path, file_name, number_of_bootstraps, prediction_name=None):
            if prediction_name is None:
                prediction_name = self.forecast_indicator
            file = os.path.join(path, file_name)
            prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
            vals = np.tile(prediction.data, (number_of_bootstraps, 1))
            return vals[~np.isnan(vals).any(axis=1), :]
    
        def _get_model_name(self):
            """Return model name without path information."""
            return self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0]
    
        def _load_model(self) -> keras.models:
            """
            Load NN model either from data store or from local path.
    
            :return: the model
            """
            try:
                model = self.data_store.get("best_model")
            except NameNotFoundInDataStore:
                logging.info("No model was saved in data store. Try to load model from experiment path.")
                model_name = self.data_store.get("model_name", "model")
                model_class: AbstractModelClass = self.data_store.get("model", "model")
                model = keras.models.load_model(model_name, custom_objects=model_class.custom_objects)
            return model
    
        # noinspection PyBroadException
        def plot(self):
            """
            Create all plots.
    
            Plots are defined in experiment set up by `plot_list`. As default, all (following) plots are enabled:
    
            * :py:class:`PlotBootstrapSkillScore <src.plotting.postprocessing_plotting.PlotBootstrapSkillScore>`
            * :py:class:`PlotConditionalQuantiles <src.plotting.postprocessing_plotting.PlotConditionalQuantiles>`
            * :py:class:`PlotStationMap <src.plotting.postprocessing_plotting.PlotStationMap>`
            * :py:class:`PlotMonthlySummary <src.plotting.postprocessing_plotting.PlotMonthlySummary>`
            * :py:class:`PlotClimatologicalSkillScore <src.plotting.postprocessing_plotting.PlotClimatologicalSkillScore>`
            * :py:class:`PlotCompetitiveSkillScore <src.plotting.postprocessing_plotting.PlotCompetitiveSkillScore>`
            * :py:class:`PlotTimeSeries <src.plotting.postprocessing_plotting.PlotTimeSeries>`
            * :py:class:`PlotAvailability <src.plotting.postprocessing_plotting.PlotAvailability>`
    
            .. note:: Bootstrap plots are only created if bootstraps are evaluated.
    
            """
            logging.info("Run plotting routines...")
            path = self.data_store.get("forecast_path")
    
            plot_list = self.data_store.get("plot_list", "postprocessing")
            time_dim = self.data_store.get("time_dim")
            window_dim = self.data_store.get("window_dim")
            target_dim = self.data_store.get("target_dim")
            iter_dim = self.data_store.get("iter_dim")
    
            try:
                if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and (
                        "PlotSeparationOfScales" in plot_list):
                    filter_dim = self.data_store.get("filter_dim", None)
                    PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path, time_dim=time_dim,
                                           window_dim=window_dim, target_dim=target_dim, **{"filter_dim": filter_dim})
            except Exception as e:
                logging.error(f"Could not create plot PlotSeparationOfScales due to the following error: {e}")
    
            try:
                if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list):
                    PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path,
                                            model_setup=self.forecast_indicator)
            except Exception as e:
                logging.error(f"Could not create plot PlotBootstrapSkillScore due to the following error: {e}")
    
            try:
                if "PlotConditionalQuantiles" in plot_list:
                    PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=path, plot_folder=self.plot_path)
            except Exception as e:
                logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error: {e}")
    
            try:
                if "PlotStationMap" in plot_list:
                    if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get(
                            "hostname")[:6] in self.data_store.get("hpc_hosts"):
                        logging.warning(
                            f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}")
                    else:
                        gens = [(self.train_data, {"marker": 5, "ms": 9}),
                                (self.val_data, {"marker": 6, "ms": 9}),
                                (self.test_data, {"marker": 4, "ms": 9})]
                        PlotStationMap(generators=gens, plot_folder=self.plot_path)
                        gens = [(self.train_val_data, {"marker": 8, "ms": 9}),
                                (self.test_data, {"marker": 9, "ms": 9})]
                        PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var")
            except Exception as e:
                logging.error(f"Could not create plot PlotStationMap due to the following error: {e}")
    
            try:
                if "PlotMonthlySummary" in plot_list:
                    PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var,
                                       plot_folder=self.plot_path)
            except Exception as e:
                logging.error(f"Could not create plot PlotMonthlySummary due to the following error: {e}")
    
            try:
                if "PlotClimatologicalSkillScore" in plot_list:
                    PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path,
                                                 model_setup=self.forecast_indicator)
                    PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False,
                                                 extra_name_tag="all_terms_", model_setup=self.forecast_indicator)
            except Exception as e:
                logging.error(f"Could not create plot PlotClimatologicalSkillScore due to the following error: {e}")
    
            try:
                if "PlotCompetitiveSkillScore" in plot_list:
                    PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path,
                                              model_setup=self.forecast_indicator)
            except Exception as e:
                logging.error(f"Could not create plot PlotCompetitiveSkillScore due to the following error: {e}")
    
            try:
                if "PlotTimeSeries" in plot_list:
                    PlotTimeSeries(self.test_data.keys(), path, r"forecasts_%s_test.nc", plot_folder=self.plot_path,
                                   sampling=self._sampling)
            except Exception as e:
                logging.error(f"Could not create plot PlotTimeSeries due to the following error: {e}")
    
            try:
                if "PlotAvailability" in plot_list:
                    avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
                    PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dim,
                                     window_dimension=window_dim)
            except Exception as e:
                logging.error(f"Could not create plot PlotAvailability due to the following error: {e}")
    
            try:
                if "PlotAvailabilityHistogram" in plot_list:
                    avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
                    PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path, station_dim=iter_dim,
                                              history_dim=window_dim)
            except Exception as e:
                logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}")
    
        def calculate_test_score(self):
            """Evaluate test score of model and save locally."""
    
            # test scores on transformed data
            test_score = self.model.evaluate_generator(generator=self.test_data_distributed,
                                                       use_multiprocessing=True, verbose=0)
            path = self.data_store.get("model_path")
            with open(os.path.join(path, "test_scores.txt"), "a") as f:
                for index, item in enumerate(to_list(test_score)):
                    logging.info(f"{self.model.metrics_names[index]} (test), {item}")
                    f.write(f"{self.model.metrics_names[index]}, {item}\n")
    
        def train_ols_model(self):
            """Train ordinary least squared model on train data."""
            self.ols_model = OrdinaryLeastSquaredModel(self.train_data)
    
        def make_prediction(self, subset):
            """
            Create predictions for NN, OLS, and persistence and add true observation as reference.
    
            Predictions are filled in an array with full index range. Therefore, predictions can have missing values. All
            predictions for a single station are stored locally under `<forecast/forecast_norm>_<station>_test.nc` and can
            be found inside `forecast_path`.
            """
            subset_type = subset.name
            logging.info(f"start make_prediction for {subset_type}")
            time_dimension = self.data_store.get("time_dim")
            window_dim = self.data_store.get("window_dim")
            subset_type = subset.name
            for i, data in enumerate(subset):
                input_data = data.get_X()
                target_data = data.get_Y(as_numpy=False)
                observation_data = data.get_observation()
    
                # get scaling parameters
                transformation_func = data.apply_transformation
    
                for normalised in [True, False]:
                    # create empty arrays
                    nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
                        target_data, count=4)
    
                    # nn forecast
                    nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_func, normalised)
    
                    # persistence
                    persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
                                                                               transformation_func, normalised)
    
                    # ols
                    ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_func, normalised)
    
                    # observation
                    observation = self._create_observation(target_data, observation, transformation_func, normalised)
    
                    # merge all predictions
                    full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency())
                    prediction_dict = {self.forecast_indicator: nn_prediction,
                                       "persi": persistence_prediction,
                                       "obs": observation,
                                       "ols": ols_prediction}
                    all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]),
                                                                  time_dimension, **prediction_dict)
    
                    # save all forecasts locally
                    path = self.data_store.get("forecast_path")
                    prefix = "forecasts_norm" if normalised is True else "forecasts"
                    file = os.path.join(path, f"{prefix}_{str(data)}_{subset_type}.nc")
                    all_predictions.to_netcdf(file)
    
        def _get_frequency(self) -> str:
            """Get frequency abbreviation."""
            getter = {"daily": "1D", "hourly": "1H"}
            return getter.get(self._sampling, None)
    
        def _create_competitor_forecast(self, station_name: str, competitor_name: str) -> xr.DataArray:
            """
            Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station
            indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will
            raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either
            there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file.
    
            :param station_name: name of the station to load data for
            :param competitor_name: name of the model
            :return: the forecast of the given competitor
            """
            path = os.path.join(self.competitor_path, competitor_name)
            file = os.path.join(path, f"forecasts_{station_name}_test.nc")
            data = xr.open_dataarray(file)
            # data = data.expand_dims(Stations=[station_name])  # ToDo: remove line
            forecast = data.sel(type=[self.forecast_indicator])
            forecast.coords["type"] = [competitor_name]
            return forecast
    
        def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray:
            """
            Create observation as ground truth from given data.
    
            Inverse transformation is applied to the ground truth to get the output in the original space.
    
            :param data: observation
            :param transformation_func: a callable function to apply inverse transformation
            :param normalised: transform ground truth in original space if false, or use normalised predictions if true
    
            :return: filled data array with observation
            """
            if not normalised:
                data = transformation_func(data, "target", inverse=True)
            return data
    
        def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray,
                                 transformation_func: Callable, normalised: bool) -> xr.DataArray:
            """
            Create ordinary least square model forecast with given input data.
    
            Inverse transformation is applied to the forecast to get the output in the original space.
    
            :param input_data: transposed history from DataPrep
            :param ols_prediction: empty array in right shape to fill with data
            :param transformation_func: a callable function to apply inverse transformation
            :param normalised: transform prediction in original space if false, or use normalised predictions if true
    
            :return: filled data array with ols predictions
            """
            tmp_ols = self.ols_model.predict(input_data)
            target_shape = ols_prediction.values.shape
            ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
            if not normalised:
                ols_prediction = transformation_func(ols_prediction, "target", inverse=True)
            return ols_prediction
    
        def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, transformation_func: Callable,
                                         normalised: bool) -> xr.DataArray:
            """
            Create persistence forecast with given data.
    
            Persistence is deviated from the value at t=0 and applied to all following time steps (t+1, ..., t+window).
            Inverse transformation is applied to the forecast to get the output in the original space.
    
            :param data: observation
            :param persistence_prediction: empty array in right shape to fill with data
            :param transformation_func: a callable function to apply inverse transformation
            :param normalised: transform prediction in original space if false, or use normalised predictions if true
    
            :return: filled data array with persistence predictions
            """
            tmp_persi = data.copy()
            persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T
            if not normalised:
                persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True)
            return persistence_prediction
    
        def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
                                normalised: bool) -> xr.DataArray:
            """
            Create NN forecast for given input data.
    
            Inverse transformation is applied to the forecast to get the output in the original space. Furthermore, only the
            output of the main branch is returned (not all minor branches, if the network has multiple output branches). The
            main branch is defined to be the last entry of all outputs.
    
            :param input_data: transposed history from DataPrep
            :param nn_prediction: empty array in right shape to fill with data
            :param transformation_func: a callable function to apply inverse transformation
            :param normalised: transform prediction in original space if false, or use normalised predictions if true
    
            :return: filled data array with nn predictions
            """
            tmp_nn = self.model.predict(input_data)
            if isinstance(tmp_nn, list):
                nn_prediction.values = tmp_nn[-1]
            elif tmp_nn.ndim == 3:
                nn_prediction.values = tmp_nn[-1, ...]
            elif tmp_nn.ndim == 2:
                nn_prediction.values = tmp_nn
            else:
                raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
            if not normalised:
                nn_prediction = transformation_func(nn_prediction, base="target", inverse=True)
            return nn_prediction
    
        @staticmethod
        def _create_empty_prediction_arrays(target_data, count=1):
            """
            Create array to collect all predictions. Expand target data by a station dimension. """
            return [target_data.copy() for _ in range(count)]
    
        @staticmethod
        def create_fullindex(df: Union[xr.DataArray, pd.DataFrame, pd.DatetimeIndex], freq: str) -> pd.DataFrame:
            """
            Create full index from first and last date inside df and resample with given frequency.
    
            :param df: use time range of this data set
            :param freq: frequency of full index
    
            :return: empty data frame with full index.
            """
            if isinstance(df, pd.DataFrame):
                earliest = df.index[0]
                latest = df.index[-1]
            elif isinstance(df, xr.DataArray):
                earliest = df.index[0].values
                latest = df.index[-1].values
            elif isinstance(df, pd.DatetimeIndex):
                earliest = df[0]
                latest = df[-1]
            else:
                raise AttributeError(f"unknown array type. Only pandas dataframes, xarray dataarrays and pandas datetimes "
                                     f"are supported. Given type is {type(df)}.")
            index = pd.DataFrame(index=pd.date_range(earliest, latest, freq=freq))
            return index
    
        @staticmethod
        def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, **kwargs):
            """
            Combine different forecast types into single xarray.
    
            :param index: index for forecasts (e.g. time)
            :param ahead_names: names of ahead values (e.g. hours or days)
            :param kwargs: as xarrays; data of forecasts
    
            :return: xarray of dimension 3: index, ahead_names, # predictions
    
            """
            keys = list(kwargs.keys())
            res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
                               coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type'])
            for k, v in kwargs.items():
                intersection = set(res.index.values) & set(v.indexes[time_dimension].values)
                match_index = np.array(list(intersection))
                res.loc[match_index, :, k] = v.loc[match_index]
            return res
    
        def _get_internal_data(self, station: str, path: str) -> Union[xr.DataArray, None]:
            """
            Get internal data for given station.
    
            Internal data is defined as data that is already known to the model. From an evaluation perspective, this
            refers to data, that is no test data, and therefore to train and val data.
    
            :param station: name of station to load internal data.
            """
            try:
                file = os.path.join(path, f"forecasts_{str(station)}_train_val.nc")
                return xr.open_dataarray(file)
            except (IndexError, KeyError, FileNotFoundError):
                return None
    
        def _get_external_data(self, station: str, path: str) -> Union[xr.DataArray, None]:
            """
            Get external data for given station.
    
            External data is defined as data that is not known to the model. From an evaluation perspective, this refers to
            data, that is not train or val data, and therefore to test data.
    
            :param station: name of station to load external data.
            """
            try:
                file = os.path.join(path, f"forecasts_{str(station)}_test.nc")
                return xr.open_dataarray(file)
            except (IndexError, KeyError, FileNotFoundError):
                return None
    
        @staticmethod
        def _combine_forecasts(forecast, competitor, dim="type"):
            """
            Combine forecast and competitor if both are xarray. If competitor is None, this returns forecasts and vise
            versa.
            """
            try:
                return xr.concat([forecast, competitor], dim=dim)
            except (TypeError, AttributeError):
                return forecast if competitor is None else competitor
    
        def calculate_error_metrics(self) -> Tuple[Dict, Dict, Dict]:
            """
            Calculate error metrics and skill scores of NN forecast.
    
            The competitive skill score compares the NN prediction with persistence and ordinary least squares forecasts.
            Whereas, the climatological skill scores evaluates the NN prediction in terms of meaningfulness in comparison
            to different climatological references.
    
            :return: competitive and climatological skill scores, error metrics
            """
            path = self.data_store.get("forecast_path")
            all_stations = self.data_store.get("stations")
            skill_score_competitive = {}
            skill_score_climatological = {}
            errors = {}
            for station in all_stations:
                external_data = self._get_external_data(station, path)  # test data
    
                # test errors
                if external_data is not None:
                    errors[station] = statistics.calculate_error_metrics(*map(lambda x: external_data.sel(type=x),
                                                                              [self.forecast_indicator, "obs"]),
                                                                         dim="index")
                # skill score
                competitor = self.load_competitors(station)
                combined = self._combine_forecasts(external_data, competitor, dim="type")
                model_list = remove_items(list(combined.type.values), "obs") if combined is not None else None
                skill_score = statistics.SkillScores(combined, models=model_list)
                if external_data is not None:
                    skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time)
    
                internal_data = self._get_internal_data(station, path)
                if internal_data is not None:
                    skill_score_climatological[station] = skill_score.climatological_skill_scores(
                        internal_data, self.window_lead_time, forecast_name=self.forecast_indicator)
    
            errors.update({"total": self.calculate_average_errors(errors)})
            return skill_score_competitive, skill_score_climatological, errors
    
        @staticmethod
        def calculate_average_errors(errors):
            avg_error = {}
            n_total = sum([x.get("n", 0) for _, x in errors.items()])
            for station, station_errors in errors.items():
                n_station = station_errors.get("n")
                for error_metric, val in station_errors.items():
                    new_val = avg_error.get(error_metric, 0) + val * n_station / n_total
                    avg_error[error_metric] = new_val
            return avg_error
    
        def report_error_metrics(self, errors):
            report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
            path_config.check_path_and_create(report_path)
            metric_collection = {}
            for station, station_errors in errors.items():
                for metric, vals in station_errors.items():
                    if metric == "n":
                        continue
                    pd_vals = pd.DataFrame.from_dict({station: vals}).T
                    pd_vals.columns = [f"{metric}(t+{x})" for x in vals.coords["ahead"].values]
                    mc = metric_collection.get(metric, pd.DataFrame())
                    mc = mc.append(pd_vals)
                    metric_collection[metric] = mc
            for metric, error_df in metric_collection.items():
                df = error_df.sort_index()
                df.reindex(df.index.drop(["total"]).to_list() + ["total"], )
                column_format = tables.create_column_format_for_tex(df)
                tables.save_to_tex(report_path, f"error_report_{metric}.tex", column_format=column_format, df=df)
                tables.save_to_md(report_path, f"error_report_{metric}.md", df=df)