Skip to content
Snippets Groups Projects
Commit 2612c377 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

check some settings for probnet

parent 0966d8e3
No related branches found
No related tags found
2 merge requests!474Draft: Resolve "DataHandler with multiple stats per variable",!466Draft: Resolve "Include CRPS analysis and other ens verif methods or plots"
Pipeline #109634 passed
...@@ -314,7 +314,7 @@ class MyUnetProb(AbstractModelClass): ...@@ -314,7 +314,7 @@ class MyUnetProb(AbstractModelClass):
pars = tf.keras.layers.Dense(params_size)(dl) pars = tf.keras.layers.Dense(params_size)(dl)
# pars = DenseVariationalCustom( # pars = DenseVariationalCustom(
# units=params_size, make_prior_fn=prior, make_posterior_fn=posterior, # units=params_size, make_prior_fn=prior, make_posterior_fn=posterior,
# kl_use_exact=True, kl_weight=1./self.x_train_shape)(dl) # kl_use_exact=False, kl_weight=1./self.num_of_training_samples)(dl)
# outputs = tfpl.MixtureSameFamily(self.k_mixed_components, # outputs = tfpl.MixtureSameFamily(self.k_mixed_components,
# tfpl.MultivariateNormalTriL( # tfpl.MultivariateNormalTriL(
...@@ -323,9 +323,22 @@ class MyUnetProb(AbstractModelClass): ...@@ -323,9 +323,22 @@ class MyUnetProb(AbstractModelClass):
# ) # )
# )(pars) # )(pars)
# outputs = tfpl.MultivariateNormalTriL(
# self._output_shape,
# convert_to_tensor_fn=tfp.distributions.Distribution.mode
# )(pars)
outputs = tfpl.MultivariateNormalTriL( outputs = tfpl.MultivariateNormalTriL(
self._output_shape, self._output_shape,
convert_to_tensor_fn=tfp.distributions.Distribution.mode # lambda s: s.sample(10),
sample_real(10),
activity_regularizer=tfpl.KLDivergenceRegularizer(
tfd.MultivariateNormalDiag(loc=tf.zeros(self._output_shape),
scale_diag=tf.ones(self._output_shape)),
weight=self.num_of_training_samples
)
# convert_to_tensor_fn=tfp.distributions.Distribution.mode
)(pars) )(pars)
self.model = keras.Model(inputs=input_train, outputs=outputs) self.model = keras.Model(inputs=input_train, outputs=outputs)
...@@ -846,6 +859,14 @@ class Convolution2DReparameterizationCustom(tfpl.Convolution2DReparameterization ...@@ -846,6 +859,14 @@ class Convolution2DReparameterizationCustom(tfpl.Convolution2DReparameterization
}) })
return config return config
def sample_real(n_real=10):
global sample
def sample(s):
return s.sample(n_real)
return sample
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment