Skip to content
Snippets Groups Projects
Commit 371ffa1f authored by Felix Kleinert's avatar Felix Kleinert
Browse files

Merge branch '414-include-crps-analysis-and-other-ens-verif-methods-or-plots'...

Merge branch '414-include-crps-analysis-and-other-ens-verif-methods-or-plots' of ssh://gitlab.version.fz-juelich.de:10022/esde/machine-learning/mlair into 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
parents 7b9f21f0 15a4aecc
No related branches found
No related tags found
1 merge request!466Draft: Resolve "Include CRPS analysis and other ens verif methods or plots"
Pipeline #126359 failed
...@@ -1369,28 +1369,37 @@ class MyUnetProbMulti(AbstractModelClass): ...@@ -1369,28 +1369,37 @@ class MyUnetProbMulti(AbstractModelClass):
# ) # )
# params_size = tfpl.IndependentNormal.params_size(self._output_shape) # params_size = tfpl.IndependentNormal.params_size(self._output_shape)
if self.k_mixed_components is None:
# Mulit B
params_size = tfpl.MultivariateNormalTriL.params_size(self._output_shape) params_size = tfpl.MultivariateNormalTriL.params_size(self._output_shape)
pars = tf.keras.layers.Dense(params_size)(dl) pars = tf.keras.layers.Dense(params_size)(dl)
# pars = DenseVariationalCustom(
# units=params_size, make_prior_fn=prior, make_posterior_fn=posterior,
# kl_use_exact=True, kl_weight=1./self.x_train_shape)(dl)
# outputs = tfpl.MixtureSameFamily(self.k_mixed_components,
# tfpl.MultivariateNormalTriL(
# self._output_shape,
# convert_to_tensor_fn=tfp.distributions.Distribution.mode
# )
# )(pars)
outputs = tfpl.MultivariateNormalTriL( outputs = tfpl.MultivariateNormalTriL(
self._output_shape, self._output_shape,
convert_to_tensor_fn=tfp.distributions.Distribution.mode convert_to_tensor_fn=tfp.distributions.Distribution.mode
)(pars) )(pars)
# Multi E
else:
# outputs = tfpl.IndependentNormal( # Mix B
# self._output_shape params_size = tfpl.MixtureSameFamily.params_size(
# )(pars) self.k_mixed_components,
component_params_size=tfpl.MultivariateNormalTriL.params_size(self._output_shape)
)
pars = tf.keras.layers.Dense(params_size)(dl)
# tfpl.MultivariateNormalTriL(self._output_shape,
# convert_to_tensor_fn=tfp.distributions.Distribution.mode
# )
outputs = tfpl.MixtureSameFamily(self.k_mixed_components,
tfpl.MultivariateNormalTriL(
self._output_shape,
convert_to_tensor_fn=tfp.distributions.Distribution.mode
))(pars)
# Mix E
self.model = keras.Model(inputs=input_train, outputs=outputs) self.model = keras.Model(inputs=input_train, outputs=outputs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment