diff --git a/mlair/model_modules/probability_models.py b/mlair/model_modules/probability_models.py index 1f51b0853147e5a638505375e789fec79f9739a3..4204d26bc9b0edb63cb07a8e03abfc1feae727b8 100644 --- a/mlair/model_modules/probability_models.py +++ b/mlair/model_modules/probability_models.py @@ -1369,28 +1369,37 @@ class MyUnetProbMulti(AbstractModelClass): # ) # params_size = tfpl.IndependentNormal.params_size(self._output_shape) - params_size = tfpl.MultivariateNormalTriL.params_size(self._output_shape) - pars = tf.keras.layers.Dense(params_size)(dl) - # pars = DenseVariationalCustom( - # units=params_size, make_prior_fn=prior, make_posterior_fn=posterior, - # kl_use_exact=True, kl_weight=1./self.x_train_shape)(dl) - # outputs = tfpl.MixtureSameFamily(self.k_mixed_components, - # tfpl.MultivariateNormalTriL( - # self._output_shape, - # convert_to_tensor_fn=tfp.distributions.Distribution.mode - # ) - # )(pars) + if self.k_mixed_components is None: + # Mulit B + params_size = tfpl.MultivariateNormalTriL.params_size(self._output_shape) + + pars = tf.keras.layers.Dense(params_size)(dl) + + outputs = tfpl.MultivariateNormalTriL( + self._output_shape, + convert_to_tensor_fn=tfp.distributions.Distribution.mode + )(pars) + # Multi E + else: - outputs = tfpl.MultivariateNormalTriL( - self._output_shape, - convert_to_tensor_fn=tfp.distributions.Distribution.mode - )(pars) + # Mix B + params_size = tfpl.MixtureSameFamily.params_size( + self.k_mixed_components, + component_params_size=tfpl.MultivariateNormalTriL.params_size(self._output_shape) + ) - # outputs = tfpl.IndependentNormal( - # self._output_shape - # )(pars) + pars = tf.keras.layers.Dense(params_size)(dl) + # tfpl.MultivariateNormalTriL(self._output_shape, + # convert_to_tensor_fn=tfp.distributions.Distribution.mode + # ) + outputs = tfpl.MixtureSameFamily(self.k_mixed_components, + tfpl.MultivariateNormalTriL( + self._output_shape, + convert_to_tensor_fn=tfp.distributions.Distribution.mode + ))(pars) + # Mix E self.model = keras.Model(inputs=input_train, outputs=outputs)