Skip to content
Snippets Groups Projects
Commit 1c5658d9 authored by Michael Langguth's avatar Michael Langguth
Browse files

Adopt handling of sequence_length in vanilla_convLSTM_model.py

.
parent 1de49cd4
No related branches found
No related tags found
No related merge requests found
Pipeline #68245 passed
...@@ -22,7 +22,7 @@ from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLST ...@@ -22,7 +22,7 @@ from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLST
from tensorflow.contrib.training import HParams from tensorflow.contrib.training import HParams
class VanillaConvLstmVideoPredictionModel(object): class VanillaConvLstmVideoPredictionModel(object):
def __init__(self, mode='train', hparams_dict=None): def __init__(self, sequence_length, mode='train', hparams_dict=None):
""" """
This is class for building convLSTM architecture by using updated hparameters This is class for building convLSTM architecture by using updated hparameters
args: args:
...@@ -36,7 +36,8 @@ class VanillaConvLstmVideoPredictionModel(object): ...@@ -36,7 +36,8 @@ class VanillaConvLstmVideoPredictionModel(object):
self.total_loss = None self.total_loss = None
self.context_frames = self.hparams.context_frames self.context_frames = self.hparams.context_frames
self.sequence_length = self.hparams.sequence_length self.sequence_length = self.hparams.sequence_length
self.predict_frames = self.sequence_length - self.context_frames self.predict_frames = VanillaConvLstmVideoPredictionModel.set_and_check_pred_frames(self.sequence_length,
self.context_frames)
self.max_epochs = self.hparams.max_epochs self.max_epochs = self.hparams.max_epochs
self.loss_fun = self.hparams.loss_fun self.loss_fun = self.hparams.loss_fun
...@@ -112,6 +113,26 @@ class VanillaConvLstmVideoPredictionModel(object): ...@@ -112,6 +113,26 @@ class VanillaConvLstmVideoPredictionModel(object):
self.is_build_graph = True self.is_build_graph = True
return self.is_build_graph return self.is_build_graph
def convLSTM_network(self):
network_template = tf.make_template('network',
VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables
# create network
x_hat = []
#This is for training (optimization of convLSTM layer)
hidden_g = None
for i in range(self.sequence_length-1):
if i < self.context_frames:
x_1_g, hidden_g = network_template(self.x[:, i, :, :, :], hidden_g)
else:
x_1_g, hidden_g = network_template(x_1_g, hidden_g)
x_hat.append(x_1_g)
# pack them all together
x_hat = tf.stack(x_hat)
self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim
self.x_hat_predict_frames = self.x_hat[:,self.context_frames-1:,:,:,:]
@staticmethod @staticmethod
def convLSTM_cell(inputs, hidden): def convLSTM_cell(inputs, hidden):
y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
...@@ -130,23 +151,23 @@ class VanillaConvLstmVideoPredictionModel(object): ...@@ -130,23 +151,23 @@ class VanillaConvLstmVideoPredictionModel(object):
x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid")
return x_hat, hidden return x_hat, hidden
def convLSTM_network(self): @staticmethod
network_template = tf.make_template('network', def set_and_check_pred_frames(seq_length, context_frames):
VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables """
# create network Checks if sequence length and context_frames are set properly and returns number of frames to be predicted.
x_hat = [] :param seq_length: number of frames/images per sequences
:param context_frames: number of context frames/images
:return: number of predicted frames
"""
#This is for training (optimization of convLSTM layer) method = VanillaConvLstmVideoPredictionModel.set_and_check_pred_frames.__name__
hidden_g = None
for i in range(self.sequence_length-1):
if i < self.context_frames:
x_1_g, hidden_g = network_template(self.x[:, i, :, :, :], hidden_g)
else:
x_1_g, hidden_g = network_template(x_1_g, hidden_g)
x_hat.append(x_1_g)
# pack them all together # sanity checks
x_hat = tf.stack(x_hat) assert isinstance(seq_length, int), "%{0}: Sequence length (seq_length) must be an integer".format(method)
self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim assert isinstance(context_frames, int), "%{0}: Number of context frames must be an integer".format(method)
self.x_hat_predict_frames = self.x_hat[:,self.context_frames-1:,:,:,:]
if seq_length > context_frames:
return seq_length-context_frames
else:
raise ValueError("%{0}: Sequence length ({1}) must be larger than context frames ({2})."
.format(method, seq_length, context_frames))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment