From 2612c377d94ad309938223f464573dc6c7e304a1 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Mon, 22 Aug 2022 14:34:30 +0200
Subject: [PATCH] check some settings for probnet
---
mlair/model_modules/probability_models.py | 25 +++++++++++++++++++++--
1 file changed, 23 insertions(+), 2 deletions(-)
diff --git a/mlair/model_modules/probability_models.py b/mlair/model_modules/probability_models.py
index 79300a1d..dd0c3035 100644
--- a/mlair/model_modules/probability_models.py
+++ b/mlair/model_modules/probability_models.py
@@ -314,7 +314,7 @@ class MyUnetProb(AbstractModelClass):
pars = tf.keras.layers.Dense(params_size)(dl)
# pars = DenseVariationalCustom(
# units=params_size, make_prior_fn=prior, make_posterior_fn=posterior,
- # kl_use_exact=True, kl_weight=1./self.x_train_shape)(dl)
+ # kl_use_exact=False, kl_weight=1./self.num_of_training_samples)(dl)
# outputs = tfpl.MixtureSameFamily(self.k_mixed_components,
# tfpl.MultivariateNormalTriL(
@@ -323,9 +323,22 @@ class MyUnetProb(AbstractModelClass):
# )
# )(pars)
+
+ # outputs = tfpl.MultivariateNormalTriL(
+ # self._output_shape,
+ # convert_to_tensor_fn=tfp.distributions.Distribution.mode
+ # )(pars)
+
outputs = tfpl.MultivariateNormalTriL(
self._output_shape,
- convert_to_tensor_fn=tfp.distributions.Distribution.mode
+ # lambda s: s.sample(10),
+ sample_real(10),
+ activity_regularizer=tfpl.KLDivergenceRegularizer(
+ tfd.MultivariateNormalDiag(loc=tf.zeros(self._output_shape),
+ scale_diag=tf.ones(self._output_shape)),
+ weight=self.num_of_training_samples
+ )
+ # convert_to_tensor_fn=tfp.distributions.Distribution.mode
)(pars)
self.model = keras.Model(inputs=input_train, outputs=outputs)
@@ -846,6 +859,14 @@ class Convolution2DReparameterizationCustom(tfpl.Convolution2DReparameterization
})
return config
+def sample_real(n_real=10):
+
+ global sample
+ def sample(s):
+ return s.sample(n_real)
+
+ return sample
+
if __name__ == "__main__":
--
GitLab