diff --git a/HPC_setup/requirements_HDFML_additionals.txt b/HPC_setup/requirements_HDFML_additionals.txt
index 1a9e8524906115e02338dcf80137081ab7165697..af3b34f1799609747afb8f2ddda8014c696706e0 100644
--- a/HPC_setup/requirements_HDFML_additionals.txt
+++ b/HPC_setup/requirements_HDFML_additionals.txt
@@ -1,4 +1,5 @@
 astropy==5.1
+ensverif==0.0.8
 pytz==2022.1
 python-dateutil==2.8.2
 requests==2.28.1
diff --git a/HPC_setup/requirements_JUWELS_additionals.txt b/HPC_setup/requirements_JUWELS_additionals.txt
index 1a9e8524906115e02338dcf80137081ab7165697..337db90a9ae1fe0b4e22a1b9e8f883c88739c6c0 100644
--- a/HPC_setup/requirements_JUWELS_additionals.txt
+++ b/HPC_setup/requirements_JUWELS_additionals.txt
@@ -1,5 +1,6 @@
 astropy==5.1
 pytz==2022.1
+ensverif==0.0.8
 python-dateutil==2.8.2
 requests==2.28.1
 werkzeug>=0.11.15
diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index 9bb15068ce3a5ad934f7b0251b84cb19f37702f6..4a4f550b4505fcd15411f9cbb0a294142f05952d 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -30,10 +30,13 @@ DEFAULT_EARLY_STOPPING_EPOCHS = np.inf
 DEFAULT_RESTORE_BEST_MODEL_WEIGHTS = True
 DEFAULT_TARGET_VAR = "o3"
 DEFAULT_TARGET_DIM = "variables"
+DEFAULT_TARGET_VAR_UNIT = "ppb"
 DEFAULT_WINDOW_LEAD_TIME = 3
 DEFAULT_WINDOW_DIM = "window"
 DEFAULT_TIME_DIM = "datetime"
 DEFAULT_ITER_DIM = "Stations"
+DEFAULT_ENS_REALIZ_DIM = ("realizations", None)
+DEFAULT_ENS_MOMENT_DIM = ("moments", None)
 DEFAULT_DIMENSIONS = {"new_index": [DEFAULT_TIME_DIM, DEFAULT_ITER_DIM]}
 DEFAULT_INTERPOLATION_METHOD = "linear"
 DEFAULT_INTERPOLATION_LIMIT = 1
@@ -70,6 +73,7 @@ DEFAULT_USE_MULTIPROCESSING = True
 DEFAULT_USE_MULTIPROCESSING_ON_DEBUG = False
 DEFAULT_MAX_NUMBER_MULTIPROCESSING = 16
 DEFAULT_CREATE_SNAPSHOT = False
+DEFAULT_NUMBER_OF_REALIZATIONS = None
 
 
 def get_defaults():
diff --git a/mlair/helpers/__init__.py b/mlair/helpers/__init__.py
index cf50fa05885d576bd64de67b83df3c8ed6d272e2..1c194d09cdd1fcfa5d158b81c7d15dc16205961c 100644
--- a/mlair/helpers/__init__.py
+++ b/mlair/helpers/__init__.py
@@ -4,4 +4,4 @@ from .testing import PyTestRegex, PyTestAllEqual, check_nested_equality
 from .time_tracking import TimeTracking, TimeTrackingWrapper
 from .logger import Logger
 from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, \
-    make_keras_pickable, sort_like, filter_dict_by_value
+    make_keras_pickable, sort_like, filter_dict_by_value, get_sampling
diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py
index 6bd616c6f17081544a2eb379a427b091dca6c9b1..c2e88417c5b9fcd8b28b3ceb081769ec81d2cb97 100644
--- a/mlair/helpers/helpers.py
+++ b/mlair/helpers/helpers.py
@@ -283,6 +283,24 @@ def str2bool(v):
         raise argparse.ArgumentTypeError('Boolean value expected.')
 
 
+def get_sampling(sampling):
+    """
+    Get letter abbreviation for sampling
+    :param sampling: long name of sampling frequency
+    :return:
+    :rtype:
+    """
+    implemented_samplings = {
+        "daily": "d",
+        "hourly": "h"
+    }
+    return implemented_samplings[sampling]
+    # if sampling == "daily":
+    #     return "D"
+    # elif sampling == "hourly":
+    #     return "h"
+
+
 # def convert_size(size_bytes):
 #     if size_bytes == 0:
 #         return "0B"
diff --git a/mlair/model_modules/probability_models.py b/mlair/model_modules/probability_models.py
index 9ffe77a5c561a4903a548062f2b015c623e66c69..dd0c3035ebd59957937fe35e8a81f9a3c5913166 100644
--- a/mlair/model_modules/probability_models.py
+++ b/mlair/model_modules/probability_models.py
@@ -304,22 +304,42 @@ class MyUnetProb(AbstractModelClass):
        
         
         #outputs = tfpl.IndependentNormal(self._output_shape)(outputs)
