diff --git a/video_prediction/models/sv2p_model.py b/video_prediction/models/sv2p_model.py index f3205f33e0164cb66aaf8681b476db076a087b4b..6cbcafd7c4ac488096cd3f20ab1ff440940557e9 100644 --- a/video_prediction/models/sv2p_model.py +++ b/video_prediction/models/sv2p_model.py @@ -191,10 +191,10 @@ def construct_latent_tower(images, hparams): return latent_mean, latent_std -def encoder_fn(inputs, hparams=None): +def encoder_fn(inputs, hparams): images = tf.unstack(inputs['images'], axis=0) latent_mean, latent_std = construct_latent_tower(images, hparams) - outputs = {'enc_zs_mu': latent_mean, 'enc_zs_log_sigma_sq': latent_std} + outputs = {'zs_mu_enc': latent_mean, 'zs_log_sigma_sq_enc': latent_std} return outputs @@ -269,7 +269,7 @@ def construct_model(images, if outputs_enc is None: # equivalent to inference_time latent_mean, latent_std = None, None else: - latent_mean, latent_std = outputs_enc['enc_zs_mu'], outputs_enc['enc_zs_log_sigma_sq'] + latent_mean, latent_std = outputs_enc['zs_mu_enc'], outputs_enc['zs_log_sigma_sq_enc'] assert latent_mean.shape.as_list() == latent_shape if hparams.multi_latent: @@ -558,10 +558,11 @@ def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): [ground_truth_examps, generated_examps]) -def generator_fn(inputs, outputs_enc=None, hparams=None): +def generator_fn(inputs, mode, hparams): images = tf.unstack(inputs['images'], axis=0) batch_size = images[0].shape[0].value action_dim, state_dim = 4, 3 + # if not use_state, use zero actions and states to match reference implementation. actions = inputs.get('actions', tf.zeros([hparams.sequence_length - 1, batch_size, action_dim])) actions = tf.unstack(actions, axis=0) @@ -569,13 +570,31 @@ def generator_fn(inputs, outputs_enc=None, hparams=None): states = tf.unstack(states, axis=0) iter_num = tf.to_float(tf.train.get_or_create_global_step()) + schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1 gen_images, gen_states = \ + construct_model(images, + actions, + states, + outputs_enc=None, + iter_num=iter_num, + k=schedule_sampling_k, + use_state='actions' in inputs, + num_masks=hparams.num_masks, + cdna=hparams.transformation == 'cdna', + dna=hparams.transformation == 'dna', + stp=hparams.transformation == 'stp', + context_frames=hparams.context_frames, + hparams=hparams) + + outputs_enc = encoder_fn(inputs, hparams) + tf.get_variable_scope().reuse_variables() + gen_images_enc, gen_states_enc = \ construct_model(images, actions, states, outputs_enc=outputs_enc, iter_num=iter_num, - k=hparams.schedule_sampling_k, + k=schedule_sampling_k, use_state='actions' in inputs, num_masks=hparams.num_masks, cdna=hparams.transformation == 'cdna', @@ -583,13 +602,16 @@ def generator_fn(inputs, outputs_enc=None, hparams=None): stp=hparams.transformation == 'stp', context_frames=hparams.context_frames, hparams=hparams) + outputs = { 'gen_images': tf.stack(gen_images, axis=0), 'gen_states': tf.stack(gen_states, axis=0), + 'gen_images_enc': tf.stack(gen_images_enc, axis=0), + 'gen_states_enc': tf.stack(gen_states_enc, axis=0), + 'zs_mu_enc': outputs_enc['zs_mu_enc'], + 'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'], } - outputs = {name: output[hparams.context_frames - 1:] for name, output in outputs.items()} - gen_images = outputs['gen_images'] - return gen_images, outputs + return outputs class SV2PVideoPredictionModel(VideoPredictionModel): @@ -602,9 +624,7 @@ class SV2PVideoPredictionModel(VideoPredictionModel): """ def __init__(self, *args, **kwargs): super(SV2PVideoPredictionModel, self).__init__( - generator_fn, encoder_fn=encoder_fn, *args, ** kwargs) - if self.hparams.schedule_sampling_k == -1: - self.encoder_fn = None + generator_fn, *args, ** kwargs) self.deterministic = not self.hparams.stochastic_model def get_default_hparams_dict(self): @@ -634,15 +654,3 @@ class SV2PVideoPredictionModel(VideoPredictionModel): # Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is # linearly increased for the first 20k iterations of this stage. return dict(itertools.chain(default_hparams.items(), hparams.items())) - - def parse_hparams(self, hparams_dict, hparams): - hparams = super(SV2PVideoPredictionModel, self).parse_hparams(hparams_dict, hparams) - if self.mode == 'test': - def override_hparams_maybe(name, value): - orig_value = hparams.values()[name] - if orig_value != value: - print('Overriding hparams from %s=%r to %r for mode=%s.' % - (name, orig_value, value, self.mode)) - hparams.set_hparam(name, value) - override_hparams_maybe('schedule_sampling_k', -1) - return hparams