diff --git a/video_prediction/models/base_model.py b/video_prediction/models/base_model.py index c5da777f97b34d7d5bf277c45070fafc9dcd056f..3facdbcb4b2314f89dc585a9fa2da890390fb17d 100644 --- a/video_prediction/models/base_model.py +++ b/video_prediction/models/base_model.py @@ -47,10 +47,10 @@ class BaseVideoPredictionModel(object): self.eval_num_samples = eval_num_samples self.eval_parallel_iterations = eval_parallel_iterations self.hparams = self.parse_hparams(hparams_dict, hparams) - if not self.hparams.context_frames: + if self.hparams.context_frames == -1: raise ValueError('Invalid context_frames %r. It might have to be ' 'specified.' % self.hparams.context_frames) - if not self.hparams.sequence_length: + if self.hparams.sequence_length == -1: raise ValueError('Invalid sequence_length %r. It might have to be ' 'specified.' % self.hparams.sequence_length) @@ -84,8 +84,8 @@ class BaseVideoPredictionModel(object): repeat: the number of repeat actions (if applicable). """ hparams = dict( - context_frames=0, - sequence_length=0, + context_frames=-1, + sequence_length=-1, repeat=1, ) return hparams @@ -348,8 +348,8 @@ class VideoPredictionModel(BaseVideoPredictionModel): max_steps=300000, beta1=0.9, beta2=0.999, - context_frames=0, - sequence_length=0, + context_frames=-1, + sequence_length=-1, clip_length=10, l1_weight=0.0, l2_weight=1.0, diff --git a/video_prediction/models/mocogan_model.py b/video_prediction/models/mocogan_model.py index ac9c470a76b82981c2885fdd39c157aaf4b68c61..22490c4d831e65286438bb66fdc4855f8349f91c 100644 --- a/video_prediction/models/mocogan_model.py +++ b/video_prediction/models/mocogan_model.py @@ -291,6 +291,7 @@ class MoCoGANVideoPredictionModel(VideoPredictionModel): def get_default_hparams_dict(self): default_hparams = super(MoCoGANVideoPredictionModel, self).get_default_hparams_dict() hparams = dict( + batch_size=32, l1_weight=10.0, l2_weight=0.0, image_gan_weight=1.0,