-        params_size = tfpl.MixtureSameFamily.params_size(
-            self.k_mixed_components,
-            component_params_size=tfpl.MultivariateNormalTriL.params_size(self._output_shape)
-        )
+        # params_size = tfpl.MixtureSameFamily.params_size(
+        #    self.k_mixed_components,
+        #    component_params_size=tfpl.MultivariateNormalTriL.params_size(self._output_shape)
+        # )
+
+        params_size = tfpl.MultivariateNormalTriL.params_size(self._output_shape)
 
         pars = tf.keras.layers.Dense(params_size)(dl)
         # pars = DenseVariationalCustom(
         #     units=params_size, make_prior_fn=prior, make_posterior_fn=posterior,
-        #     kl_use_exact=True, kl_weight=1./self.x_train_shape)(dl)
-
-        outputs = tfpl.MixtureSameFamily(self.k_mixed_components,
-                                         tfpl.MultivariateNormalTriL(
-                                             self._output_shape,
-                                             convert_to_tensor_fn=tfp.distributions.Distribution.mode
-                                         )
-                                         )(pars)
+        #     kl_use_exact=False, kl_weight=1./self.num_of_training_samples)(dl)
+
+        # outputs = tfpl.MixtureSameFamily(self.k_mixed_components,
+        #                                 tfpl.MultivariateNormalTriL(
+        #                                     self._output_shape,
+        #                                     convert_to_tensor_fn=tfp.distributions.Distribution.mode
+        #                                 )
+        #                                 )(pars)
+
+
+        # outputs = tfpl.MultivariateNormalTriL(
+        #     self._output_shape,
+        #     convert_to_tensor_fn=tfp.distributions.Distribution.mode
+        # )(pars)
+
+        outputs = tfpl.MultivariateNormalTriL(
+            self._output_shape,
+            # lambda s: s.sample(10),
+            sample_real(10),
+            activity_regularizer=tfpl.KLDivergenceRegularizer(
+                tfd.MultivariateNormalDiag(loc=tf.zeros(self._output_shape),
+                                           scale_diag=tf.ones(self._output_shape)),
+                weight=self.num_of_training_samples
+            )
+            # convert_to_tensor_fn=tfp.distributions.Distribution.mode
+        )(pars)
 
         self.model = keras.Model(inputs=input_train, outputs=outputs)
 
@@ -839,6 +859,14 @@ class Convolution2DReparameterizationCustom(tfpl.Convolution2DReparameterization
             })
         return config
 
+def sample_real(n_real=10):
+
+    global sample
+    def sample(s):
+        return s.sample(n_real)
+
+    return sample
+
 
 if __name__ == "__main__":
 
diff --git a/mlair/plotting/abstract_plot_class.py b/mlair/plotting/abstract_plot_class.py
index a26023bb6cb8772623479491ac8bcc731dd42223..d0c177cd9b988882a8b7ed0756b6e4bd952dc33b 100644
--- a/mlair/plotting/abstract_plot_class.py
+++ b/mlair/plotting/abstract_plot_class.py
@@ -6,6 +6,7 @@ import logging
 import os
 
 from matplotlib import pyplot as plt
+from mlair.helpers.helpers import get_sampling
 
 
 class AbstractPlotClass:  # pragma: no cover
@@ -93,10 +94,11 @@ class AbstractPlotClass:  # pragma: no cover
 
     @staticmethod
     def _get_sampling(sampling):
-        if sampling == "daily":
-            return "D"
-        elif sampling == "hourly":
-            return "h"
+        return get_sampling(sampling)
+        # if sampling == "daily":
+        #     return "D"
+        # elif sampling == "hourly":
+        #     return "h"
 
     @staticmethod
     def get_dataset_colors():
@@ -105,3 +107,9 @@ class AbstractPlotClass:  # pragma: no cover
         """
         colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"}  # hex code
         return colors
+
+    @staticmethod
+    def _get_target_sampling(sampling, pos):
+        sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
+        sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
+        return sampling, sampling_letter
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 7540ac50f347b1c037f07fdf18bacd16859fa0bc..2bcc7c3df3f64aa5af1e15d96aaf7f6e67eace00 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -17,7 +17,6 @@ import seaborn as sns
 import xarray as xr
 from matplotlib.backends.backend_pdf import PdfPages
 from matplotlib.offsetbox import AnchoredText
-from scipy.stats import mannwhitneyu
 
 from mlair import helpers
 from mlair.data_handler.iterator import DataCollection
@@ -759,12 +758,6 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
         self._number_of_bootstraps = np.unique(data.coords[self._boot_dim].values).shape[0]
         return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()).dropna()
 
-    @staticmethod
-    def _get_target_sampling(sampling, pos):
-        sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
-        sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
-        return sampling, sampling_letter
-
     def _return_vars_without_number_tag(self, values, split_by, keep, as_unique=False):
         arr = np.array([v.split(split_by) for v in values])
         num = arr[:, 0]
@@ -1325,6 +1318,48 @@ class PlotTimeEvolutionMetric(AbstractPlotClass):
         self._save()
 
 
+class PlotRankHistogram(AbstractPlotClass):
+    """
+
+    """
+
+    def __init__(self, data: dict, plot_folder: str = ".", ahead_dim: str = "ahead", target_unit: str = None,
+                 target_var: str = None):
+        super().__init__(plot_folder, "rank_historgam_plot")
+        self.data = data
+        self.ahead_dim = ahead_dim
+        self.target_unit = target_unit
+        self.target_var = target_var
+        self._plot()
+
+    def _plot(self):
+        for key, value in self.data.items():
+            plot_name = f"{self.plot_name}_{key}_plot.pdf"
+            plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name)
+            pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
+            for ah in value[self.ahead_dim]:
+                fig, ax = plt.subplots()
+                bins = value.sel({"rank_hist_type": "bins", self.ahead_dim: ah.values})
+                freq = value.sel({"rank_hist_type": "freq", self.ahead_dim: ah.values})
+                relfreq = freq/freq.sum()
+                ax.bar(bins, relfreq, align='center', color="gray")
+                ax.hlines(1. / len(bins), 0, len(bins), color="black", linestyle="dashed")
+                ax.set_xlabel(r"Verification Rank ($n_{ens}+1$)")
+                ax.set_ylabel("Relative Frequency")
+
+                plt.title(f"{ah.values}")
+                pdf_pages.savefig()
+                plt.close('all')
+                # close all open figures / plots
+            pdf_pages.close()
+            # plt.close('all')
+        print("test")
+
+    def _load_data(self, subset):
+        pass
+
+
+
 if __name__ == "__main__":
     stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
     path = "../../testrun_network/forecasts"
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index f89633cbe0f80f26dbb2481ca24a7fd294ee6888..4f220cf5416ca5559dd593bfd39f2023d4da2d06 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -24,7 +24,8 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \
     DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \
     DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE, DEFAULT_CREATE_SNAPSHOT, \
-    DEFAULT_EARLY_STOPPING_EPOCHS, DEFAULT_RESTORE_BEST_MODEL_WEIGHTS, DEFAULT_COMPETITORS
+    DEFAULT_EARLY_STOPPING_EPOCHS, DEFAULT_RESTORE_BEST_MODEL_WEIGHTS, DEFAULT_COMPETITORS, DEFAULT_NUMBER_OF_REALIZATIONS, \
+    DEFAULT_ENS_MOMENT_DIM, DEFAULT_ENS_REALIZ_DIM, DEFAULT_TARGET_VAR_UNIT
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
@@ -68,6 +69,9 @@ class ExperimentSetup(RunEnvironment):
         * `target_var` [.]
         * `target_dim` [.]
         * `window_lead_time` [.]
+        * `num_realizations` [postprocessing]
+        * `ens_moment_dim` [postprocessing]
+        * `ens_realization_dim` [postprocessing]
 
     Creates
         * plot of model architecture in `<model_name>.pdf`
@@ -96,6 +100,7 @@ class ExperimentSetup(RunEnvironment):
     :param target_var: target variable to predict by model, currently only a single target variable is supported.
         Because this framework was originally designed to predict ozone, default is `"o3"`.
     :param target_dim: dimension of target variable (default `"variables"`).
+    :param target_var_unit: unit of target variable (e.g used in plots)
     :param window_lead_time: number of time steps to predict by model (default 3). Time steps `t_0+1` to `t_0+w` are
         predicted.
     :param dimensions:
@@ -202,6 +207,9 @@ class ExperimentSetup(RunEnvironment):
         only for storing a snapshot, `snapshot_load_path` indicates where to load the snapshot from. If this parameter
         is not provided at all, no snapshot is loaded. Note, the workflow will apply the default preprocessing without
         loading a snapshot only if this parameter is None!
+    :param num_realizations: Number of realizations to be applied using a probabilistic neural network
+    :param ens_moment_dim: Name of dimension used to extract ensemble moments
+    :param ens_realization_dim: Name of dimension used for realizations of ensemble
     """
 
     def __init__(self,
@@ -242,7 +250,8 @@ class ExperimentSetup(RunEnvironment):
                  uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None,
                  do_uncertainty_estimate: bool = None, model_display_name: str = None, transformation_file: str = None,
                  calculate_fresh_transformation: bool = None, snapshot_load_path: str = None,
-                 create_snapshot: bool = None, snapshot_path: str = None, **kwargs):
+                 create_snapshot: bool = None, snapshot_path: str = None, num_realizations: int = None,
+                 ens_moment_dim: str = None, ens_realization_dim: str = None, target_var_unit: str = None, **kwargs):
 
         # create run framework
         super().__init__()
