diff --git a/mlair/model_modules/probability_models.py b/mlair/model_modules/probability_models.py
index 79300a1d332f8d45f6e85c89ba582a479c3cf989..dd0c3035ebd59957937fe35e8a81f9a3c5913166 100644
--- a/mlair/model_modules/probability_models.py
+++ b/mlair/model_modules/probability_models.py
@@ -314,7 +314,7 @@ class MyUnetProb(AbstractModelClass):
         pars = tf.keras.layers.Dense(params_size)(dl)
         # pars = DenseVariationalCustom(
         #     units=params_size, make_prior_fn=prior, make_posterior_fn=posterior,
-        #     kl_use_exact=True, kl_weight=1./self.x_train_shape)(dl)
+        #     kl_use_exact=False, kl_weight=1./self.num_of_training_samples)(dl)
 
         # outputs = tfpl.MixtureSameFamily(self.k_mixed_components,
         #                                 tfpl.MultivariateNormalTriL(
@@ -323,9 +323,22 @@ class MyUnetProb(AbstractModelClass):
         #                                 )
         #                                 )(pars)
 
+
+        # outputs = tfpl.MultivariateNormalTriL(
+        #     self._output_shape,
+        #     convert_to_tensor_fn=tfp.distributions.Distribution.mode
+        # )(pars)
+
         outputs = tfpl.MultivariateNormalTriL(
             self._output_shape,
-            convert_to_tensor_fn=tfp.distributions.Distribution.mode
+            # lambda s: s.sample(10),
+            sample_real(10),
+            activity_regularizer=tfpl.KLDivergenceRegularizer(
+                tfd.MultivariateNormalDiag(loc=tf.zeros(self._output_shape),
+                                           scale_diag=tf.ones(self._output_shape)),
+                weight=self.num_of_training_samples
+            )
+            # convert_to_tensor_fn=tfp.distributions.Distribution.mode
         )(pars)
 
         self.model = keras.Model(inputs=input_train, outputs=outputs)
@@ -846,6 +859,14 @@ class Convolution2DReparameterizationCustom(tfpl.Convolution2DReparameterization
             })
         return config
 
+def sample_real(n_real=10):
+
+    global sample
+    def sample(s):
+        return s.sample(n_real)
+
+    return sample
+
 
 if __name__ == "__main__":