diff --git a/video_prediction/models/sv2p_model.py b/video_prediction/models/sv2p_model.py index 135a2e04ee8ae65a744e7c102713020fe064d8d0..f3205f33e0164cb66aaf8681b476db076a087b4b 100644 --- a/video_prediction/models/sv2p_model.py +++ b/video_prediction/models/sv2p_model.py @@ -610,6 +610,7 @@ class SV2PVideoPredictionModel(VideoPredictionModel): def get_default_hparams_dict(self): default_hparams = super(SV2PVideoPredictionModel, self).get_default_hparams_dict() hparams = dict( + batch_size=32, l1_weight=0.0, l2_weight=1.0, kl_weight=1e-3 * 10 * 8, # equivalent to latent_loss_multiplier up to a factor (see below)