diff --git a/video_prediction/models/sna_model.py b/video_prediction/models/sna_model.py
index 11d89e5ad9e0f9853903f73f50010a5aba0116fe..ddb04deafc73f49d0466acd74ac4a43d94ac72f0 100644
--- a/video_prediction/models/sna_model.py
+++ b/video_prediction/models/sna_model.py
@@ -603,7 +603,7 @@ def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
                              [ground_truth_examps, generated_examps])
 
 
-def generator_fn(inputs, hparams=None):
+def generator_fn(inputs, mode, hparams):
     images = tf.unstack(inputs['images'], axis=0)
     actions = tf.unstack(inputs['actions'], axis=0)
     states = tf.unstack(inputs['states'], axis=0)
@@ -617,13 +617,14 @@ def generator_fn(inputs, hparams=None):
     else:
         kern_size = hparams.kernel_size
 
+    schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1
     conf = {
         'context_frames': hparams.context_frames,  # of frames before predictions.' ,
         'use_state': 1,  # 'Whether or not to give the state+action to the model' ,
         'ngf': hparams.ngf,
         'model': hparams.transformation.upper(),  # 'model architecture to use - CDNA, DNA, or STP' ,
         'num_masks': hparams.num_masks,  # 'number of masks, usually 1 for DNA, 10 for CDNA, STN.' ,
-        'schedsamp_k': hparams.schedule_sampling_k,  # 'The k hyperparameter for scheduled sampling -1 for no scheduled sampling.' ,
+        'schedsamp_k': schedule_sampling_k,  # 'The k hyperparameter for scheduled sampling -1 for no scheduled sampling.' ,
         'kern_size': kern_size,  # size of DNA kerns
     }
     if hparams.first_image_background:
@@ -641,9 +642,7 @@ def generator_fn(inputs, hparams=None):
     }
     if 'pix_distribs' in inputs:
         outputs['gen_pix_distribs'] = tf.stack(m.gen_distrib1, axis=0)
-    outputs = {name: output[hparams.context_frames - 1:] for name, output in outputs.items()}
-    gen_images = outputs['gen_images'][hparams.context_frames - 1:]
-    return gen_images, outputs
+    return outputs
 
 
 class SNAVideoPredictionModel(VideoPredictionModel):
@@ -666,15 +665,3 @@ class SNAVideoPredictionModel(VideoPredictionModel):
             schedule_sampling_k=900.0,
         )
         return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-    def parse_hparams(self, hparams_dict, hparams):
-        hparams = super(SNAVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)
-        if self.mode == 'test':
-            def override_hparams_maybe(name, value):
-                orig_value = hparams.values()[name]
-                if orig_value != value:
-                    print('Overriding hparams from %s=%r to %r for mode=%s.' %
-                          (name, orig_value, value, self.mode))
-                    hparams.set_hparam(name, value)
-            override_hparams_maybe('schedule_sampling_k', -1)
-        return hparams
diff --git a/video_prediction/models/sv2p_model.py b/video_prediction/models/sv2p_model.py
index c2134e413cfd44c5db181feb8c240bd6283725cc..e7a06364178dc380456e47569787ab693ad121de 100644
--- a/video_prediction/models/sv2p_model.py
+++ b/video_prediction/models/sv2p_model.py
@@ -585,32 +585,34 @@ def generator_fn(inputs, mode, hparams):
                         stp=hparams.transformation == 'stp',
                         context_frames=hparams.context_frames,
                         hparams=hparams)
-
-    outputs_enc = encoder_fn(inputs, hparams)
-    tf.get_variable_scope().reuse_variables()
-    gen_images_enc, gen_states_enc = \
-        construct_model(images,
-                        actions,
-                        states,
-                        outputs_enc=outputs_enc,
-                        iter_num=iter_num,
-                        k=schedule_sampling_k,
-                        use_state='actions' in inputs,
-                        num_masks=hparams.num_masks,
-                        cdna=hparams.transformation == 'cdna',
-                        dna=hparams.transformation == 'dna',
-                        stp=hparams.transformation == 'stp',
-                        context_frames=hparams.context_frames,
-                        hparams=hparams)
-
     outputs = {
         'gen_images': tf.stack(gen_images, axis=0),
         'gen_states': tf.stack(gen_states, axis=0),
-        'gen_images_enc': tf.stack(gen_images_enc, axis=0),
-        'gen_states_enc': tf.stack(gen_states_enc, axis=0),
-        'zs_mu_enc': outputs_enc['zs_mu_enc'],
-        'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'],
     }
+
+    if mode == 'train':
+        outputs_enc = encoder_fn(inputs, hparams)
+        tf.get_variable_scope().reuse_variables()
+        gen_images_enc, gen_states_enc = \
+            construct_model(images,
+                            actions,
+                            states,
+                            outputs_enc=outputs_enc,
+                            iter_num=iter_num,
+                            k=schedule_sampling_k,
+                            use_state='actions' in inputs,
+                            num_masks=hparams.num_masks,
+                            cdna=hparams.transformation == 'cdna',
+                            dna=hparams.transformation == 'dna',
+                            stp=hparams.transformation == 'stp',
+                            context_frames=hparams.context_frames,
+                            hparams=hparams)
+        outputs.update({
+            'gen_images_enc': tf.stack(gen_images_enc, axis=0),
+            'gen_states_enc': tf.stack(gen_states_enc, axis=0),
+            'zs_mu_enc': outputs_enc['zs_mu_enc'],
+            'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'],
+        })
     return outputs
 
 
@@ -655,3 +657,22 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
         # Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is
         # linearly increased for the first 20k iterations of this stage.
         return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def parse_hparams(self, hparams_dict, hparams):
+        # backwards compatibility
+        deprecated_hparams_keys = [
+            'num_gpus',
+            'acvideo_gan_weight',
+            'acvideo_vae_gan_weight',
+            'image_gan_weight',
+            'image_vae_gan_weight',
+            'tuple_gan_weight',
+            'tuple_vae_gan_weight',
+            'gan_weight',
+            'vae_gan_weight',
+            'video_gan_weight',
+            'video_vae_gan_weight',
+        ]
+        for deprecated_hparams_key in deprecated_hparams_keys:
+            hparams_dict.pop(deprecated_hparams_key, None)
+        return super(SV2PVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)