diff --git a/video_prediction/models/sv2p_model.py b/video_prediction/models/sv2p_model.py
index 135a2e04ee8ae65a744e7c102713020fe064d8d0..f3205f33e0164cb66aaf8681b476db076a087b4b 100644
--- a/video_prediction/models/sv2p_model.py
+++ b/video_prediction/models/sv2p_model.py
@@ -610,6 +610,7 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
     def get_default_hparams_dict(self):
         default_hparams = super(SV2PVideoPredictionModel, self).get_default_hparams_dict()
         hparams = dict(
+            batch_size=32,
             l1_weight=0.0,
             l2_weight=1.0,
             kl_weight=1e-3 * 10 * 8,  # equivalent to latent_loss_multiplier up to a factor (see below)