Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
ebd530ee
Commit
ebd530ee
authored
6 years ago
by
Alex Lee
Browse files
Options
Downloads
Patches
Plain Diff
Fix sv2p to work with new abstractions.
parent
4ad00192
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
video_prediction/models/sv2p_model.py
+31
-23
31 additions, 23 deletions
video_prediction/models/sv2p_model.py
with
31 additions
and
23 deletions
video_prediction/models/sv2p_model.py
+
31
−
23
View file @
ebd530ee
...
@@ -191,10 +191,10 @@ def construct_latent_tower(images, hparams):
...
@@ -191,10 +191,10 @@ def construct_latent_tower(images, hparams):
return
latent_mean
,
latent_std
return
latent_mean
,
latent_std
def
encoder_fn
(
inputs
,
hparams
=
None
):
def
encoder_fn
(
inputs
,
hparams
):
images
=
tf
.
unstack
(
inputs
[
'
images
'
],
axis
=
0
)
images
=
tf
.
unstack
(
inputs
[
'
images
'
],
axis
=
0
)
latent_mean
,
latent_std
=
construct_latent_tower
(
images
,
hparams
)
latent_mean
,
latent_std
=
construct_latent_tower
(
images
,
hparams
)
outputs
=
{
'
enc_
zs_mu
'
:
latent_mean
,
'
enc_
zs_log_sigma_sq
'
:
latent_std
}
outputs
=
{
'
zs_mu
_enc
'
:
latent_mean
,
'
zs_log_sigma_sq
_enc
'
:
latent_std
}
return
outputs
return
outputs
...
@@ -269,7 +269,7 @@ def construct_model(images,
...
@@ -269,7 +269,7 @@ def construct_model(images,
if
outputs_enc
is
None
:
# equivalent to inference_time
if
outputs_enc
is
None
:
# equivalent to inference_time
latent_mean
,
latent_std
=
None
,
None
latent_mean
,
latent_std
=
None
,
None
else
:
else
:
latent_mean
,
latent_std
=
outputs_enc
[
'
enc_
zs_mu
'
],
outputs_enc
[
'
enc_
zs_log_sigma_sq
'
]
latent_mean
,
latent_std
=
outputs_enc
[
'
zs_mu
_enc
'
],
outputs_enc
[
'
zs_log_sigma_sq
_enc
'
]
assert
latent_mean
.
shape
.
as_list
()
==
latent_shape
assert
latent_mean
.
shape
.
as_list
()
==
latent_shape
if
hparams
.
multi_latent
:
if
hparams
.
multi_latent
:
...
@@ -558,10 +558,11 @@ def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
...
@@ -558,10 +558,11 @@ def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
[
ground_truth_examps
,
generated_examps
])
[
ground_truth_examps
,
generated_examps
])
def
generator_fn
(
inputs
,
outputs_enc
=
Non
e
,
hparams
=
None
):
def
generator_fn
(
inputs
,
mod
e
,
hparams
):
images
=
tf
.
unstack
(
inputs
[
'
images
'
],
axis
=
0
)
images
=
tf
.
unstack
(
inputs
[
'
images
'
],
axis
=
0
)
batch_size
=
images
[
0
].
shape
[
0
].
value
batch_size
=
images
[
0
].
shape
[
0
].
value
action_dim
,
state_dim
=
4
,
3
action_dim
,
state_dim
=
4
,
3
# if not use_state, use zero actions and states to match reference implementation.
# if not use_state, use zero actions and states to match reference implementation.
actions
=
inputs
.
get
(
'
actions
'
,
tf
.
zeros
([
hparams
.
sequence_length
-
1
,
batch_size
,
action_dim
]))
actions
=
inputs
.
get
(
'
actions
'
,
tf
.
zeros
([
hparams
.
sequence_length
-
1
,
batch_size
,
action_dim
]))
actions
=
tf
.
unstack
(
actions
,
axis
=
0
)
actions
=
tf
.
unstack
(
actions
,
axis
=
0
)
...
@@ -569,13 +570,31 @@ def generator_fn(inputs, outputs_enc=None, hparams=None):
...
@@ -569,13 +570,31 @@ def generator_fn(inputs, outputs_enc=None, hparams=None):
states
=
tf
.
unstack
(
states
,
axis
=
0
)
states
=
tf
.
unstack
(
states
,
axis
=
0
)
iter_num
=
tf
.
to_float
(
tf
.
train
.
get_or_create_global_step
())
iter_num
=
tf
.
to_float
(
tf
.
train
.
get_or_create_global_step
())
schedule_sampling_k
=
hparams
.
schedule_sampling_k
if
mode
==
'
train
'
else
-
1
gen_images
,
gen_states
=
\
gen_images
,
gen_states
=
\
construct_model
(
images
,
actions
,
states
,
outputs_enc
=
None
,
iter_num
=
iter_num
,
k
=
schedule_sampling_k
,
use_state
=
'
actions
'
in
inputs
,
num_masks
=
hparams
.
num_masks
,
cdna
=
hparams
.
transformation
==
'
cdna
'
,
dna
=
hparams
.
transformation
==
'
dna
'
,
stp
=
hparams
.
transformation
==
'
stp
'
,
context_frames
=
hparams
.
context_frames
,
hparams
=
hparams
)
outputs_enc
=
encoder_fn
(
inputs
,
hparams
)
tf
.
get_variable_scope
().
reuse_variables
()
gen_images_enc
,
gen_states_enc
=
\
construct_model
(
images
,
construct_model
(
images
,
actions
,
actions
,
states
,
states
,
outputs_enc
=
outputs_enc
,
outputs_enc
=
outputs_enc
,
iter_num
=
iter_num
,
iter_num
=
iter_num
,
k
=
hparams
.
schedule_sampling_k
,
k
=
schedule_sampling_k
,
use_state
=
'
actions
'
in
inputs
,
use_state
=
'
actions
'
in
inputs
,
num_masks
=
hparams
.
num_masks
,
num_masks
=
hparams
.
num_masks
,
cdna
=
hparams
.
transformation
==
'
cdna
'
,
cdna
=
hparams
.
transformation
==
'
cdna
'
,
...
@@ -583,13 +602,16 @@ def generator_fn(inputs, outputs_enc=None, hparams=None):
...
@@ -583,13 +602,16 @@ def generator_fn(inputs, outputs_enc=None, hparams=None):
stp
=
hparams
.
transformation
==
'
stp
'
,
stp
=
hparams
.
transformation
==
'
stp
'
,
context_frames
=
hparams
.
context_frames
,
context_frames
=
hparams
.
context_frames
,
hparams
=
hparams
)
hparams
=
hparams
)
outputs
=
{
outputs
=
{
'
gen_images
'
:
tf
.
stack
(
gen_images
,
axis
=
0
),
'
gen_images
'
:
tf
.
stack
(
gen_images
,
axis
=
0
),
'
gen_states
'
:
tf
.
stack
(
gen_states
,
axis
=
0
),
'
gen_states
'
:
tf
.
stack
(
gen_states
,
axis
=
0
),
'
gen_images_enc
'
:
tf
.
stack
(
gen_images_enc
,
axis
=
0
),
'
gen_states_enc
'
:
tf
.
stack
(
gen_states_enc
,
axis
=
0
),
'
zs_mu_enc
'
:
outputs_enc
[
'
zs_mu_enc
'
],
'
zs_log_sigma_sq_enc
'
:
outputs_enc
[
'
zs_log_sigma_sq_enc
'
],
}
}
outputs
=
{
name
:
output
[
hparams
.
context_frames
-
1
:]
for
name
,
output
in
outputs
.
items
()}
return
outputs
gen_images
=
outputs
[
'
gen_images
'
]
return
gen_images
,
outputs
class
SV2PVideoPredictionModel
(
VideoPredictionModel
):
class
SV2PVideoPredictionModel
(
VideoPredictionModel
):
...
@@ -602,9 +624,7 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
...
@@ -602,9 +624,7 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
SV2PVideoPredictionModel
,
self
).
__init__
(
super
(
SV2PVideoPredictionModel
,
self
).
__init__
(
generator_fn
,
encoder_fn
=
encoder_fn
,
*
args
,
**
kwargs
)
generator_fn
,
*
args
,
**
kwargs
)
if
self
.
hparams
.
schedule_sampling_k
==
-
1
:
self
.
encoder_fn
=
None
self
.
deterministic
=
not
self
.
hparams
.
stochastic_model
self
.
deterministic
=
not
self
.
hparams
.
stochastic_model
def
get_default_hparams_dict
(
self
):
def
get_default_hparams_dict
(
self
):
...
@@ -634,15 +654,3 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
...
@@ -634,15 +654,3 @@ class SV2PVideoPredictionModel(VideoPredictionModel):
# Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is
# Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is
# linearly increased for the first 20k iterations of this stage.
# linearly increased for the first 20k iterations of this stage.
return
dict
(
itertools
.
chain
(
default_hparams
.
items
(),
hparams
.
items
()))
return
dict
(
itertools
.
chain
(
default_hparams
.
items
(),
hparams
.
items
()))
def
parse_hparams
(
self
,
hparams_dict
,
hparams
):
hparams
=
super
(
SV2PVideoPredictionModel
,
self
).
parse_hparams
(
hparams_dict
,
hparams
)
if
self
.
mode
==
'
test
'
:
def
override_hparams_maybe
(
name
,
value
):
orig_value
=
hparams
.
values
()[
name
]
if
orig_value
!=
value
:
print
(
'
Overriding hparams from %s=%r to %r for mode=%s.
'
%
(
name
,
orig_value
,
value
,
self
.
mode
))
hparams
.
set_hparam
(
name
,
value
)
override_hparams_maybe
(
'
schedule_sampling_k
'
,
-
1
)
return
hparams
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment