diff --git a/video_prediction/models/sna_model.py b/video_prediction/models/sna_model.py index 11d89e5ad9e0f9853903f73f50010a5aba0116fe..ddb04deafc73f49d0466acd74ac4a43d94ac72f0 100644 --- a/video_prediction/models/sna_model.py +++ b/video_prediction/models/sna_model.py @@ -603,7 +603,7 @@ def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): [ground_truth_examps, generated_examps]) -def generator_fn(inputs, hparams=None): +def generator_fn(inputs, mode, hparams): images = tf.unstack(inputs['images'], axis=0) actions = tf.unstack(inputs['actions'], axis=0) states = tf.unstack(inputs['states'], axis=0) @@ -617,13 +617,14 @@ def generator_fn(inputs, hparams=None): else: kern_size = hparams.kernel_size + schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1 conf = { 'context_frames': hparams.context_frames, # of frames before predictions.' , 'use_state': 1, # 'Whether or not to give the state+action to the model' , 'ngf': hparams.ngf, 'model': hparams.transformation.upper(), # 'model architecture to use - CDNA, DNA, or STP' , 'num_masks': hparams.num_masks, # 'number of masks, usually 1 for DNA, 10 for CDNA, STN.' , - 'schedsamp_k': hparams.schedule_sampling_k, # 'The k hyperparameter for scheduled sampling -1 for no scheduled sampling.' , + 'schedsamp_k': schedule_sampling_k, # 'The k hyperparameter for scheduled sampling -1 for no scheduled sampling.' , 'kern_size': kern_size, # size of DNA kerns } if hparams.first_image_background: @@ -641,9 +642,7 @@ def generator_fn(inputs, hparams=None): } if 'pix_distribs' in inputs: outputs['gen_pix_distribs'] = tf.stack(m.gen_distrib1, axis=0) - outputs = {name: output[hparams.context_frames - 1:] for name, output in outputs.items()} - gen_images = outputs['gen_images'][hparams.context_frames - 1:] - return gen_images, outputs + return outputs class SNAVideoPredictionModel(VideoPredictionModel): @@ -666,15 +665,3 @@ class SNAVideoPredictionModel(VideoPredictionModel): schedule_sampling_k=900.0, ) return dict(itertools.chain(default_hparams.items(), hparams.items())) - - def parse_hparams(self, hparams_dict, hparams): - hparams = super(SNAVideoPredictionModel, self).parse_hparams(hparams_dict, hparams) - if self.mode == 'test': - def override_hparams_maybe(name, value): - orig_value = hparams.values()[name] - if orig_value != value: - print('Overriding hparams from %s=%r to %r for mode=%s.' % - (name, orig_value, value, self.mode)) - hparams.set_hparam(name, value) - override_hparams_maybe('schedule_sampling_k', -1) - return hparams diff --git a/video_prediction/models/sv2p_model.py b/video_prediction/models/sv2p_model.py index c2134e413cfd44c5db181feb8c240bd6283725cc..e7a06364178dc380456e47569787ab693ad121de 100644 --- a/video_prediction/models/sv2p_model.py +++ b/video_prediction/models/sv2p_model.py @@ -585,32 +585,34 @@ def generator_fn(inputs, mode, hparams): stp=hparams.transformation == 'stp', context_frames=hparams.context_frames, hparams=hparams) - - outputs_enc = encoder_fn(inputs, hparams) - tf.get_variable_scope().reuse_variables() - gen_images_enc, gen_states_enc = \ - construct_model(images, - actions, - states, - outputs_enc=outputs_enc, - iter_num=iter_num, - k=schedule_sampling_k, - use_state='actions' in inputs, - num_masks=hparams.num_masks, - cdna=hparams.transformation == 'cdna', - dna=hparams.transformation == 'dna', - stp=hparams.transformation == 'stp', - context_frames=hparams.context_frames, - hparams=hparams) - outputs = { 'gen_images': tf.stack(gen_images, axis=0), 'gen_states': tf.stack(gen_states, axis=0), - 'gen_images_enc': tf.stack(gen_images_enc, axis=0), - 'gen_states_enc': tf.stack(gen_states_enc, axis=0), - 'zs_mu_enc': outputs_enc['zs_mu_enc'], - 'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'], } + + if mode == 'train': + outputs_enc = encoder_fn(inputs, hparams) + tf.get_variable_scope().reuse_variables() + gen_images_enc, gen_states_enc = \ + construct_model(images, + actions, + states, + outputs_enc=outputs_enc, + iter_num=iter_num, + k=schedule_sampling_k, + use_state='actions' in inputs, + num_masks=hparams.num_masks, + cdna=hparams.transformation == 'cdna', + dna=hparams.transformation == 'dna', + stp=hparams.transformation == 'stp', + context_frames=hparams.context_frames, + hparams=hparams) + outputs.update({ + 'gen_images_enc': tf.stack(gen_images_enc, axis=0), + 'gen_states_enc': tf.stack(gen_states_enc, axis=0), + 'zs_mu_enc': outputs_enc['zs_mu_enc'], + 'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'], + }) return outputs @@ -655,3 +657,22 @@ class SV2PVideoPredictionModel(VideoPredictionModel): # Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is # linearly increased for the first 20k iterations of this stage. return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def parse_hparams(self, hparams_dict, hparams): + # backwards compatibility + deprecated_hparams_keys = [ + 'num_gpus', + 'acvideo_gan_weight', + 'acvideo_vae_gan_weight', + 'image_gan_weight', + 'image_vae_gan_weight', + 'tuple_gan_weight', + 'tuple_vae_gan_weight', + 'gan_weight', + 'vae_gan_weight', + 'video_gan_weight', + 'video_vae_gan_weight', + ] + for deprecated_hparams_key in deprecated_hparams_keys: + hparams_dict.pop(deprecated_hparams_key, None) + return super(SV2PVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)