Skip to content
Snippets Groups Projects
Commit 89324962 authored by b.gong's avatar b.gong
Browse files

add basicConvLSTMCell and vaniall_convLSTM_model into workflow

parent bce95957
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf
from .layer_def import *
class ConvRNNCell(object):
"""Abstract object representing an Convolutional RNN cell.
"""
def __call__(self, inputs, state, scope=None):
"""Run this RNN cell on inputs, starting from the given state.
"""
raise NotImplementedError("Abstract method")
@property
def state_size(self):
"""size(s) of state(s) used by this cell.
"""
raise NotImplementedError("Abstract method")
@property
def output_size(self):
"""Integer or TensorShape: size of outputs produced by this cell."""
raise NotImplementedError("Abstract method")
def zero_state(self,input, dtype):
"""Return zero-filled state tensor(s).
Args:
batch_size: int, float, or unit Tensor representing the batch size.
dtype: the data type to use for the state.
Returns:
tensor of shape '[batch_size x shape[0] x shape[1] x num_features]
filled with zeros
"""
shape = self.shape
num_features = self.num_features
#x= tf.placeholder(tf.float32, shape=[input.shape[0], shape[0], shape[1], num_features * 2])#Bing: add this to
zeros = tf.zeros([tf.shape(input)[0], shape[0], shape[1], num_features * 2])
#zeros = tf.zeros_like(x)
return zeros
class BasicConvLSTMCell(ConvRNNCell):
"""Basic Conv LSTM recurrent network cell. The
"""
def __init__(self, shape, filter_size, num_features, forget_bias=1.0, input_size=None,
state_is_tuple=False, activation=tf.nn.tanh):
"""Initialize the basic Conv LSTM cell.
Args:
shape: int tuple thats the height and width of the cell
filter_size: int tuple thats the height and width of the filter
num_features: int thats the depth of the cell
forget_bias: float, The bias added to forget gates (see above).
input_size: Deprecated and unused.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states.
"""
# if not state_is_tuple:
# logging.warn("%s: Using a concatenated state is slower and will soon be "
# "deprecated. Use state_is_tuple=True.", self)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self.shape = shape
self.filter_size = filter_size
self.num_features = num_features
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation
@property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state)
concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat)
new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * tf.nn.sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = tf.concat(axis = 3, values = [new_c, new_h])
return new_h, new_state
def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=None):
"""convolution:
Args:
args: a 4D Tensor or a list of 4D, batch x n, Tensors.
filter_size: int tuple of filter height and width.
num_features: int, number of features.
bias_start: starting value to initialize the bias; 0 by default.
scope: VariableScope for the created subgraph; defaults to "Linear".
Returns:
A 4D Tensor with shape [batch h w num_features]
Raises:
ValueError: if some of the arguments has unspecified or wrong shape.
"""
# Calculate the total size of arguments on dimension 1.
total_arg_size_depth = 0
shapes = [a.get_shape().as_list() for a in args]
for shape in shapes:
if len(shape) != 4:
raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes))
if not shape[3]:
raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes))
else:
total_arg_size_depth += shape[3]
dtype = [a.dtype for a in args][0]
# Now the computation.
with tf.variable_scope(scope or "Conv"):
matrix = tf.get_variable(
"Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype)
if len(args) == 1:
res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME')
else:
res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME')
if not bias:
return res
bias_term = tf.get_variable(
"Bias", [num_features],
dtype = dtype,
initializer = tf.constant_initializer(
bias_start, dtype = dtype))
return res + bias_term
import collections
import functools
import itertools
from collections import OrderedDict
import numpy as np
import tensorflow as tf
from tensorflow.python.util import nest
from video_prediction import ops, flow_ops
from video_prediction.models import BaseVideoPredictionModel
from video_prediction.models import networks
from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
from video_prediction.utils import tf_utils
from datetime import datetime
from pathlib import Path
from video_prediction.layers import layer_def as ld
from video_prediction.layers import BasicConvLSTMCell
class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
def __init__(self, mode='train', hparams_dict=None,
hparams=None, **kwargs):
super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
self.mode = mode
self.hparams = hparams
self.learning_rate = self.hparams.lr
self.gen_images_enc = None
self.recon_loss = None
self.latent_loss = None
self.total_loss = None
self.context_frames = 10
self.sequence_length = 20
self.predict_frames = self.sequence_length - self.context_frames
def get_default_hparams_dict(self):
"""
The keys of this dict define valid hyperparameters for instances of
this class. A class inheriting from this one should override this
method if it has a different set of hyperparameters.
Returns:
A dict with the following hyperparameters.
batch_size: batch size for training.
lr: learning rate. if decay steps is non-zero, this is the
learning rate for steps <= decay_step.
end_lr: learning rate for steps >= end_decay_step if decay_steps
is non-zero, ignored otherwise.
decay_steps: (decay_step, end_decay_step) tuple.
max_steps: number of training steps.
beta1: momentum term of Adam.
beta2: momentum term of Adam.
context_frames: the number of ground-truth frames to pass in at
start. Must be specified during instantiation.
sequence_length: the number of frames in the video sequence,
including the context frames, so this model predicts
`sequence_length - context_frames` future frames. Must be
specified during instantiation.
"""
default_hparams = super(VanillaVAEVideoPredictionModel, self).get_default_hparams_dict()
hparams = dict(
batch_size=16,
lr=0.001,
end_lr=0.0,
decay_steps=(200000, 300000),
max_steps=350000,
)
return dict(itertools.chain(default_hparams.items(), hparams.items()))
def build_graph(self, x):
global_step = tf.train.get_or_create_global_step()
original_global_variables = tf.global_variables()
tf.reset_default_graph()
self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step')
# ARCHITECTURE
self.x_hat_context_frames, self.x_hat_predict_frames = self.convLSTM_network()
self.x_hat = tf.concat([self.x_hat_context_frames, self.x_hat_predict_frames], 1)
print("x_hat,shape", self.x_hat)
self.context_frames_loss = tf.reduce_mean(
tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
self.predict_frames_loss = tf.reduce_mean(
tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_predict_frames[:, :, :, :, 0]))
self.total_loss = self.context_frames_loss + self.predict_frames_loss
self.train_op = tf.train.AdamOptimizer(
learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
# Summary op
self.loss_summary = tf.summary.scalar("recon_loss", self.context_frames_loss)
self.loss_summary = tf.summary.scalar("latent_loss", self.predict_frames_loss)
self.loss_summary = tf.summary.scalar("total_loss", self.total_loss)
self.summary_op = tf.summary.merge_all()
global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
self.saveable_variables = [global_step] + global_variables
return
@staticmethod
def convLSTM_cell(inputs, hidden, nz=16):
print("Inputs shape", inputs.shape)
conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu")
print("Encode_1_shape", conv1.shape)
conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2", activate = "leaky_relu")
print("Encode 2_shape,", conv2.shape)
conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3", activate = "leaky_relu")
print("Encode 3_shape, ", conv3.shape)
y_0 = conv3
# conv lstm cell
with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
cell = BasicConvLSTMCell(shape = [16, 16], filter_size = [3, 3], num_features = 8)
if hidden is None:
hidden = cell.zero_state(y_0, tf.float32)
print("hidden zero layer", hidden.shape)
output, hidden = cell(y_0, hidden)
print("output for cell:", output)
output_shape = output.get_shape().as_list()
print("output_shape,", output_shape)
z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, "decode_5", activate = "leaky_relu")
print("conv5 shape", conv5)
conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6", activate = "leaky_relu")
print("conv6 shape", conv6)
x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, "decode_7", activate = "sigmoid") # set activation to linear
print("x hat shape", x_hat)
return x_hat, hidden
def convLSTM_network(self):
network_template = tf.make_template('network',
convLSTM.convLSTM_cell) # make the template to share the variables
# create network
x_hat_context = []
x_hat_predict = []
seq_start = 1
hidden = None
for i in range(self.context_frames):
if i < seq_start:
x_1, hidden = network_template(self.x[:, i, :, :, :], hidden)
else:
x_1, hidden = network_template(x_1, hidden)
x_hat_context.append(x_1)
for i in range(self.predict_frames):
x_1, hidden = network_template(x_1, hidden)
x_hat_predict.append(x_1)
# pack them all together
x_hat_context = tf.stack(x_hat_context)
x_hat_predict = tf.stack(x_hat_predict)
self.x_hat_context = tf.transpose(x_hat_context, [1, 0, 2, 3, 4]) # change first dim with sec dim
self.x_hat_predict = tf.transpose(x_hat_predict, [1, 0, 2, 3, 4]) # change first dim with sec dim
return self.x_hat_context, self.x_hat_predict
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment