diff --git a/mlair/model_modules/probability_models.py b/mlair/model_modules/probability_models.py index 79300a1d332f8d45f6e85c89ba582a479c3cf989..dd0c3035ebd59957937fe35e8a81f9a3c5913166 100644 --- a/mlair/model_modules/probability_models.py +++ b/mlair/model_modules/probability_models.py @@ -314,7 +314,7 @@ class MyUnetProb(AbstractModelClass): pars = tf.keras.layers.Dense(params_size)(dl) # pars = DenseVariationalCustom( # units=params_size, make_prior_fn=prior, make_posterior_fn=posterior, - # kl_use_exact=True, kl_weight=1./self.x_train_shape)(dl) + # kl_use_exact=False, kl_weight=1./self.num_of_training_samples)(dl) # outputs = tfpl.MixtureSameFamily(self.k_mixed_components, # tfpl.MultivariateNormalTriL( @@ -323,9 +323,22 @@ class MyUnetProb(AbstractModelClass): # ) # )(pars) + + # outputs = tfpl.MultivariateNormalTriL( + # self._output_shape, + # convert_to_tensor_fn=tfp.distributions.Distribution.mode + # )(pars) + outputs = tfpl.MultivariateNormalTriL( self._output_shape, - convert_to_tensor_fn=tfp.distributions.Distribution.mode + # lambda s: s.sample(10), + sample_real(10), + activity_regularizer=tfpl.KLDivergenceRegularizer( + tfd.MultivariateNormalDiag(loc=tf.zeros(self._output_shape), + scale_diag=tf.ones(self._output_shape)), + weight=self.num_of_training_samples + ) + # convert_to_tensor_fn=tfp.distributions.Distribution.mode )(pars) self.model = keras.Model(inputs=input_train, outputs=outputs) @@ -846,6 +859,14 @@ class Convolution2DReparameterizationCustom(tfpl.Convolution2DReparameterization }) return config +def sample_real(n_real=10): + + global sample + def sample(s): + return s.sample(n_real) + + return sample + if __name__ == "__main__":