@@ -352,6 +361,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
         self._set_param("target_dim", target_dim, default=DEFAULT_TARGET_DIM)
         self._set_param("window_lead_time", window_lead_time, default=DEFAULT_WINDOW_LEAD_TIME)
+        self._set_param("target_var_unit", target_var_unit, default=DEFAULT_TARGET_VAR_UNIT)
 
         # interpolation
         self._set_param("dimensions", dimensions, default=DEFAULT_DIMENSIONS)
@@ -408,6 +418,22 @@ class ExperimentSetup(RunEnvironment):
                         default=DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, scope="feature_importance")
 
         self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
+
+        # set ensemble related parameters
+        self._set_param("num_realizations", num_realizations, default=DEFAULT_NUMBER_OF_REALIZATIONS,
+                        scope="general.postprocessing")
+        if num_realizations is not None:
+            self._set_param("ens_moment_dim", ens_moment_dim,
+                            default=DEFAULT_ENS_MOMENT_DIM[0], scope="general.postprocessing")
+            self._set_param("ens_realization_dim", ens_realization_dim,
+                            default=DEFAULT_ENS_REALIZ_DIM[0], scope="general.postprocessing")
+        else:
+            self._set_param("ens_moment_dim", ens_moment_dim,
+                            default=DEFAULT_ENS_MOMENT_DIM[1], scope="general.postprocessing")
+            self._set_param("ens_realization_dim", ens_realization_dim,
+                            default=DEFAULT_ENS_REALIZ_DIM[1], scope="general.postprocessing")
+
+
         if model_display_name is not None:
             self._set_param("model_display_name", model_display_name)
         self._set_param("neighbors", ["DEBW030"])  # TODO: just for testing
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index bff77e86fadf50d3f0616aaa3c1834715340599d..b6a0b33ea1d8f03d6e8c8049365b6af933869d9a 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -6,26 +6,30 @@ __date__ = '2019-12-11'
 import inspect
 import logging
 import os
+import pickle
 import sys
 import traceback
 import copy
 from typing import Dict, Tuple, Union, List, Callable
+import ensverif
+import glob
 
 import numpy as np
 import pandas as pd
 import xarray as xr
 import datetime as dt
+import tensorflow as tf
 
 from mlair.configuration import path_config
 from mlair.data_handler import Bootstraps, KerasIterator
 from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope
 from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables, \
-    data_sources
+    data_sources, get_sampling
 from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
 from mlair.model_modules import AbstractModelClass
 from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \
     PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \
-    PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric
+    PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric, PlotRankHistogram
 from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \
     PlotPeriodogram, PlotDataHistogram
 from mlair.run_modules.run_environment import RunEnvironment
@@ -86,13 +90,19 @@ class PostProcessing(RunEnvironment):
         self.target_var = self.data_store.get("target_var")
         self.target_var_with_stat = data_sources.get_single_var_with_stat_name(
             self.data_store.get("statistics_per_var"), self.target_var)
+        self.target_var_unit = self.data_store.get("target_var_unit")
         self._sampling = self.data_store.get("sampling")
+        self.num_realizations = self.data_store.get("num_realizations", "postprocessing")
+        self.ens_realization_dim = self.data_store.get("ens_realization_dim", "postprocessing")
+        self.ens_moment_dim = self.data_store.get("ens_moment_dim", "postprocessing")
+        self.iter_dim = self.data_store.get("iter_dim")
         self.window_lead_time = extract_value(self.data_store.get("output_shape", "model"))
         self.skill_scores = None
         self.feature_importance_skill_scores = None
         self.uncertainty_estimate = None
         self.uncertainty_estimate_seasons = {}
         self.block_mse_per_station = None
+        self.rank_hist = {}
         self.competitor_path = self.data_store.get("competitor_path")
         self.competitors = to_list(self.data_store.get_default("competitors", default=[]))
         self.forecast_indicator = "nn"
@@ -116,6 +126,12 @@ class PostProcessing(RunEnvironment):
         # calculate error metrics on test data
         self.calculate_test_score()
 
+        # calc/report ens scores
+        if self.num_realizations is not None:
+            for subset in ["test", "train_val"]:
+                self.report_crps(subset)
+
+
         # sample uncertainty
         if self.data_store.get("do_uncertainty_estimate", "postprocessing"):
             self.estimate_sample_uncertainty(separate_ahead=True)
@@ -143,6 +159,156 @@ class PostProcessing(RunEnvironment):
         # plotting
         self.plot()
 
+
+    @TimeTrackingWrapper
+    def report_crps2(self, subset):
+        """
+        Calculate CRPS for all lead times
+        :return:
+        :rtype:
+        """
+        file_pattern = os.path.join(self.forecast_path, f"forecasts_*_ens_{subset}_values.nc")
+        # get ens files with predictions (not normalized)
+        ens_files = [e for e in filter(lambda x: not "_norm" in x, glob.glob(file_pattern))]
+
+        ds = xr.open_mfdataset(ens_files)
+        crps = {}
+        crps_times = {}
+        for i in range(1, self.window_lead_time+1):
+            ens = ds["ens"].sel(
+                {self.ahead_dim: i, self.ens_moment_dim: "ens_dist_mean",
+                 self.model_type_dim: "ens"}
+            ).dropna(self.index_dim)
+            obs = ds["det"].sel(
+                {self.ahead_dim: i, self.model_type_dim: "obs"}
+            ).dropna(self.index_dim)
+            crps[f"{i}{get_sampling(self._sampling)}"] = ensverif.crps.crps(
+                ens.values.reshape(-1, self.num_realizations), obs.values.reshape(-1), distribution="emp")
+            crps_stations = {}
+            for station in ens.coords[self.iter_dim].values:
+                ens_station = ens.sel({self.iter_dim: station}).values
+                obs_station = obs.sel({self.iter_dim: station}).values
+                crps_stations[station] = ensverif.crps.crps(ens_station, obs_station, distribution="emp")
+            crps_times[f"{i}{get_sampling(self._sampling)}"] = crps_stations
+
+        df_tot = pd.DataFrame(crps, index=[subset])
+        df_stations = pd.DataFrame(crps_times)
+
+        report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
+        path_config.check_path_and_create(report_path)
+        self.store_crps_reports(df_tot, report_path, subset, station=False)
+        self.store_crps_reports(df_stations, report_path, subset, station=True)
+
+    @TimeTrackingWrapper
+    def report_crps(self, subset):
+        """
+        Calculate CRPS for all lead times
+        :return:
+        :rtype:
+        """
+        file_pattern = os.path.join(self.forecast_path, f"forecasts_*_ens_{subset}_values.nc")
+        # get ens files with predictions (not normalized)
+        # ens_files = [e for e in filter(lambda x: not "_norm" in x, glob.glob(file_pattern))]
+
+        # ds = xr.open_mfdataset(ens_files)
+        collector_ens = []
+        collector_obs = []
+
+        # crps[f"{i}{get_sampling(self._sampling)}"] = ensverif.crps.crps(
+        #     ens.values.reshape(-1, self.num_realizations), obs.values.reshape(-1), distribution="emp")
+        crps_stations = {}
+        rank_stations = {}
+        idx_counter = 0
+        generators = {"train": self.train_data, "val": self.val_data,
+                      "test": self.test_data, "train_val": self.train_val_data}
+        # for subset, generator in generators.items():
+        for station in generators[subset].keys():
+            station_based_file_name = os.path.join(self.forecast_path, f"forecasts_{station}_ens_{subset}_values.nc")
+            ds = xr.open_mfdataset(station_based_file_name)
+            ens_station = ds["ens"].sel({self.iter_dim: station, self.ens_moment_dim: "ens_dist_mean",
+                                         self.model_type_dim: "ens"}).dropna(self.index_dim)
+            obs_station = ds["det"].sel({self.model_type_dim: "obs"}).dropna(self.index_dim)
+
+            if len(ens_station.coords[self.index_dim]) != 0:
+                new_index = range(idx_counter, idx_counter+len(ens_station[self.index_dim]))
+                ens_reindex = xr.DataArray(data=ens_station.data,
+                                           dims=[self.index_dim, self.ens_realization_dim, self.ahead_dim],
+                                           coords={self.index_dim: new_index,
+                                                   self.ens_realization_dim: ens_station.coords[self.ens_realization_dim],
+                                                   self.ahead_dim: ens_station.coords[self.ahead_dim]})
+                obs_reindex = xr.DataArray(data=obs_station.data,
+                                           dims=[self.index_dim, self.ahead_dim],
+                                           coords={self.index_dim: new_index,
+                                                   self.ahead_dim: ens_station.coords[self.ahead_dim]})
+                collector_ens.append(ens_reindex)
+                collector_obs.append(obs_reindex)
+                idx_counter = new_index[-1]
+
+            crps_times, rank_times = self._calc_crps_for_lead_times(ens_station, obs_station)
+
+            crps_stations[station] = crps_times
+            crps_stations[station] = rank_times
+
+
+        df_stations = pd.DataFrame(crps_stations)
+        report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
+        path_config.check_path_and_create(report_path)
+        self.store_crps_reports(df_stations, report_path, subset, station=True)
+
+        try:
+            full_ens = xr.concat(collector_ens, dim=self.index_dim)
+            full_obs = xr.concat(collector_obs, dim=self.index_dim)
+            crps, ranks = self._calc_crps_for_lead_times(full_ens, full_obs)
+            df_crps_tot = pd.DataFrame(crps, index=to_list(subset))
+
+            self.store_crps_reports(df_crps_tot, report_path, subset, station=False)
+            rh = self._get_rank_hist_from_dict(ranks)
+            self.rank_hist[subset] = rh
+            self._save_rank_hist(rh, subset)
+        except Exception as e:
+            logging.info(f"Can't calc crps for all stations together due to: {e}")
+
+
+    def _get_rank_hist_from_dict(self, data):
+
+        d = np.stack([np.stack([np.insert(v[0].astype(np.float32), 0, np.nan), v[1]]) for v in data.values()])
+        dxr = xr.DataArray(d, dims=[self.ahead_dim, "rank_hist_type", "idx"],
+                           coords={self.ahead_dim: list(data.keys()), "rank_hist_type": ["freq", "bins"],
+                                   "idx": range(d.shape[-1])})
+        return dxr
+    def _save_rank_hist(self, data, subset):
+        report_path = os.path.join(self.data_store.get("experiment_path"), "data")
+        data.to_netcdf(os.path.join(report_path, f"ens_rank_hist_{subset}_data.nc"))
+        # return data
+
+
+
+
+    # @TimeTrackingWrapper
+    def _calc_crps_for_lead_times(self, ens, obs):
+        crps_collector = {}
+        rank_collector = {}
+        for i in range(1, self.window_lead_time + 1):
+            ens_res = ens.sel({self.ahead_dim: i, })
+            obs_res = obs.sel({self.ahead_dim: i, })
+            collector_key = f"{i}{get_sampling(self._sampling)}"
+            crps_collector[collector_key] = ensverif.crps.crps(ens_res, obs_res, distribution="emp")
+            rank_collector[collector_key] = ensverif.rankhist.rankhist(ens_res, obs_res)
+
+        return crps_collector, rank_collector
+
+    @staticmethod
+    def store_crps_reports(df, report_path, subset, station=False):
+        if station is True:
+            file_name = f"crps_stations_{subset}.%s"
+            df = df.transpose()
+        else:
+            file_name = f"crps_summary_{subset}.%s"
+        column_format = tables.create_column_format_for_tex(df)
+        tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df)
+        tables.save_to_md(report_path, file_name % "md", df=df)
+        df.to_csv(file_name % "csv", sep=";")
+
     @TimeTrackingWrapper
     def estimate_sample_uncertainty(self, separate_ahead=False):
         """
@@ -684,6 +850,14 @@ class PostProcessing(RunEnvironment):
             logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}"
                           f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
 
+        try:
+            if "PlotRankHistogram" in plot_list:
+                PlotRankHistogram(self.rank_hist, plot_folder=self.plot_path, target_unit=self.target_var_unit,
+                                  target_var=self.target_var)
+        except Exception as e:
+            logging.error(f"Could not create plot PlotRankHistogram due to the following error: {e}"
+                          f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
+
     @TimeTrackingWrapper
     def calculate_test_score(self):
         """Evaluate test score of model and save locally."""
@@ -730,7 +904,15 @@ class PostProcessing(RunEnvironment):
             # get scaling parameters
             transformation_func = data.apply_transformation
 
-            nn_output = self.model.predict(input_data)
+            if self.num_realizations is None:
+                nn_output = self.model.model(input_data)
+                nn_output = nn_output.numpy()
+            else:
+                with TimeTracking(name=f"Create ensemble predictions for {data.__repr__()} ({subset_type}), {i+1}/{len(subset)}"):
+                    ens_collector = []
+                    for r in range(self.num_realizations):
+                        ens_collector.append(self.model.model(input_data))
+                    nn_output = self._create_ens_mean_pred(ens_collector)
 
             for normalised in [True, False]:
                 # create empty arrays
@@ -769,6 +951,41 @@ class PostProcessing(RunEnvironment):
                 prefix = "forecasts_norm" if normalised is True else "forecasts"
                 file = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_{subset_type}.nc")
                 all_predictions.to_netcdf(file)
+                if self.num_realizations is not None:
+                    nn_ens_dist_prediction = nn_prediction.expand_dims(
+                        {self.ens_realization_dim: range(self.num_realizations),
+                         self.ens_moment_dim: ["ens_dist_mean", "ens_dist_stddev"],
+                         })
+                    nn_ens_dist_predictions = self._create_nn_ens_forecast(ens_collector, nn_ens_dist_prediction,
+                                                                           transformation_func, normalised)
+
+                    nn_ens_dist_predictions_full = self.create_forecast_arrays(
+                        full_index, list(target_data.indexes[window_dim]), time_dimension,
+                        ahead_dim=self.ahead_dim,
+                        index_dim=self.index_dim, type_dim=self.model_type_dim,
+                        ens_dims=[
+                            self.ens_realization_dim,
+                            self.ens_moment_dim],
+                        ens_coords=[
+                            range(self.num_realizations),
+                            ["ens_dist_mean", "ens_dist_stddev"]],
+                        **{"ens": nn_ens_dist_predictions.transpose("datetime", ...)}
+                    )
+                    nn_ens_dist_predictions_full = nn_ens_dist_predictions_full.expand_dims(
+                        {self.iter_dim: to_list(str(nn_ens_dist_predictions[self.iter_dim].values))}
+                    ).transpose(self.index_dim, ...)
+                    all_predictions_ens = xr.Dataset({"ens": nn_ens_dist_predictions_full,
+                                                      "det":  all_predictions,
+                                                      })
+                    file_ens = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_ens_{subset_type}")
+                    all_predictions_ens.to_netcdf(f"{file_ens}_values.nc")
+                    with open(f"{file_ens}_dist.pkl", 'wb') as outp:
+                        pickle.dump(ens_collector, outp, pickle.HIGHEST_PROTOCOL)
+
+    @staticmethod
+    def _create_ens_mean_pred(collector):
+        """Calculates the ens. mean from a list containing ens. members of type tfp.distributions._TensorCoercible"""
+        return tf.reduce_mean(tf.stack(collector), axis=0).numpy()
 
     def _get_frequency(self) -> str:
         """Get frequency abbreviation."""
@@ -865,7 +1082,7 @@ class PostProcessing(RunEnvironment):
             persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True)
         return persistence_prediction
 
-    def _create_nn_forecast(self, nn_output: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
+    def _create_nn_forecast(self, nn_output: np.array, nn_prediction: xr.DataArray, transformation_func: Callable,
                             normalised: bool) -> xr.DataArray:
         """
         Create NN forecast for given input data.
@@ -882,18 +1099,62 @@ class PostProcessing(RunEnvironment):
         :return: filled data array with nn predictions
         """
 
+        nn_prediction = self._set_nn_prediction_based_on_output_shape(nn_output, nn_prediction)
+        if not normalised:
+            nn_prediction = transformation_func(nn_prediction, base="target", inverse=True)
+        return nn_prediction
+
+    @staticmethod
+    def _set_nn_prediction_based_on_output_shape(nn_output, nn_prediction):
+        """Sets output of the main branch (not all minor branches, if the network has multiple output branches). The
+        main branch is defined to be the last entry of all outputs.
+        :param nn_output: Set
+        :type nn_output:
+        :param nn_prediction:
+        :type nn_prediction:
+        :return:
+        :rtype:
+        """
         if isinstance(nn_output, list):
             nn_prediction.values = nn_output[-1]
         elif nn_output.ndim == 3:
             nn_prediction.values = nn_output[-1, ...]
         elif nn_output.ndim == 2:
             nn_prediction.values = nn_output
+        elif nn_prediction.ndim == nn_output.ndim:
+            nn_prediction.values = nn_output
         else:
             raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {nn_output.dims}.")
-        if not normalised:
-            nn_prediction = transformation_func(nn_prediction, base="target", inverse=True)
         return nn_prediction
 
+    def _create_nn_ens_forecast(self, collector, nn_ens_dist_prediction, transformation_func: Callable,
+                            normalised: bool):
+        """
+        Create xarray containing ensamble distribution means and stddev for each realization.
+
+        Mean and stddev are calculated for each realitation based on distributoon layer of BNN
+
+        :param collector:
+        :type collector:
+        :param nn_ens_dist_prediction:
+        :type nn_ens_dist_prediction:
+        :param transformation_func:
+        :type transformation_func:
+        :param normalised:
+        :type normalised:
+        :return:
+        :rtype:
+        """
+        ens_dist_collector = []
+        for realization in collector:
+            ens_dist_collector.append([realization.mean(), realization.stddev()])
+        ens_dist_vals = np.stack(ens_dist_collector)
+        nn_ens_dist_prediction = self._set_nn_prediction_based_on_output_shape(ens_dist_vals, nn_ens_dist_prediction)
+        if not normalised:
+            nn_ens_dist_prediction = transformation_func(nn_ens_dist_prediction, base="target", inverse=True)
+        return nn_ens_dist_prediction
+
+
     @staticmethod
     def _create_empty_prediction_arrays(target_data, count=1):
         """
@@ -927,7 +1188,8 @@ class PostProcessing(RunEnvironment):
 
     @staticmethod
     def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension,
-                               ahead_dim="ahead", index_dim="index", type_dim="type", **kwargs):
+                               ahead_dim="ahead", index_dim="index", type_dim="type",
+                               ens_coords = None, ens_dims=None, **kwargs):
         """
         Combine different forecast types into single xarray.
 
@@ -940,12 +1202,22 @@ class PostProcessing(RunEnvironment):
         """
         kwargs = {k: v for k, v in kwargs.items() if v is not None}
         keys = list(kwargs.keys())
-        res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
-                           coords=[index.index, ahead_names, keys], dims=[index_dim, ahead_dim, type_dim])
+        res_coords = [index.index, ahead_names, keys]
+        res_dims = [index_dim, ahead_dim, type_dim]
+        # res_fill_shape = (len(index.index), len(ahead_names), len(keys))
+        if (ens_coords is not None) and (ens_dims is not None):
+            ens_coords = to_list(ens_coords)
+            ens_dims = to_list(ens_dims)
+            res_coords = to_list(res_coords[0]) + ens_coords + to_list(res_coords[1:])
+            res_dims = to_list(res_dims[0]) + to_list(ens_dims) + to_list(res_dims[1:])
+
+        res_fill_shape = [len(i) for i in res_coords]
+        res = xr.DataArray(np.full(res_fill_shape, np.nan),
+                           coords=res_coords, dims=res_dims)
         for k, v in kwargs.items():
             intersection = set(res.index.values) & set(v.indexes[time_dimension].values)
             match_index = np.array(list(intersection))
-            res.loc[match_index, :, k] = v.loc[match_index]
+            res.loc[match_index, ..., k] = v.loc[match_index]
         return res
 
     def _get_internal_data(self, station: str, path: str) -> Union[xr.DataArray, None]:
diff --git a/requirements.txt b/requirements.txt
index f644ae9257c0b5a18492f8a2d0ef27d1246ec0d4..d772f0847a5fdd3246956e1f5e62c95ee0c1701d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ auto_mix_prep==0.2.0
 Cartopy==0.20.0
 dask==2021.9.1
 dill==0.3.3
+ensverif==0.0.8
 fsspec==2021.10.1
 Keras==2.6.0
 locket==0.2.1
diff --git a/run_bnn.py b/run_bnn.py
index 3642ec6d522aff51516a7eb710d00b04ab137d50..d2e94b6f4cc454a14e01489cf6d540c68d5f5250 100644
--- a/run_bnn.py
+++ b/run_bnn.py
@@ -29,26 +29,40 @@ def load_stations(case=0):
 
 def main(parser_args):
     # tf.compat.v1.disable_v2_behavior()
-    plots = remove_items(DEFAULT_PLOT_LIST, ["PlotConditionalQuantiles", "PlotPeriodogram"])
+    #plots = remove_items(DEFAULT_PLOT_LIST, ["PlotConditionalQuantiles", "PlotPeriodogram"])
     stats_per_var = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
      'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
      'pblheight': 'maximum'}
+    transformation = {'o3': {'method': 'standardise'},
+                      'relhum': {'method': 'min_max'},
+                      'temp': {'method': 'standardise'},
+                      'u': {'method': 'standardise'},
+                      'v': {'method': 'standardise'},
+                      'no': {'method': 'standardise'},
+                      'no2': {'method': 'standardise'},
+                      'cloudcover': {'method': 'min_max'},
+                      'pblheight': {'method': 'standardise'}
+                      }
     workflow = DefaultWorkflow(  # stations=load_stations(),
         #stations=["DEBW087","DEBW013", "DEBW107",  "DEBW076"],
         stations=load_stations(2),
         model=MyUnetProb,
         window_lead_time=4,
         window_history_size=6,
-        epochs=100,
-        batch_size=1024,
+        epochs=200,
+        batch_size=512, #1024,
+        permute_data_on_training=True,
+        transformation=transformation,
         train_model=False, create_new_model=True, network="UBA",
         evaluate_feature_importance=False,  # plot_list=["PlotCompetitiveSkillScore"],
-        # competitors=["test_model", "test_model2"],
-        competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
+        competitors=["IntelliO3-ts"],#["test_model", "test_model2"],
+        competitor_path="/p/scratch/deepacf/intelliaq/kleinert1/MLAIR_competitors/comp_from_toar/o3",
+        #competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
         variables=list(stats_per_var.keys()),
         statistics_per_var=stats_per_var,
         target_var="o3",
         target_var_unit="ppb",
+        num_realizations=30,
         **parser_args.__dict__, start_script=__file__)
     workflow.run()