From e49b74d614a795061892653397e127844334da87 Mon Sep 17 00:00:00 2001
From: Alex Lee <alexleegk@gmail.com>
Date: Wed, 11 Apr 2018 19:38:57 -0700
Subject: [PATCH] Change default context_frames and sequence_lengths defaults
 to -1. Set default batch_size=32 for mocogan.

---
 video_prediction/models/base_model.py    | 12 ++++++------
 video_prediction/models/mocogan_model.py |  1 +
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/video_prediction/models/base_model.py b/video_prediction/models/base_model.py
index c5da777f..3facdbcb 100644
--- a/video_prediction/models/base_model.py
+++ b/video_prediction/models/base_model.py
@@ -47,10 +47,10 @@ class BaseVideoPredictionModel(object):
         self.eval_num_samples = eval_num_samples
         self.eval_parallel_iterations = eval_parallel_iterations
         self.hparams = self.parse_hparams(hparams_dict, hparams)
-        if not self.hparams.context_frames:
+        if self.hparams.context_frames == -1:
             raise ValueError('Invalid context_frames %r. It might have to be '
                              'specified.' % self.hparams.context_frames)
-        if not self.hparams.sequence_length:
+        if self.hparams.sequence_length == -1:
             raise ValueError('Invalid sequence_length %r. It might have to be '
                              'specified.' % self.hparams.sequence_length)
 
@@ -84,8 +84,8 @@ class BaseVideoPredictionModel(object):
             repeat: the number of repeat actions (if applicable).
         """
         hparams = dict(
-            context_frames=0,
-            sequence_length=0,
+            context_frames=-1,
+            sequence_length=-1,
             repeat=1,
         )
         return hparams
@@ -348,8 +348,8 @@ class VideoPredictionModel(BaseVideoPredictionModel):
             max_steps=300000,
             beta1=0.9,
             beta2=0.999,
-            context_frames=0,
-            sequence_length=0,
+            context_frames=-1,
+            sequence_length=-1,
             clip_length=10,
             l1_weight=0.0,
             l2_weight=1.0,
diff --git a/video_prediction/models/mocogan_model.py b/video_prediction/models/mocogan_model.py
index ac9c470a..22490c4d 100644
--- a/video_prediction/models/mocogan_model.py
+++ b/video_prediction/models/mocogan_model.py
@@ -291,6 +291,7 @@ class MoCoGANVideoPredictionModel(VideoPredictionModel):
     def get_default_hparams_dict(self):
         default_hparams = super(MoCoGANVideoPredictionModel, self).get_default_hparams_dict()
         hparams = dict(
+            batch_size=32,
             l1_weight=10.0,
             l2_weight=0.0,
             image_gan_weight=1.0,
-- 
GitLab