diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py index e78f973b2aef760e1fccc24d296b31ee67f2e0c8..7c506cf452ad21263ab7ff7ea90f56291f6e7f17 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py @@ -22,7 +22,7 @@ from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLST from tensorflow.contrib.training import HParams class VanillaConvLstmVideoPredictionModel(object): - def __init__(self, mode='train', hparams_dict=None): + def __init__(self, sequence_length, mode='train', hparams_dict=None): """ This is class for building convLSTM architecture by using updated hparameters args: @@ -36,7 +36,8 @@ class VanillaConvLstmVideoPredictionModel(object): self.total_loss = None self.context_frames = self.hparams.context_frames self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length - self.context_frames + self.predict_frames = VanillaConvLstmVideoPredictionModel.set_and_check_pred_frames(self.sequence_length, + self.context_frames) self.max_epochs = self.hparams.max_epochs self.loss_fun = self.hparams.loss_fun @@ -67,11 +68,11 @@ class VanillaConvLstmVideoPredictionModel(object): hparams = dict( context_frames=10, sequence_length=20, - max_epochs = 20, - batch_size = 40, - lr = 0.001, - loss_fun = "cross_entropy", - shuffle_on_val= True, + max_epochs=20, + batch_size=40, + lr=0.001, + loss_fun="cross_entropy", + shuffle_on_val=True, ) return hparams @@ -112,24 +113,6 @@ class VanillaConvLstmVideoPredictionModel(object): self.is_build_graph = True return self.is_build_graph - @staticmethod - def convLSTM_cell(inputs, hidden): - y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset - channels = inputs.get_shape()[-1] - # conv lstm cell - cell_shape = y_0.get_shape().as_list() - channels = cell_shape[-1] - with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): - cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64) - if hidden is None: - hidden = cell.zero_state(y_0, tf.float32) - output, hidden = cell(y_0, hidden) - output_shape = output.get_shape().as_list() - z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) - #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction - x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") - return x_hat, hidden - def convLSTM_network(self): network_template = tf.make_template('network', VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables @@ -150,3 +133,41 @@ class VanillaConvLstmVideoPredictionModel(object): self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim self.x_hat_predict_frames = self.x_hat[:,self.context_frames-1:,:,:,:] + @staticmethod + def convLSTM_cell(inputs, hidden): + y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset + channels = inputs.get_shape()[-1] + # conv lstm cell + cell_shape = y_0.get_shape().as_list() + channels = cell_shape[-1] + with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): + cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64) + if hidden is None: + hidden = cell.zero_state(y_0, tf.float32) + output, hidden = cell(y_0, hidden) + output_shape = output.get_shape().as_list() + z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) + #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction + x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") + return x_hat, hidden + + @staticmethod + def set_and_check_pred_frames(seq_length, context_frames): + """ + Checks if sequence length and context_frames are set properly and returns number of frames to be predicted. + :param seq_length: number of frames/images per sequences + :param context_frames: number of context frames/images + :return: number of predicted frames + """ + + method = VanillaConvLstmVideoPredictionModel.set_and_check_pred_frames.__name__ + + # sanity checks + assert isinstance(seq_length, int), "%{0}: Sequence length (seq_length) must be an integer".format(method) + assert isinstance(context_frames, int), "%{0}: Number of context frames must be an integer".format(method) + + if seq_length > context_frames: + return seq_length-context_frames + else: + raise ValueError("%{0}: Sequence length ({1}) must be larger than context frames ({2})." + .format(method, seq_length, context_frames))