diff --git a/HPC_setup/requirements_HDFML_additionals.txt b/HPC_setup/requirements_HDFML_additionals.txt
index 90ccd220394c9afb00fba7e069af8e677d4ae0b2..1a9e8524906115e02338dcf80137081ab7165697 100644
--- a/HPC_setup/requirements_HDFML_additionals.txt
+++ b/HPC_setup/requirements_HDFML_additionals.txt
@@ -16,4 +16,6 @@ xarray==0.16.2
 tabulate==0.8.10
 wget==3.2
 pydot==1.4.2
-netcdf4==1.6.0
\ No newline at end of file
+netcdf4==1.6.0
+tensorflow-probability==0.14.1
+tzwhere
\ No newline at end of file
diff --git a/HPC_setup/requirements_JUWELS_additionals.txt b/HPC_setup/requirements_JUWELS_additionals.txt
index 90ccd220394c9afb00fba7e069af8e677d4ae0b2..1a9e8524906115e02338dcf80137081ab7165697 100644
--- a/HPC_setup/requirements_JUWELS_additionals.txt
+++ b/HPC_setup/requirements_JUWELS_additionals.txt
@@ -16,4 +16,6 @@ xarray==0.16.2
 tabulate==0.8.10
 wget==3.2
 pydot==1.4.2
-netcdf4==1.6.0
\ No newline at end of file
+netcdf4==1.6.0
+tensorflow-probability==0.14.1
+tzwhere
\ No newline at end of file
diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py
index 7aae1837e490d116b8dbeae31c5aeb6b459f9287..af75905bd511a4cfb8d8b5023325c678f83f0799 100644
--- a/mlair/data_handler/iterator.py
+++ b/mlair/data_handler/iterator.py
@@ -97,7 +97,13 @@ class KerasIterator(keras.utils.Sequence):
 
     def _get_model_rank(self):
         if self.model is not None:
-            mod_out = self.model.output_shape
+            try:
+                mod_out = self.model.output_shape
+            except AttributeError as e:
+                # ToDo replace except statemnet with something meaningful. Depending on BNN architecture the attr
+                # output_shape might not be defined. We use it here to check the number of tails -> make sure multiple
+                # tails would also work with BNNs in future versions
+                mod_out = (None, None)
             if isinstance(mod_out, tuple):  # only one output branch: (None, ahead)
                 mod_rank = 1
             elif isinstance(mod_out, list):  # multiple output branches, e.g.: [(None, ahead), (None, ahead)]
diff --git a/mlair/model_modules/probability_models.py b/mlair/model_modules/probability_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ffe77a5c561a4903a548062f2b015c623e66c69
--- /dev/null
+++ b/mlair/model_modules/probability_models.py
@@ -0,0 +1,864 @@
+"""
+>>> MyCustomisedModel().model.compile(**kwargs) == MyCustomisedModel().compile(**kwargs)
+True
+
+"""
+
+import mlair.model_modules.keras_extensions
+
+__author__ = "Felix Kleinert"
+__date__ = '2022-07-08'
+
+import tensorflow as tf
+import tensorflow.keras as keras
+import tensorflow_probability as tfp
+tfd = tfp.distributions
+tfb = tfp.bijectors
+tfpl = tfp.layers
+
+import logging
+from mlair.model_modules import AbstractModelClass
+from mlair.model_modules.inception_model import InceptionModelBase
+from mlair.model_modules.flatten import flatten_tail
+from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D
+from mlair.model_modules.loss import l_p_loss
+
+
+class MyUnetProb(AbstractModelClass):
+    def __init__(self, input_shape: list, output_shape: list, num_of_training_samples: int):
+        super().__init__(input_shape[0], output_shape[0])
+        self.first_filter_size = 16  # 16*2#self._input_shape[-1]  # 16
+        self.lstm_units = 64 * 2  # * 2
+        self.kernel_size = (3, 1)  # (3,1)
+        self.activation = "elu"
+        self.pool_size = (2, 1)
+
+        self.num_of_training_samples = num_of_training_samples
+        # self.divergence_fn = lambda q, p, _: tfd.kl_divergence(q, p) / input_shape[0][0]
+        self.divergence_fn = lambda q, p, _: tfd.kl_divergence(q, p) / num_of_training_samples
+
+        # self.loss_fn = lambda y_true, y_pred: -y_pred.log_prob(y_true)
+        # self.loss = nll
+
+        self.dropout = .15  # .2
+        self.k_mixed_components = 2
+        self.kernel_regularizer = keras.regularizers.l1_l2(l1=0.01, l2=0.01)
+        self.bias_regularizer = keras.regularizers.l1_l2(l1=0.01, l2=0.01)
+
+        self.kernel_initializer = 'he_normal'
+
+        self.dense_units = 32 * 2
+        self.initial_lr = 0.001
+
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(SymmetricPadding2D=SymmetricPadding2D, loss=self.loss, divergence_fn=self.divergence_fn)
+
+    def set_model(self):
+        input_train = keras.layers.Input(shape=self._input_shape)
+        pad_size = PadUtils.get_padding_for_same(self.kernel_size)
+
+        c1 = Padding2D("SymPad2D")(padding=pad_size)(input_train)
+        c1 = tfpl.Convolution2DReparameterization(
+                    self.first_filter_size, self.kernel_size, padding='valid',
+                    kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                    kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    kernel_divergence_fn=self.divergence_fn,
+                    bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                    bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    bias_divergence_fn=self.divergence_fn,
+                    activation=self.activation,
+        )(c1)
+        #c1 = keras.layers.Conv2D(self.first_filter_size, self.kernel_size, activation=self.activation,
+        #                         kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer,
+        #                         bias_regularizer=self.bias_regularizer)(c1)
+        c1 = keras.layers.Dropout(self.dropout)(c1)
+        c1 = Padding2D("SymPad2D")(padding=pad_size)(c1)
+        c1 = tfpl.Convolution2DReparameterization(
+                    self.first_filter_size, self.kernel_size, padding='valid',
+                    kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                    kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    kernel_divergence_fn=self.divergence_fn,
+                    bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                    bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    bias_divergence_fn=self.divergence_fn,
+                    activation=self.activation,
+                    name='c1'
+        )(c1)
+        #c1 = keras.layers.Conv2D(self.first_filter_size, self.kernel_size, activation=self.activation,
+        #                         kernel_initializer=self.kernel_initializer, name='c1',
+        #                         kernel_regularizer=self.kernel_regularizer,
+        #                         bias_regularizer=self.bias_regularizer)(c1)
+        p1 = c1
+        # p1 = keras.layers.MaxPooling2D(self.pool_size)(c1)
+
+        c2 = Padding2D("SymPad2D")(padding=pad_size)(p1)
+        c2 = tfpl.Convolution2DReparameterization(
+                    self.first_filter_size * 2, self.kernel_size, padding='valid',
+                    kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                    kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    kernel_divergence_fn=self.divergence_fn,
+                    bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                    bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    bias_divergence_fn=self.divergence_fn,
+                    activation=self.activation,
+        )(c2)
+        # c2 = keras.layers.Conv2D(self.first_filter_size * 2, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c2)
+        c2 = keras.layers.Dropout(self.dropout)(c2)
+        c2 = Padding2D("SymPad2D")(padding=pad_size)(c2)
+        c2 = tfpl.Convolution2DReparameterization(
+                    self.first_filter_size * 2, self.kernel_size, padding='valid',
+                    kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                    kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    kernel_divergence_fn=self.divergence_fn,
+                    bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                    bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    bias_divergence_fn=self.divergence_fn,
+                    activation=self.activation, name="c2"
+        )(c2)
+        # c2 = keras.layers.Conv2D(self.first_filter_size * 2, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, name='c2',
+        #                          kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c2)
+        p2 = c2
+        # p2 = keras.layers.MaxPooling2D(self.pool_size)(c2)
+
+        c3 = Padding2D("SymPad2D")(padding=pad_size)(p2)
+        c3 = tfpl.Convolution2DReparameterization(
+                    self.first_filter_size * 4, self.kernel_size, padding='valid',
+                    kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                    kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    kernel_divergence_fn=self.divergence_fn,
+                    bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                    bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    bias_divergence_fn=self.divergence_fn,
+                    activation=self.activation
+        )(c3)
+        # c3 = keras.layers.Conv2D(self.first_filter_size * 4, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c3)
+        c3 = keras.layers.Dropout(self.dropout * 2)(c3)
+        c3 = Padding2D("SymPad2D")(padding=pad_size)(c3)
+        c3 = tfpl.Convolution2DReparameterization(
+                    self.first_filter_size * 4, self.kernel_size, padding='valid',
+                    kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                    kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    kernel_divergence_fn=self.divergence_fn,
+                    bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                    bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    bias_divergence_fn=self.divergence_fn,
+                    activation=self.activation, name="c3"
+        )(c3)
+        # c3 = keras.layers.Conv2D(self.first_filter_size * 4, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, name='c3',
+        #                          kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c3)
+        # p3 = c3
+        p3 = keras.layers.MaxPooling2D(self.pool_size)(c3)
+
+        ### own LSTM Block ###
+        ls1 = keras.layers.Reshape((p3.shape[1], p3.shape[-1]))(p3)
+        ls1 = keras.layers.LSTM(self.lstm_units, return_sequences=True)(ls1)
+        ls1 = keras.layers.LSTM(self.lstm_units, return_sequences=True)(ls1)
+        c4 = keras.layers.Reshape((p3.shape[1], 1, -1))(ls1)
+
+        ### own 2nd LSTM Block ###
+        ls2 = keras.layers.Reshape((c3.shape[1], c3.shape[-1]))(c3)
+        ls2 = keras.layers.LSTM(self.lstm_units, return_sequences=True)(ls2)
+        ls2 = keras.layers.LSTM(self.lstm_units, return_sequences=True)(ls2)
+        c4_2 = keras.layers.Reshape((c3.shape[1], 1, -1))(ls2)
+
+        u7 = keras.layers.UpSampling2D(size=(3, 1))(c4)
+        cn3 = Padding2D("SymPad2D")(padding=pad_size)(c3)
+        # u7 = c4
+        u7 = keras.layers.concatenate([u7, cn3], name="u7_c3")
+        c7 = u7
+        # c7 = Padding2D("SymPad2D")(padding=pad_size)(u7)
+        c7 = tfpl.Convolution2DReparameterization(
+                    self.first_filter_size * 4, self.kernel_size, padding='valid',
+                    kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                    kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    kernel_divergence_fn=self.divergence_fn,
+                    bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                    bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                    bias_divergence_fn=self.divergence_fn,
+                    activation=self.activation
+        )(c7)
+        # c7 = keras.layers.Conv2D(self.first_filter_size * 4, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c7)
+        c7 = keras.layers.concatenate([c7, c4_2], name="Concat_2nd_LSTM")
+        c7 = keras.layers.Dropout(self.dropout * 2)(c7)
+        c7 = Padding2D("SymPad2D")(padding=pad_size)(c7)
+        c7 = tfpl.Convolution2DReparameterization(
+            self.first_filter_size * 4, self.kernel_size, padding='valid',
+            kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+            kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            kernel_divergence_fn=self.divergence_fn,
+            bias_prior_fn=tfpl.default_multivariate_normal_fn,
+            bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            bias_divergence_fn=self.divergence_fn,
+            activation=self.activation, name='c7_to_u8'
+        )(c7)
+        # c7 = keras.layers.Conv2D(self.first_filter_size * 4, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, name='c7_to_u8',
+        #                          kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c7)
+
+
+        # u8 = Padding2D("SymPad2D")(padding=pad_size)(c7)
+        # u8 = keras.layers.Conv2DTranspose(32, self.pool_size, strides=self.pool_size)(u8)
+        u8 = c7
+        # u8 = c3
+        u8 = keras.layers.concatenate([u8, c2], name="u8_c2")
+        c8 = Padding2D("SymPad2D")(padding=pad_size)(u8)
+        c8 = tfpl.Convolution2DReparameterization(
+            self.first_filter_size * 2, self.kernel_size, padding='valid',
+            kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+            kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            kernel_divergence_fn=self.divergence_fn,
+            bias_prior_fn=tfpl.default_multivariate_normal_fn,
+            bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            bias_divergence_fn=self.divergence_fn,
+            activation=self.activation
+        )(c8)
+        # c8 = keras.layers.Conv2D(self.first_filter_size * 2, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c8)
+        c8 = keras.layers.Dropout(self.dropout)(c8)
+        c8 = Padding2D("SymPad2D")(padding=pad_size)(c8)
+        c8 = tfpl.Convolution2DReparameterization(
+            self.first_filter_size * 2, self.kernel_size, padding='valid',
+            kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+            kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            kernel_divergence_fn=self.divergence_fn,
+            bias_prior_fn=tfpl.default_multivariate_normal_fn,
+            bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            bias_divergence_fn=self.divergence_fn,
+            activation=self.activation, name='c8_to_u9'
+        )(c8)
+        # c8 = keras.layers.Conv2D(self.first_filter_size * 2, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, name='c8_to_u9',
+        #                          kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c8)
+
+        # u9 = Padding2D("SymPad2D")(padding=pad_size)(c8)
+        # u9 = keras.layers.Conv2DTranspose(16, self.pool_size, strides=self.pool_size)(u9)
+        u9 = c8
+        u9 = keras.layers.concatenate([u9, c1], name="u9_c1")
+        c9 = Padding2D("SymPad2D")(padding=pad_size)(u9)
+        c9 = tfpl.Convolution2DReparameterization(
+            self.first_filter_size, self.kernel_size, padding='valid',
+            kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+            kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            kernel_divergence_fn=self.divergence_fn,
+            bias_prior_fn=tfpl.default_multivariate_normal_fn,
+            bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            bias_divergence_fn=self.divergence_fn,
+            activation=self.activation,
+        )(c9)
+        # c9 = keras.layers.Conv2D(self.first_filter_size, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c9)
+        c9 = keras.layers.Dropout(self.dropout)(c9)
+        c9 = Padding2D("SymPad2D")(padding=pad_size)(c9)
+        c9 = tfpl.Convolution2DReparameterization(
+            self.first_filter_size, self.kernel_size, padding='valid',
+            kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+            kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            kernel_divergence_fn=self.divergence_fn,
+            bias_prior_fn=tfpl.default_multivariate_normal_fn,
+            bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+            bias_divergence_fn=self.divergence_fn,
+            activation=self.activation,
+        )(c9)
+        # c9 = keras.layers.Conv2D(self.first_filter_size, self.kernel_size, activation=self.activation,
+        #                          kernel_initializer=self.kernel_initializer, name='c9',
+        #                          kernel_regularizer=self.kernel_regularizer,
+        #                          bias_regularizer=self.bias_regularizer)(c9)
+
+        # outputs = keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)
+        dl = keras.layers.Flatten()(c9)
+        dl = keras.layers.Dropout(self.dropout)(dl)
+
+        # outputs = tfpl.DenseVariational(tfpl.MultivariateNormalTriL.params_size(self._output_shape),
+        #                                 make_posterior_fn=self.posterior,
+        #                                 make_prior_fn=self.prior)(dl)
+        # outputs = tfpl.MultivariateNormalTriL(self._output_shape)(outputs)
+        # outputs = keras.layers.Dense(units=self._output_shape)(dl)
+
+        #outputs = keras.layers.Dense(tfpl.IndependentNormal.params_size(self._output_shape),
+        #                                )(dl)
+        #outputs = tfpl.DenseVariational(units=tfpl.IndependentNormal.params_size(self._output_shape),
+        #                                #make_prior_fn=self.prior,
+        #                                make_prior_fn=prior_trainable,
+        #                                make_posterior_fn=self.posterior,
+        #                                )(dl)
+        #outputs = VarDense(units=tfpl.IndependentNormal.params_size(self._output_shape),
+        #                                 make_prior_fn=self.prior,
+        #                                 make_posterior_fn=self.posterior,
+        #                                 )(dl)
+       
+        
+        #outputs = tfpl.IndependentNormal(self._output_shape)(outputs)
+        params_size = tfpl.MixtureSameFamily.params_size(
+            self.k_mixed_components,
+            component_params_size=tfpl.MultivariateNormalTriL.params_size(self._output_shape)
+        )
+
+        pars = tf.keras.layers.Dense(params_size)(dl)
+        # pars = DenseVariationalCustom(
+        #     units=params_size, make_prior_fn=prior, make_posterior_fn=posterior,
+        #     kl_use_exact=True, kl_weight=1./self.x_train_shape)(dl)
+
+        outputs = tfpl.MixtureSameFamily(self.k_mixed_components,
+                                         tfpl.MultivariateNormalTriL(
+                                             self._output_shape,
+                                             convert_to_tensor_fn=tfp.distributions.Distribution.mode
+                                         )
+                                         )(pars)
+
+        self.model = keras.Model(inputs=input_train, outputs=outputs)
+
+    def set_compile_options(self):
+        # self.optimizer = keras.optimizers.Adam(lr=self.initial_lr,
+                                               # clipnorm=self.clipnorm,
+                                               # )
+
+        # loss = nll
+        # self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=self.initial_lr)
+        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.initial_lr)
+        self.loss = nll
+        self.compile_options = {"metrics": ["mse", "mae"]}
+
+        # loss = keras.losses.MeanSquaredError()
+        # self.compile_options = {"loss": [loss]}
+
+    @staticmethod
+    def prior(kernel_size, bias_size, dtype=None):
+        n = kernel_size + bias_size
+
+        prior_model = tf.keras.Sequential([
+
+            tfpl.DistributionLambda(
+                # Note: Our prior is a non-trianable distribution
+                lambda t: tfd.MultivariateNormalDiag(loc=tf.zeros(n), scale_diag=tf.ones(n)))
+        ])
+
+        return prior_model
+
+    @staticmethod
+    def posterior(kernel_size, bias_size, dtype=None):
+        n = kernel_size + bias_size
+
+        posterior_model = tf.keras.Sequential([
+
+            tfpl.VariableLayer(tfpl.MultivariateNormalTriL.params_size(n), dtype=dtype),
+            tfpl.MultivariateNormalTriL(n)
+        ])
+
+        return posterior_model
+
+
+
+
+class MyCNNProb(AbstractModelClass):
+    """
+    Taken fromhttps://towardsdatascience.com/uncertainty-in-deep-learning-bayesian-cnn-tensorflow-probability-758d7482bef6
+    and modified to our data
+    """
+    def __init__(self, input_shape: list, output_shape: list):
+        super().__init__(input_shape[0], output_shape[0])
+        self.initial_lr = 0.001
+
+        self.divergence_fn = lambda q, p, q_tensor : self.approximate_kl(q, p, q_tensor) / 1000 # check how to get num of samples included here
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(loss_fn=self.loss_fn )
+    
+
+    @staticmethod
+    def loss_fn(y_true, y_pred):
+        return -y_pred.log_prob(y_true)
+
+
+    # For Reparameterization Layers
+
+    @staticmethod
+    def custom_normal_prior(dtype, shape, name, trainable, add_variable_fn):
+        distribution = tfd.Normal(loc = 0.1 * tf.ones(shape, dtype),
+                                  scale = 1.5 * tf.ones(shape, dtype))
+        batch_ndims = tf.size(distribution.batch_shape_tensor())
+
+        distribution = tfd.Independent(distribution,
+                                       reinterpreted_batch_ndims = batch_ndims)
+        return distribution
+
+    @staticmethod
+    def laplace_prior(dtype, shape, name, trainable, add_variable_fn):
+        distribution = tfd.Laplace(loc = tf.zeros(shape, dtype),
+                                   scale = tf.ones(shape, dtype))
+        batch_ndims = tf.size(distribution.batch_shape_tensor())
+
+        distribution = tfd.Independent(distribution,
+                                   reinterpreted_batch_ndims = batch_ndims)
+        return distribution
+
+
+    @staticmethod
+    def approximate_kl(q, p, q_tensor):
+        return tf.reduce_mean(q.log_prob(q_tensor) - p.log_prob(q_tensor))
+
+    
+    def conv_reparameterization_layer(self, filters, kernel_size, activation):
+        # For simplicity, we use default prior and posterior.
+        # In the next parts, we will use custom mixture prior and posteriors.
+        return tfpl.Convolution2DReparameterization(
+                filters = filters,
+                kernel_size = kernel_size,
+                activation = activation, 
+                padding = 'same',
+                kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
+                kernel_prior_fn = tfpl.default_multivariate_normal_fn,
+                
+                bias_prior_fn = tfpl.default_multivariate_normal_fn,
+                bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
+                
+                kernel_divergence_fn = self.divergence_fn,
+                bias_divergence_fn = self.divergence_fn)
+
+    def set_model(self):
+        bayesian_cnn = tf.keras.Sequential([
+            tf.keras.layers.InputLayer(self._input_shape),
+            self.conv_reparameterization_layer(16, 3, 'swish'),
+            #tf.keras.layers.MaxPooling2D(2),
+            self.conv_reparameterization_layer(32, 3, 'swish'),
+            #tf.keras.layers.MaxPooling2D(2),
+            self.conv_reparameterization_layer(64, 3, 'swish'),
+            #tf.keras.layers.MaxPooling2D(2),
+            self.conv_reparameterization_layer(128, 3, 'swish'),
+            #tf.keras.layers.GlobalMaxPooling2D(),
+            tfpl.DenseReparameterization(
+                units=tfpl.IndependentNormal.params_size(self._output_shape), activation=None,
+                kernel_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
+                kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                bias_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
+                kernel_divergence_fn=self.divergence_fn,
+                bias_divergence_fn=self.divergence_fn),
+            tfpl.IndependentNormal(self._output_shape)
+        ])
+
+        input_train = keras.layers.Input(shape=self._input_shape)
+        x = self.conv_reparameterization_layer(16, 3, 'swish')(input_train)
+        # tf.keras.layers.MaxPooling2D(2),
+        x = self.conv_reparameterization_layer(32, 3, 'swish')(x)
+        # tf.keras.layers.MaxPooling2D(2),
+        x = self.conv_reparameterization_layer(64, 3, 'swish')(x)
+        # tf.keras.layers.MaxPooling2D(2),
+        x = self.conv_reparameterization_layer(128, 3, 'swish')(x)
+        x = tf.keras.layers.Flatten()(x)
+        # x = tfpl.DenseReparameterization(
+        #         units=tfpl.IndependentNormal.params_size(self._output_shape), activation=None,
+        #         kernel_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
+        #         kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+        #         bias_prior_fn=tfpl.default_multivariate_normal_fn,
+        #         bias_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
+        #         kernel_divergence_fn=self.divergence_fn,
+        #         bias_divergence_fn=self.divergence_fn)(x)
+        # outputs = tfpl.IndependentNormal(self._output_shape)(x)
+        x = tf.keras.layers.Dense(tfpl.IndependentNormal.params_size(event_shape=self._output_shape))(x)
+        outputs = tfpl.IndependentNormal(event_shape=self._output_shape)(x)
+        # outputs = tfpl.DistributionLambda(
+        #     make_distribution_fn=lambda t: tfd.Normal(
+        #         loc=t[..., 0], scale=tf.exp(t[..., 1])),
+        #     convert_to_tensor_fn=lambda s: s.sample(30))(x)
+
+
+        bnn = keras.Model(inputs=input_train, outputs=outputs)
+        self.model = bnn
+
+
+        logging.info(f"model summary:\n{self.model.summary()}")
+    
+    def set_compile_options(self):
+        self.optimizer = tf.keras.optimizers.Adam(lr=self.initial_lr,
+                                               # clipnorm=self.clipnorm,
+                                               )
+
+        loss = self.loss_fn
+        # self.compile_options = {"loss": [loss], "metrics": ["mse", "mae"]}
+
+        # loss = keras.losses.MeanSquaredError()
+        self.compile_options = {"loss": [loss]}
+
+
+
+
+class VarDense(tf.keras.layers.Layer):
+
+    def __init__(self,
+                 units,
+                 make_posterior_fn,
+                 make_prior_fn,
+                 kl_weight=None,
+                 kl_use_exact=False,
+                 activation=None,
+                 use_bias=True,
+                 activity_regularizer=None,
+                 **kwargs
+                 ):
+        super().__init__(**kwargs)
+        self.units = units
+        self.make_posterior_fn = make_posterior_fn
+        self.make_prior_fn = make_prior_fn
+        self.kl_weight = kl_weight,
+        self.kl_use_exact = kl_use_exact,
+        self.activation = activation,
+        self.use_bias = use_bias,
+        self.activity_regularizer = activity_regularizer
+        self.tfpllayer = tfpl.DenseVariational(units=self.units,
+                              make_prior_fn=self.make_prior_fn,
+                              make_posterior_fn=self.make_posterior_fn,
+                              kl_weight=self.kl_weight,
+                              kl_use_exact=self.kl_use_exact,
+                              use_bias=self.use_bias,
+                              activity_regularizer=self.activity_regularizer
+                              )
+
+    def call(self, inputs):
+        return self.tfpllayer(inputs)
+
+
+
+
+    def get_config(self):
+        config = super().get_config().copy()
+        config.update({
+            "units": self.units,
+            "make_posterior_fn": self.make_posterior_fn,
+            "make_prior_fn": self.make_prior_fn,
+            "kl_weight": self.kl_weight,
+            "kl_use_exact": self.kl_use_exact,
+            "activation": self.activation,
+            "use_bias": self.use_bias,
+            "activity_regularizer": self.activity_regularizer,
+        })
+        return config
+
+
+def prior_trainable(kernel_size, bias_size=0, dtype=None):
+    n = kernel_size + bias_size
+    return tf.keras.Sequential([
+        tfp.layers.VariableLayer(n, dtype=dtype),
+        tfp.layers.DistributionLambda(lambda t: tfd.Independent(
+            tfd.Normal(loc=t, scale=1),
+            reinterpreted_batch_ndims=1)),
+        ])
+
+
+class ProbTestModel(AbstractModelClass):
+    def __init__(self, input_shape: list, output_shape: list):
+        super().__init__(input_shape[0], output_shape[0])
+        self.initial_lr = 0.001
+        self.loss = nll
+
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(nll=nll)
+
+    def set_model(self):
+
+        x_in = keras.layers.Input(self._input_shape)
+        x = keras.layers.Conv2D(kernel_size=(3,1), filters=8,
+                                   activation='relu', padding="same")(x_in)
+        x = keras.layers.Flatten()(x)
+        x = keras.layers.Dense(tfpl.IndependentNormal.params_size(self._output_shape))(x)
+        out = tfpl.IndependentNormal(self._output_shape)(x)
+        model = keras.Model(inputs=x_in, outputs=out)
+
+
+        #model = tf.keras.Sequential([
+        #    keras.layers.InputLayer(self._input_shape),
+        #    keras.layers.Conv2D(kernel_size=(3,1), filters=8,
+        #                           activation='relu', padding="same"),
+
+        #    keras.layers.Flatten(),
+
+        #    keras.layers.Dense(tfpl.IndependentNormal.params_size(self._output_shape)),
+        #    tfpl.IndependentNormal(self._output_shape,
+        #                           convert_to_tensor_fn=tfp.distributions.Distribution.sample
+        #                           )
+
+        #])
+        self.model = model
+        logging.info(self.model.summary())
+
+    def set_compile_options(self):
+        self.optimizer = tf.keras.optimizers.RMSprop(lr=self.initial_lr,
+                                                  # clipnorm=self.clipnorm,
+                                                  )
+
+class ProbTestModel2(AbstractModelClass):
+    def __init__(self, input_shape: list, output_shape: list):
+        super().__init__(input_shape[0], output_shape[0])
+        self.initial_lr = 0.001
+        self.loss = nll
+        self.divergence_fn = lambda q, p, _: tfd.kl_divergence(q, p) / input_shape[0][0]
+
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(nll=nll)
+
+    def set_model(self):
+        model = tf.keras.Sequential([
+            tf.keras.layers.InputLayer(self._input_shape),
+            #tf.keras.layers.Conv2D(kernel_size=(3,1), filters=8,
+            #                       activation='relu', padding="same"),
+            Convolution2DReparameterizationCustom(
+                8, (3,1), padding='same',
+                kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                kernel_divergence_fn=self.divergence_fn,
+                bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                bias_divergence_fn=self.divergence_fn,
+                activation='relu'),
+
+            tf.keras.layers.Flatten(),
+
+            tf.keras.layers.Dense(tfpl.MultivariateNormalTriL.params_size(self._output_shape)),
+            tfpl.MultivariateNormalTriL(self._output_shape,
+                                   convert_to_tensor_fn=tfp.distributions.Distribution.mode
+                                   )
+
+        ])
+        self.model = model
+        logging.info(self.model.summary())
+
+    def set_compile_options(self):
+        self.optimizer = tf.keras.optimizers.RMSprop(lr=self.initial_lr,
+                                                  # clipnorm=self.clipnorm,
+                                                  )
+
+
+
+
+class ProbTestModel3(AbstractModelClass):
+    def __init__(self, input_shape: list, output_shape: list):
+        super().__init__(input_shape[0], output_shape[0])
+
+        self.x_train_shape=100.
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(nll=nll)
+
+    def set_model(self):
+        model = tf.keras.Sequential([
+            keras.layers.Flatten(input_shape=self._input_shape),
+            # Epistemic uncertainty
+            tfpl.DenseVariational(units=8,
+                                  make_prior_fn=prior,
+                                  make_posterior_fn=posterior,
+                                  kl_weight=1/self.x_train_shape,
+                                  kl_use_exact=False,
+                                  activation='sigmoid'),
+        
+            tfpl.DenseVariational(units=tfpl.IndependentNormal.params_size(1),
+                                  make_prior_fn=prior,
+                                  make_posterior_fn=posterior,
+                                  kl_use_exact=False,
+                                  kl_weight=1/self.x_train_shape),
+
+            # Aleatoric uncertainty
+            tfpl.IndependentNormal(1)
+            ])
+        logging.warning(model.summary())
+        self.model = model 
+  
+    def set_compile_options(self):
+        self.optimizer = tf.keras.optimizers.RMSprop()
+
+
+class ProbTestModel4(AbstractModelClass):
+    def __init__(self, input_shape: list, output_shape: list):
+        super().__init__(input_shape[0], output_shape[0])
+
+        self.x_train_shape = 100.
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(nll=nll)
+
+    def set_model(self):
+        model = tf.keras.Sequential([
+            keras.layers.Flatten(input_shape=self._input_shape),
+            # Epistemic uncertainty
+            DenseVariationalCustom(units=8,
+                                  make_prior_fn=prior,
+                                  make_posterior_fn=posterior,
+                                  kl_weight=1 / self.x_train_shape,
+                                  kl_use_exact=False,
+                                  activation='sigmoid'),
+
+            DenseVariationalCustom(units=tfpl.IndependentNormal.params_size(self._output_shape),
+                                  make_prior_fn=prior,
+                                  make_posterior_fn=posterior,
+                                  kl_use_exact=False,
+                                  kl_weight=1 / self.x_train_shape),
+
+            # Aleatoric uncertainty
+            tfpl.IndependentNormal(self._output_shape)
+        ])
+        logging.warning(model.summary())
+        self.model = model
+
+    def set_compile_options(self):
+        self.optimizer = tf.keras.optimizers.RMSprop()
+        self.loss = nll
+
+
+
+class ProbTestModelMixture(AbstractModelClass):
+    def __init__(self, input_shape: list, output_shape: list):
+        super().__init__(input_shape[0], output_shape[0])
+        self.initial_lr = 0.001
+        self.loss = nll
+        self.divergence_fn = lambda q, p, _: tfd.kl_divergence(q, p) / input_shape[0][0]
+        self.k_mixed_components = 2
+
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(nll=nll)
+
+    def set_model(self):
+        x_input = tf.keras.layers.Input(self._input_shape)
+            #tf.keras.layers.Conv2D(kernel_size=(3,1), filters=8,
+            #                       activation='relu', padding="same"),
+        x = Convolution2DReparameterizationCustom(
+                8, (3, 1), padding='same',
+                kernel_prior_fn=tfpl.default_multivariate_normal_fn,
+                kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                kernel_divergence_fn=self.divergence_fn,
+                bias_prior_fn=tfpl.default_multivariate_normal_fn,
+                bias_posterior_fn=tfpl.default_mean_field_normal_fn(),
+                bias_divergence_fn=self.divergence_fn,
+                activation='relu',
+        )(x_input)
+
+        x = tf.keras.layers.Flatten()(x)
+
+        params_size = tfpl.MixtureSameFamily.params_size(
+            self.k_mixed_components,
+            component_params_size=tfpl.MultivariateNormalTriL.params_size(self._output_shape)
+        )
+
+        x = tf.keras.layers.Dense(params_size)(x)
+        #    tfpl.MultivariateNormalTriL(self._output_shape,
+        #                           convert_to_tensor_fn=tfp.distributions.Distribution.mode
+        #                           )
+        out = tfpl.MixtureSameFamily(self.k_mixed_components, tfpl.MultivariateNormalTriL(self._output_shape,
+                                     convert_to_tensor_fn=tfp.distributions.Distribution.mode
+                                     ))(x)
+
+        self.model = tf.keras.Model(inputs=[x_input], outputs=out)
+        logging.info(self.model.summary())
+
+    def set_compile_options(self):
+        self.optimizer = tf.keras.optimizers.RMSprop(lr=self.initial_lr,
+                                                  # clipnorm=self.clipnorm,
+                                                  )
+
+
+
+def nll(y_true, y_pred):
+    """
+    This function should return the negative log-likelihood of each sample
+    in y_true given the predicted distribution y_pred. If y_true is of shape
+    [B, E] and y_pred has batch shape [B] and event_shape [E], the output
+    should be a Tensor of shape [B].
+    """
+    return -y_pred.log_prob(y_true)
+
+
+# Posterior
+def posterior(kernel_size, bias_size, dtype=None):
+
+    n = kernel_size + bias_size
+
+    posterior_model = tf.keras.Sequential([
+
+        tfpl.VariableLayer(tfpl.MultivariateNormalTriL.params_size(n), dtype=dtype),
+        tfpl.MultivariateNormalTriL(n)
+    ])
+
+    return posterior_model
+
+# Prior - diagonal MVN ~ N(0, 1)
+def prior(kernel_size, bias_size, dtype=None):
+
+    n = kernel_size + bias_size
+
+    prior_model = tf.keras.Sequential([
+
+        tfpl.DistributionLambda(
+            # Note: Our prior is a non-trianable distribution
+            lambda t: tfd.MultivariateNormalDiag(loc=tf.zeros(n), scale_diag=tf.ones(n)))
+    ])
+
+    return prior_model
+
+
+class DenseVariationalCustom(tfpl.DenseVariational):
+    """
+    Trying to implement a DensVar that can be stored:
+    https://github.com/tensorflow/probability/commit/0ca065fb526b50ce38b68f7d5b803f02c78c8f16#
+    """
+
+    def get_config(self):
+        config = super().get_config().copy()
+        config.update({
+            'units': self.units,
+            'make_posterior_fn': self._make_posterior_fn,
+            'make_prior_fn': self._make_prior_fn
+        })
+        return config
+
+
+class Convolution2DReparameterizationCustom(tfpl.Convolution2DReparameterization):
+    def get_config(self):
+        config = super().get_config().copy()
+        config.update({
+            # 'units': self.units,
+            # 'make_posterior_fn': self._make_posterior_fn,
+            # 'make_prior_fn': self._make_prior_fn,
+            # 'kernel_divergence_fn': self.divergence_fn,
+            })
+        return config
+
+
+if __name__ == "__main__":
+
+    mylayer = DenseVariationalCustom(units=8,
+                                  make_prior_fn=prior,
+                                  make_posterior_fn=posterior,
+                                  kl_weight=1/100.,
+                                  kl_use_exact=False,
+                                  activation='sigmoid')
+
+    print(mylayer)
+
+
+####  How to access mixture model parameters:
+# https://stackoverflow.com/questions/65918888/mixture-parameters-from-a-tensorflow-probability-mixture-density-network
+# from MLAir perspective:
+#gm = self.model.model(input_data)
+#
+#mixing parameters
+#gm.mixture_distribution.probs_parameter()
+#
+#for parameters see keys and select
+#gm.components_distribution.parameters.keys()
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index eab8012b983a0676620bbc66f65ff79b31165aeb..bf09ac6fc8c63bcfc31024dafb550c84e0ff5df4 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -74,6 +74,9 @@ class ModelSetup(RunEnvironment):
         # set channels depending on inputs
         self._set_shapes()
 
+        # set number of training samples (total)
+        self._set_num_of_training_samples()
+
         # build model graph using settings from my_model_settings()
         self.build_model()
 
@@ -103,6 +106,19 @@ class ModelSetup(RunEnvironment):
         shape = list(map(lambda y: y.shape[1:], self.data_store.get("data_collection", "train")[0].get_Y()))
         self.data_store.set("output_shape", shape, self.scope)
 
+    def _set_num_of_training_samples(self):
+        """ Set number of training samples - needed for example for Bayesian NNs"""
+        samples = 0
+        for s in self.data_store.get("data_collection", "train"):
+            if isinstance(s.get_Y(), list):
+                s_sam = s.get_Y()[0].shape[0]
+            elif isinstance(s.get_Y(), tuple):
+                s_sam = s.get_Y().shape[0]
+            else:
+                s_sam = np.nan
+            samples += s_sam
+        self.num_of_training_samples = samples
+
     def compile_model(self):
         """
         Compiles the keras model. Compile options are mandatory and have to be set by implementing set_compile() method
@@ -162,6 +178,11 @@ class ModelSetup(RunEnvironment):
         """Build model using input and output shapes from data store."""
         model = self.data_store.get("model_class")
         args_list = model.requirements()
+        if "num_of_training_samples" in args_list:
+            self.data_store.set("num_of_training_samples", self.num_of_training_samples, scope=self.scope)
+            logging.info(f"Store number of training samples ({self.num_of_training_samples}) in data_store: "
+                         f"self.data_store.set('num_of_training_samples', {self.num_of_training_samples}, scope='{self.scope}')")
+
         args = self.data_store.create_args_dict(args_list, self.scope)
         self.model = model(**args)
         self.get_model_settings()
@@ -185,9 +206,12 @@ class ModelSetup(RunEnvironment):
 
     def plot_model(self):  # pragma: no cover
         """Plot model architecture as `<model_name>.pdf`."""
-        with tf.device("/cpu:0"):
-            file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
-            keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
+        try:
+            with tf.device("/cpu:0"):
+                file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
+                keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
+        except Exception as e:
+            logging.info(f"Can not plot model due to: {e}")
 
     def report_model(self):
         # report model settings
diff --git a/requirements.txt b/requirements.txt
index 4f911b37f5a27be1f30caf69a613df5deef62a29..f644ae9257c0b5a18492f8a2d0ef27d1246ec0d4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,6 +27,7 @@ six==1.15.0
 statsmodels==0.13.2
 tabulate==0.8.10
 tensorflow==2.6.0
+tensorflow-probability==0.14.1
 timezonefinder==5.2.0
 toolz==0.11.1
 typing_extensions~=3.7.4
diff --git a/run_bnn.py b/run_bnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3642ec6d522aff51516a7eb710d00b04ab137d50
--- /dev/null
+++ b/run_bnn.py
@@ -0,0 +1,61 @@
+__author__ = "Felix Kleinert"
+__date__ = '2022-08-05'
+
+import argparse
+from mlair.workflows import DefaultWorkflow
+# from mlair.model_modules.recurrent_networks import RNN as chosen_model
+from mlair.helpers import remove_items
+from mlair.configuration.defaults import DEFAULT_PLOT_LIST
+from mlair.model_modules.probability_models import ProbTestModel4, MyUnetProb, ProbTestModel2, ProbTestModelMixture
+import os
+import tensorflow as tf
+
+
+def load_stations(case=0):
+    import json
+    cases = {
+        0: 'supplement/station_list_north_german_plain_rural.json',
+        1: 'supplement/station_list_north_german_plain.json',
+        2: 'supplement/German_background_stations.json',
+    }
+    try:
+        filename = cases[case]
+        with open(filename, 'r') as jfile:
+            stations = json.load(jfile)
+    except FileNotFoundError:
+        stations = None
+    return stations
+
+
+def main(parser_args):
+    # tf.compat.v1.disable_v2_behavior()
+    plots = remove_items(DEFAULT_PLOT_LIST, ["PlotConditionalQuantiles", "PlotPeriodogram"])
+    stats_per_var = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
+     'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
+     'pblheight': 'maximum'}
+    workflow = DefaultWorkflow(  # stations=load_stations(),
+        #stations=["DEBW087","DEBW013", "DEBW107",  "DEBW076"],
+        stations=load_stations(2),
+        model=MyUnetProb,
+        window_lead_time=4,
+        window_history_size=6,
+        epochs=100,
+        batch_size=1024,
+        train_model=False, create_new_model=True, network="UBA",
+        evaluate_feature_importance=False,  # plot_list=["PlotCompetitiveSkillScore"],
+        # competitors=["test_model", "test_model2"],
+        competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
+        variables=list(stats_per_var.keys()),
+        statistics_per_var=stats_per_var,
+        target_var="o3",
+        target_var_unit="ppb",
+        **parser_args.__dict__, start_script=__file__)
+    workflow.run()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
+                        help="set experiment date as string")
+    args = parser.parse_args()
+    main(args)