diff --git a/video_prediction/models/sv2p_model.py b/video_prediction/models/sv2p_model.py
index f3205f33e0164cb66aaf8681b476db076a087b4b..6cbcafd7c4ac488096cd3f20ab1ff440940557e9 100644
--- a/video_prediction/models/sv2p_model.py
+++ b/video_prediction/models/sv2p_model.py
@@ -191,10 +191,10 @@ def construct_latent_tower(images, hparams):
     return latent_mean, latent_std
 
 
-def encoder_fn(inputs, hparams=None):
+def encoder_fn(inputs, hparams):
     images = tf.unstack(inputs['images'], axis=0)
     latent_mean, latent_std = construct_latent_tower(images, hparams)
-    outputs = {'enc_zs_mu': latent_mean, 'enc_zs_log_sigma_sq': latent_std}
+    outputs = {'zs_mu_enc': latent_mean, 'zs_log_sigma_sq_enc': latent_std}
     return outputs
 
 
@@ -269,7 +269,7 @@ def construct_model(images,
         if outputs_enc is None:  # equivalent to inference_time
             latent_mean, latent_std = None, None
         else:
-            latent_mean, latent_std = outputs_enc['enc_zs_mu'], outputs_enc['enc_zs_log_sigma_sq']
+            latent_mean, latent_std = outputs_enc['zs_mu_enc'], outputs_enc['zs_log_sigma_sq_enc']
             assert latent_mean.shape.as_list() == latent_shape
 
         if hparams.multi_latent:
@@ -558,10 +558,11 @@ def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
                              [ground_truth_examps, generated_examps])
 
 
-def generator_fn(inputs, outputs_enc=None, hparams=None):
+def generator_fn(inputs, mode, hparams):
     images = tf.unstack(inputs['images'], axis=0)
     batch_size = images[0].shape[0].value
     action_dim, state_dim = 4, 3
+
     # if not use_state, use zero actions and states to match reference implementation.
     actions = inputs.get('actions', tf.zeros([hparams.sequence_length - 1, batch_size, action_dim]))
     actions = tf.unstack(actions, axis=0)
@@ -569,13 +570,31 @@ def generator_fn(inputs, outputs_enc=None, hparams=None):
     states = tf.unstack(states, axis=0)
     iter_num = tf.to_float(tf.train.get_or_create_global_step())
 
+    schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1
     gen_images, gen_states = \
+        construct_model(images,
+                        actions,
+                        states,
+                        outputs_enc=None,
+                        iter_num=iter_num,
+                        k=schedule_sampling_k,
+                        use_state='actions' in inputs,
+                        num_masks=hparams.num_masks,
+                        cdna=hparams.transformation == 'cdna',
+                        dna=hparams.transformation == 'dna',
+                        stp=hparams.transformation == 'stp',
+                        context_frames=hparams.context_frames,
+                        hparams=hparams)
+
+    outputs_enc = encoder_fn(inputs, hparams)
+    tf.get_variable_scope().reuse_variables()
+    gen_images_enc, gen_states_enc = \
         construct_model(images,
                         actions,
                         states,
                         outputs_enc=outputs_enc,
                         iter_num=iter_num,
-                        k=hparams.schedule_sampling_k,
+                        k=schedule_sampling_k,
                         use_state='actions' in inputs,
                         num_masks=hparams.num_masks,
                         cdna=hparams.transformation == 'cdna',
@@ -583,13 +602,16 @@ def generator_fn(inputs, outputs_enc=None, hparams=None):
                         stp=hparams.transformation == 'stp',
                         context_frames=hparams.context_frames,
                         hparams=hparams)
+
     outputs = {
         'gen_images': tf.stack(gen_images, axis=0),
         'gen_states': tf.stack(gen_states, axis=0),
+        'gen_images_enc': tf.stack(gen_images_enc, axis=0),
+        'gen_states_enc': tf.stack(gen_states_enc, axis=0),
+        'zs_mu_enc': outputs_enc['zs_mu_enc'],
+        'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'],
     }
-    outputs = {name: output[hparams.context_frames - 1:] for name, output in outputs.items()}
-    gen_images = outputs['gen_images']
-    return gen_images, outputs
+    return outputs
 
 
 class SV2PVideoPredictionModel(VideoPredictionModel):
@@ -602,9 +624,7 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
     """
     def __init__(self, *args, **kwargs):
         super(SV2PVideoPredictionModel, self).__init__(
-            generator_fn, encoder_fn=encoder_fn, *args, ** kwargs)
-        if self.hparams.schedule_sampling_k == -1:
-            self.encoder_fn = None
+            generator_fn, *args, ** kwargs)
         self.deterministic = not self.hparams.stochastic_model
 
     def get_default_hparams_dict(self):
@@ -634,15 +654,3 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
         # Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is
         # linearly increased for the first 20k iterations of this stage.
         return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-    def parse_hparams(self, hparams_dict, hparams):
-        hparams = super(SV2PVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)
-        if self.mode == 'test':
-            def override_hparams_maybe(name, value):
-                orig_value = hparams.values()[name]
-                if orig_value != value:
-                    print('Overriding hparams from %s=%r to %r for mode=%s.' %
-                          (name, orig_value, value, self.mode))
-                    hparams.set_hparam(name, value)
-            override_hparams_maybe('schedule_sampling_k', -1)
-        return hparams