Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
1c5658d9
Commit
1c5658d9
authored
4 years ago
by
Michael Langguth
Browse files
Options
Downloads
Patches
Plain Diff
Adopt handling of sequence_length in vanilla_convLSTM_model.py
.
parent
1de49cd4
No related branches found
No related tags found
No related merge requests found
Pipeline
#68245
passed
4 years ago
Stage: build
Stage: test
Stage: deploy
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
+46
-25
46 additions, 25 deletions
...modules/video_prediction/models/vanilla_convLSTM_model.py
with
46 additions
and
25 deletions
video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
+
46
−
25
View file @
1c5658d9
...
@@ -22,7 +22,7 @@ from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLST
...
@@ -22,7 +22,7 @@ from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLST
from
tensorflow.contrib.training
import
HParams
from
tensorflow.contrib.training
import
HParams
class
VanillaConvLstmVideoPredictionModel
(
object
):
class
VanillaConvLstmVideoPredictionModel
(
object
):
def
__init__
(
self
,
mode
=
'
train
'
,
hparams_dict
=
None
):
def
__init__
(
self
,
sequence_length
,
mode
=
'
train
'
,
hparams_dict
=
None
):
"""
"""
This is class for building convLSTM architecture by using updated hparameters
This is class for building convLSTM architecture by using updated hparameters
args:
args:
...
@@ -36,7 +36,8 @@ class VanillaConvLstmVideoPredictionModel(object):
...
@@ -36,7 +36,8 @@ class VanillaConvLstmVideoPredictionModel(object):
self
.
total_loss
=
None
self
.
total_loss
=
None
self
.
context_frames
=
self
.
hparams
.
context_frames
self
.
context_frames
=
self
.
hparams
.
context_frames
self
.
sequence_length
=
self
.
hparams
.
sequence_length
self
.
sequence_length
=
self
.
hparams
.
sequence_length
self
.
predict_frames
=
self
.
sequence_length
-
self
.
context_frames
self
.
predict_frames
=
VanillaConvLstmVideoPredictionModel
.
set_and_check_pred_frames
(
self
.
sequence_length
,
self
.
context_frames
)
self
.
max_epochs
=
self
.
hparams
.
max_epochs
self
.
max_epochs
=
self
.
hparams
.
max_epochs
self
.
loss_fun
=
self
.
hparams
.
loss_fun
self
.
loss_fun
=
self
.
hparams
.
loss_fun
...
@@ -112,6 +113,26 @@ class VanillaConvLstmVideoPredictionModel(object):
...
@@ -112,6 +113,26 @@ class VanillaConvLstmVideoPredictionModel(object):
self
.
is_build_graph
=
True
self
.
is_build_graph
=
True
return
self
.
is_build_graph
return
self
.
is_build_graph
def
convLSTM_network
(
self
):
network_template
=
tf
.
make_template
(
'
network
'
,
VanillaConvLstmVideoPredictionModel
.
convLSTM_cell
)
# make the template to share the variables
# create network
x_hat
=
[]
#This is for training (optimization of convLSTM layer)
hidden_g
=
None
for
i
in
range
(
self
.
sequence_length
-
1
):
if
i
<
self
.
context_frames
:
x_1_g
,
hidden_g
=
network_template
(
self
.
x
[:,
i
,
:,
:,
:],
hidden_g
)
else
:
x_1_g
,
hidden_g
=
network_template
(
x_1_g
,
hidden_g
)
x_hat
.
append
(
x_1_g
)
# pack them all together
x_hat
=
tf
.
stack
(
x_hat
)
self
.
x_hat
=
tf
.
transpose
(
x_hat
,
[
1
,
0
,
2
,
3
,
4
])
# change first dim with sec dim
self
.
x_hat_predict_frames
=
self
.
x_hat
[:,
self
.
context_frames
-
1
:,:,:,:]
@staticmethod
@staticmethod
def
convLSTM_cell
(
inputs
,
hidden
):
def
convLSTM_cell
(
inputs
,
hidden
):
y_0
=
inputs
#we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
y_0
=
inputs
#we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
...
@@ -130,23 +151,23 @@ class VanillaConvLstmVideoPredictionModel(object):
...
@@ -130,23 +151,23 @@ class VanillaConvLstmVideoPredictionModel(object):
x_hat
=
ld
.
conv_layer
(
z3
,
1
,
1
,
channels
,
"
decode_1
"
,
activate
=
"
sigmoid
"
)
x_hat
=
ld
.
conv_layer
(
z3
,
1
,
1
,
channels
,
"
decode_1
"
,
activate
=
"
sigmoid
"
)
return
x_hat
,
hidden
return
x_hat
,
hidden
def
convLSTM_network
(
self
):
@staticmethod
network_template
=
tf
.
make_template
(
'
network
'
,
def
set_and_check_pred_frames
(
seq_length
,
context_frames
):
VanillaConvLstmVideoPredictionModel
.
convLSTM_cell
)
# make the template to share the variables
"""
# create network
Checks if sequence length and context_frames are set properly and returns number of frames to be predicted.
x_hat
=
[]
:param seq_length: number of frames/images per sequences
:param context_frames: number of context frames/images
:return: number of predicted frames
"""
#This is for training (optimization of convLSTM layer)
method
=
VanillaConvLstmVideoPredictionModel
.
set_and_check_pred_frames
.
__name__
hidden_g
=
None
for
i
in
range
(
self
.
sequence_length
-
1
):
if
i
<
self
.
context_frames
:
x_1_g
,
hidden_g
=
network_template
(
self
.
x
[:,
i
,
:,
:,
:],
hidden_g
)
else
:
x_1_g
,
hidden_g
=
network_template
(
x_1_g
,
hidden_g
)
x_hat
.
append
(
x_1_g
)
# pack them all together
# sanity checks
x_hat
=
tf
.
stack
(
x_hat
)
assert
isinstance
(
seq_length
,
int
),
"
%{0}: Sequence length (seq_length) must be an integer
"
.
format
(
method
)
self
.
x_hat
=
tf
.
transpose
(
x_hat
,
[
1
,
0
,
2
,
3
,
4
])
# change first dim with sec dim
assert
isinstance
(
context_frames
,
int
),
"
%{0}: Number of context frames must be an integer
"
.
format
(
method
)
self
.
x_hat_predict_frames
=
self
.
x_hat
[:,
self
.
context_frames
-
1
:,:,:,:]
if
seq_length
>
context_frames
:
return
seq_length
-
context_frames
else
:
raise
ValueError
(
"
%{0}: Sequence length ({1}) must be larger than context frames ({2}).
"
.
format
(
method
,
seq_length
,
context_frames
))
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment