Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
dd78d680
Commit
dd78d680
authored
4 years ago
by
gong1
Browse files
Options
Downloads
Patches
Plain Diff
update vae model
parent
b4fb67da
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
video_prediction_savp/scripts/train_dummy.py
+7
-2
7 additions, 2 deletions
video_prediction_savp/scripts/train_dummy.py
video_prediction_savp/video_prediction/models/vanilla_vae_model.py
+8
-9
8 additions, 9 deletions
...diction_savp/video_prediction/models/vanilla_vae_model.py
with
15 additions
and
11 deletions
video_prediction_savp/scripts/train_dummy.py
+
7
−
2
View file @
dd78d680
...
@@ -287,7 +287,7 @@ def main():
...
@@ -287,7 +287,7 @@ def main():
#fetches["latent_loss"] = model.latent_loss
#fetches["latent_loss"] = model.latent_loss
fetches
[
"
summary
"
]
=
model
.
summary_op
fetches
[
"
summary
"
]
=
model
.
summary_op
if
model
.
__class__
.
__name__
==
"
McNetVideoPredictionModel
"
or
model
.
__class__
.
__name__
==
"
VanillaConvLstmVideoPredictionModel
"
:
if
model
.
__class__
.
__name__
==
"
McNetVideoPredictionModel
"
or
model
.
__class__
.
__name__
==
"
VanillaConvLstmVideoPredictionModel
"
or
model
.
__class__
.
__name__
==
"
VanillaVAEVideoPredictionModel
"
:
fetches
[
"
global_step
"
]
=
model
.
global_step
fetches
[
"
global_step
"
]
=
model
.
global_step
fetches
[
"
total_loss
"
]
=
model
.
total_loss
fetches
[
"
total_loss
"
]
=
model
.
total_loss
#fetch the specific loss function only for mcnet
#fetch the specific loss function only for mcnet
...
@@ -295,6 +295,9 @@ def main():
...
@@ -295,6 +295,9 @@ def main():
fetches
[
"
L_p
"
]
=
model
.
L_p
fetches
[
"
L_p
"
]
=
model
.
L_p
fetches
[
"
L_gdl
"
]
=
model
.
L_gdl
fetches
[
"
L_gdl
"
]
=
model
.
L_gdl
fetches
[
"
L_GAN
"
]
=
model
.
L_GAN
fetches
[
"
L_GAN
"
]
=
model
.
L_GAN
if
model
.
__class__
.
__name__
==
"
VanillaVAEVideoPredictionModel
"
:
fetches
[
"
latent_loss
"
]
=
model
.
latent_loss
fetches
[
"
recon_loss
"
]
=
model
.
recon_loss
results
=
sess
.
run
(
fetches
)
results
=
sess
.
run
(
fetches
)
train_losses
.
append
(
results
[
"
total_loss
"
])
train_losses
.
append
(
results
[
"
total_loss
"
])
#Fetch losses for validation data
#Fetch losses for validation data
...
@@ -333,6 +336,8 @@ def main():
...
@@ -333,6 +336,8 @@ def main():
print
(
"
Total_loss:{}
"
.
format
(
results
[
"
total_loss
"
]))
print
(
"
Total_loss:{}
"
.
format
(
results
[
"
total_loss
"
]))
elif
model
.
__class__
.
__name__
==
"
SAVPVideoPredictionModel
"
:
elif
model
.
__class__
.
__name__
==
"
SAVPVideoPredictionModel
"
:
print
(
"
Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}
"
.
format
(
results
[
"
g_losses
"
],
results
[
"
d_losses
"
],
results
[
"
g_loss
"
],
results
[
"
d_loss
"
]))
print
(
"
Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}
"
.
format
(
results
[
"
g_losses
"
],
results
[
"
d_losses
"
],
results
[
"
g_loss
"
],
results
[
"
d_loss
"
]))
elif
model
.
__class__
.
__name__
==
"
VanillaVAEVideoPredictionModel
"
:
print
(
"
Total_loss:{}; latent_losses:{}; reconst_loss:{}
"
.
format
(
results
[
"
total_loss
"
],
results
[
"
latent_loss
"
],
results
[
"
recon_loss
"
]))
else
:
else
:
print
(
"
The model name does not exist
"
)
print
(
"
The model name does not exist
"
)
...
...
This diff is collapsed.
Click to expand it.
video_prediction_savp/video_prediction/models/vanilla_vae_model.py
+
8
−
9
View file @
dd78d680
...
@@ -21,6 +21,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
...
@@ -21,6 +21,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
super
(
VanillaVAEVideoPredictionModel
,
self
).
__init__
(
mode
,
hparams_dict
,
hparams
,
**
kwargs
)
super
(
VanillaVAEVideoPredictionModel
,
self
).
__init__
(
mode
,
hparams_dict
,
hparams
,
**
kwargs
)
self
.
mode
=
mode
self
.
mode
=
mode
self
.
learning_rate
=
self
.
hparams
.
lr
self
.
learning_rate
=
self
.
hparams
.
lr
self
.
weight_recon
=
self
.
hparams
.
weight_recon
self
.
nz
=
self
.
hparams
.
nz
self
.
nz
=
self
.
hparams
.
nz
self
.
aggregate_nccl
=
aggregate_nccl
self
.
aggregate_nccl
=
aggregate_nccl
self
.
gen_images_enc
=
None
self
.
gen_images_enc
=
None
...
@@ -30,6 +31,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
...
@@ -30,6 +31,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
self
.
latent_loss
=
None
self
.
latent_loss
=
None
self
.
total_loss
=
None
self
.
total_loss
=
None
def
get_default_hparams_dict
(
self
):
def
get_default_hparams_dict
(
self
):
"""
"""
The keys of this dict define valid hyperparameters for instances of
The keys of this dict define valid hyperparameters for instances of
...
@@ -67,11 +69,12 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
...
@@ -67,11 +69,12 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
nz
=
10
,
nz
=
10
,
context_frames
=-
1
,
context_frames
=-
1
,
sequence_length
=-
1
,
sequence_length
=-
1
,
clip_length
=
10
,
#Bing: TODO What is the clip_length, original is 10,
weight_recon
=
0.4
)
)
return
dict
(
itertools
.
chain
(
default_hparams
.
items
(),
hparams
.
items
()))
return
dict
(
itertools
.
chain
(
default_hparams
.
items
(),
hparams
.
items
()))
def
build_graph
(
self
,
x
)
def
build_graph
(
self
,
x
)
:
self
.
x
=
x
[
"
images
"
]
self
.
x
=
x
[
"
images
"
]
#self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
#self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
self
.
global_step
=
tf
.
train
.
get_or_create_global_step
()
self
.
global_step
=
tf
.
train
.
get_or_create_global_step
()
...
@@ -82,7 +85,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
...
@@ -82,7 +85,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
1
+
self
.
z_log_sigma_sq
-
tf
.
square
(
self
.
z_mu
)
-
1
+
self
.
z_log_sigma_sq
-
tf
.
square
(
self
.
z_mu
)
-
tf
.
exp
(
self
.
z_log_sigma_sq
),
axis
=
1
)
tf
.
exp
(
self
.
z_log_sigma_sq
),
axis
=
1
)
self
.
latent_loss
=
tf
.
reduce_mean
(
latent_loss
)
self
.
latent_loss
=
tf
.
reduce_mean
(
latent_loss
)
self
.
total_loss
=
self
.
recon_loss
+
self
.
latent_loss
self
.
total_loss
=
self
.
weight_recon
*
self
.
recon_loss
+
self
.
latent_loss
self
.
train_op
=
tf
.
train
.
AdamOptimizer
(
self
.
train_op
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
self
.
learning_rate
).
minimize
(
self
.
total_loss
,
global_step
=
self
.
global_step
)
learning_rate
=
self
.
learning_rate
).
minimize
(
self
.
total_loss
,
global_step
=
self
.
global_step
)
# Build a saver
# Build a saver
...
@@ -109,7 +112,6 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
...
@@ -109,7 +112,6 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
@staticmethod
@staticmethod
def
vae_arc3
(
x
,
l_name
=
0
,
nz
=
16
):
def
vae_arc3
(
x
,
l_name
=
0
,
nz
=
16
):
seq_name
=
"
sq_
"
+
str
(
l_name
)
+
"
_
"
seq_name
=
"
sq_
"
+
str
(
l_name
)
+
"
_
"
conv1
=
ld
.
conv_layer
(
x
,
3
,
2
,
8
,
seq_name
+
"
encode_1
"
)
conv1
=
ld
.
conv_layer
(
x
,
3
,
2
,
8
,
seq_name
+
"
encode_1
"
)
conv2
=
ld
.
conv_layer
(
conv1
,
3
,
1
,
8
,
seq_name
+
"
encode_2
"
)
conv2
=
ld
.
conv_layer
(
conv1
,
3
,
1
,
8
,
seq_name
+
"
encode_2
"
)
conv3
=
ld
.
conv_layer
(
conv2
,
3
,
2
,
8
,
seq_name
+
"
encode_3
"
)
conv3
=
ld
.
conv_layer
(
conv2
,
3
,
2
,
8
,
seq_name
+
"
encode_3
"
)
...
@@ -121,12 +123,9 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
...
@@ -121,12 +123,9 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
z
=
z_mu
+
tf
.
sqrt
(
tf
.
exp
(
z_log_sigma_sq
))
*
eps
z
=
z_mu
+
tf
.
sqrt
(
tf
.
exp
(
z_log_sigma_sq
))
*
eps
z2
=
ld
.
fc_layer
(
z
,
hiddens
=
conv3_shape
[
1
]
*
conv3_shape
[
2
]
*
conv3_shape
[
3
],
idx
=
seq_name
+
"
decode_fc1
"
)
z2
=
ld
.
fc_layer
(
z
,
hiddens
=
conv3_shape
[
1
]
*
conv3_shape
[
2
]
*
conv3_shape
[
3
],
idx
=
seq_name
+
"
decode_fc1
"
)
z3
=
tf
.
reshape
(
z2
,
[
-
1
,
conv3_shape
[
1
],
conv3_shape
[
2
],
conv3_shape
[
3
]])
z3
=
tf
.
reshape
(
z2
,
[
-
1
,
conv3_shape
[
1
],
conv3_shape
[
2
],
conv3_shape
[
3
]])
conv5
=
ld
.
transpose_conv_layer
(
z3
,
3
,
2
,
8
,
conv5
=
ld
.
transpose_conv_layer
(
z3
,
3
,
2
,
8
,
seq_name
+
"
decode_5
"
)
seq_name
+
"
decode_5
"
)
conv6
=
ld
.
transpose_conv_layer
(
conv5
,
3
,
1
,
8
,
seq_name
+
"
decode_6
"
)
conv6
=
ld
.
transpose_conv_layer
(
conv5
,
3
,
1
,
8
,
seq_name
+
"
decode_6
"
)
x_hat
=
ld
.
transpose_conv_layer
(
conv6
,
3
,
2
,
3
,
seq_name
+
"
decode_8
"
)
x_hat
=
ld
.
transpose_conv_layer
(
conv6
,
3
,
2
,
3
,
seq_name
+
"
decode_8
"
)
return
x_hat
,
z_mu
,
z_log_sigma_sq
,
z
return
x_hat
,
z_mu
,
z_log_sigma_sq
,
z
def
vae_arc_all
(
self
):
def
vae_arc_all
(
self
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment