Skip to content
Snippets Groups Projects
Commit ebd530ee authored by Alex Lee's avatar Alex Lee
Browse files

Fix sv2p to work with new abstractions.

parent 4ad00192
Branches
Tags
No related merge requests found
...@@ -191,10 +191,10 @@ def construct_latent_tower(images, hparams): ...@@ -191,10 +191,10 @@ def construct_latent_tower(images, hparams):
return latent_mean, latent_std return latent_mean, latent_std
def encoder_fn(inputs, hparams=None): def encoder_fn(inputs, hparams):
images = tf.unstack(inputs['images'], axis=0) images = tf.unstack(inputs['images'], axis=0)
latent_mean, latent_std = construct_latent_tower(images, hparams) latent_mean, latent_std = construct_latent_tower(images, hparams)
outputs = {'enc_zs_mu': latent_mean, 'enc_zs_log_sigma_sq': latent_std} outputs = {'zs_mu_enc': latent_mean, 'zs_log_sigma_sq_enc': latent_std}
return outputs return outputs
...@@ -269,7 +269,7 @@ def construct_model(images, ...@@ -269,7 +269,7 @@ def construct_model(images,
if outputs_enc is None: # equivalent to inference_time if outputs_enc is None: # equivalent to inference_time
latent_mean, latent_std = None, None latent_mean, latent_std = None, None
else: else:
latent_mean, latent_std = outputs_enc['enc_zs_mu'], outputs_enc['enc_zs_log_sigma_sq'] latent_mean, latent_std = outputs_enc['zs_mu_enc'], outputs_enc['zs_log_sigma_sq_enc']
assert latent_mean.shape.as_list() == latent_shape assert latent_mean.shape.as_list() == latent_shape
if hparams.multi_latent: if hparams.multi_latent:
...@@ -558,10 +558,11 @@ def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): ...@@ -558,10 +558,11 @@ def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
[ground_truth_examps, generated_examps]) [ground_truth_examps, generated_examps])
def generator_fn(inputs, outputs_enc=None, hparams=None): def generator_fn(inputs, mode, hparams):
images = tf.unstack(inputs['images'], axis=0) images = tf.unstack(inputs['images'], axis=0)
batch_size = images[0].shape[0].value batch_size = images[0].shape[0].value
action_dim, state_dim = 4, 3 action_dim, state_dim = 4, 3
# if not use_state, use zero actions and states to match reference implementation. # if not use_state, use zero actions and states to match reference implementation.
actions = inputs.get('actions', tf.zeros([hparams.sequence_length - 1, batch_size, action_dim])) actions = inputs.get('actions', tf.zeros([hparams.sequence_length - 1, batch_size, action_dim]))
actions = tf.unstack(actions, axis=0) actions = tf.unstack(actions, axis=0)
...@@ -569,13 +570,31 @@ def generator_fn(inputs, outputs_enc=None, hparams=None): ...@@ -569,13 +570,31 @@ def generator_fn(inputs, outputs_enc=None, hparams=None):
states = tf.unstack(states, axis=0) states = tf.unstack(states, axis=0)
iter_num = tf.to_float(tf.train.get_or_create_global_step()) iter_num = tf.to_float(tf.train.get_or_create_global_step())
schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1
gen_images, gen_states = \ gen_images, gen_states = \
construct_model(images,
actions,
states,
outputs_enc=None,
iter_num=iter_num,
k=schedule_sampling_k,
use_state='actions' in inputs,
num_masks=hparams.num_masks,
cdna=hparams.transformation == 'cdna',
dna=hparams.transformation == 'dna',
stp=hparams.transformation == 'stp',
context_frames=hparams.context_frames,
hparams=hparams)
outputs_enc = encoder_fn(inputs, hparams)
tf.get_variable_scope().reuse_variables()
gen_images_enc, gen_states_enc = \
construct_model(images, construct_model(images,
actions, actions,
states, states,
outputs_enc=outputs_enc, outputs_enc=outputs_enc,
iter_num=iter_num, iter_num=iter_num,
k=hparams.schedule_sampling_k, k=schedule_sampling_k,
use_state='actions' in inputs, use_state='actions' in inputs,
num_masks=hparams.num_masks, num_masks=hparams.num_masks,
cdna=hparams.transformation == 'cdna', cdna=hparams.transformation == 'cdna',
...@@ -583,13 +602,16 @@ def generator_fn(inputs, outputs_enc=None, hparams=None): ...@@ -583,13 +602,16 @@ def generator_fn(inputs, outputs_enc=None, hparams=None):
stp=hparams.transformation == 'stp', stp=hparams.transformation == 'stp',
context_frames=hparams.context_frames, context_frames=hparams.context_frames,
hparams=hparams) hparams=hparams)
outputs = { outputs = {
'gen_images': tf.stack(gen_images, axis=0), 'gen_images': tf.stack(gen_images, axis=0),
'gen_states': tf.stack(gen_states, axis=0), 'gen_states': tf.stack(gen_states, axis=0),
'gen_images_enc': tf.stack(gen_images_enc, axis=0),
'gen_states_enc': tf.stack(gen_states_enc, axis=0),
'zs_mu_enc': outputs_enc['zs_mu_enc'],
'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'],
} }
outputs = {name: output[hparams.context_frames - 1:] for name, output in outputs.items()} return outputs
gen_images = outputs['gen_images']
return gen_images, outputs
class SV2PVideoPredictionModel(VideoPredictionModel): class SV2PVideoPredictionModel(VideoPredictionModel):
...@@ -602,9 +624,7 @@ class SV2PVideoPredictionModel(VideoPredictionModel): ...@@ -602,9 +624,7 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SV2PVideoPredictionModel, self).__init__( super(SV2PVideoPredictionModel, self).__init__(
generator_fn, encoder_fn=encoder_fn, *args, ** kwargs) generator_fn, *args, ** kwargs)
if self.hparams.schedule_sampling_k == -1:
self.encoder_fn = None
self.deterministic = not self.hparams.stochastic_model self.deterministic = not self.hparams.stochastic_model
def get_default_hparams_dict(self): def get_default_hparams_dict(self):
...@@ -634,15 +654,3 @@ class SV2PVideoPredictionModel(VideoPredictionModel): ...@@ -634,15 +654,3 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
# Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is # Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is
# linearly increased for the first 20k iterations of this stage. # linearly increased for the first 20k iterations of this stage.
return dict(itertools.chain(default_hparams.items(), hparams.items())) return dict(itertools.chain(default_hparams.items(), hparams.items()))
def parse_hparams(self, hparams_dict, hparams):
hparams = super(SV2PVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)
if self.mode == 'test':
def override_hparams_maybe(name, value):
orig_value = hparams.values()[name]
if orig_value != value:
print('Overriding hparams from %s=%r to %r for mode=%s.' %
(name, orig_value, value, self.mode))
hparams.set_hparam(name, value)
override_hparams_maybe('schedule_sampling_k', -1)
return hparams
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment