diff --git a/Zam347_scripts/generate_era5.sh b/Zam347_scripts/generate_era5.sh index 507ccf9cb2c30b50a5cb768125ed3e472b989a95..72046611bc0e35aa297b73266aa9c2e89c0101b8 100755 --- a/Zam347_scripts/generate_era5.sh +++ b/Zam347_scripts/generate_era5.sh @@ -3,7 +3,7 @@ python -u ../scripts/generate_transfer_learning_finetune.py \ --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords \ ---dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/vae \ +--dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM \ --mode test --results_dir /home/${USER}/results/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500 \ --batch_size 2 --dataset era5 > generate_era5-out.out diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh index b93c7bea814d41b2255f1201b430c37a2022db4e..1f037f6fc21ac0e21a1e16ba5b6dc62438dda13a 100755 --- a/Zam347_scripts/train_era5.sh +++ b/Zam347_scripts/train_era5.sh @@ -2,5 +2,5 @@ -python ../scripts/train_dummy.py --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model vae --model_hparams_dict ../hparams/era5/vae/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/vae +python ../scripts/train_dummy.py --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model convLSTM --model_hparams_dict ../hparams/era5/vae/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM #srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/hparams/era5/vae/model_hparams.json b/hparams/era5/vae/model_hparams.json index 2e9406148e140054ced5e0c4311f3885aa47f728..75e66a11a15fa462abbc113ef76253fb6d15eca6 100644 --- a/hparams/era5/vae/model_hparams.json +++ b/hparams/era5/vae/model_hparams.json @@ -1,8 +1,8 @@ { "batch_size": 8, - "lr": 0.0002, + "lr": 0.001, "nz": 16, - "max_steps":20 + "max_steps":500 } diff --git a/scripts/train_dummy.py b/scripts/train_dummy.py index b89ca957aa4696f4ba6f4118a83bee10683c16ff..6ebdb70bf24ecc53fd9611a7af948842600cd0db 100644 --- a/scripts/train_dummy.py +++ b/scripts/train_dummy.py @@ -164,13 +164,13 @@ def main(): gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) #global_step = tf.train.get_or_create_global_step() + #global_step = tf.Variable(0, name = 'global_step', trainable = False) max_steps = model.hparams.max_steps print ("max_steps",max_steps) with tf.Session(config=config) as sess: print("parameter_count =", sess.run(parameter_count)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) - #coord = tf.train.Coordinator() #threads = tf.train.start_queue_runners(sess = sess, coord = coord) print("Init done: {sess.run(tf.local_variables_initializer())}%") @@ -188,20 +188,23 @@ def main(): # start at one step earlier to log everything without doing any training # step is relative to the start_step for step in range(-1, max_steps - start_step): + global_step = sess.run(model.global_step) + print ("global_step:", global_step) val_handle_eval = sess.run(val_handle) print ("val_handle_val",val_handle_eval) if step == 1: # skip step -1 and 0 for timing purposes (for warmstarting) start_time = time.time() - - fetches = {"global_step": model.global_step} + + fetches = {"global_step":model.global_step} fetches["train_op"] = model.train_op - fetches["latent_loss"] = model.latent_loss + + # fetches["latent_loss"] = model.latent_loss fetches["total_loss"] = model.total_loss - if isinstance(model.learning_rate, tf.Tensor): - fetches["learning_rate"] = model.learning_rate + # if isinstance(model.learning_rate, tf.Tensor): + # fetches["learning_rate"] = model.learning_rate fetches["summary"] = model.summary_op @@ -210,21 +213,25 @@ def main(): X = inputs["images"].eval(session=sess) #results = sess.run(fetches,feed_dict={model.x:X}) #fetch the elements in dictinoary fetch results = sess.run(fetches) + print ("results global step:",results["global_step"]) run_elapsed_time = time.time() - run_start_time if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: print('running train_op took too long (%0.1fs)' % run_elapsed_time) #Run testing results - val_fetches = {"global_step": model.global_step} - val_fetches["latent_loss"] = model.latent_loss - val_fetches["total_loss"] = model.total_loss + #val_fetches = {"global_step":global_step} + val_fetches = {} + #val_fetches["latent_loss"] = model.latent_loss + #val_fetches["total_loss"] = model.total_loss val_fetches["summary"] = model.summary_op - val_results = sess.run(val_fetches) - - summary_writer.add_summary(results["summary"], results["global_step"]) - summary_writer.add_summary(val_results["summary"], val_results["global_step"]) - - + val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval}) + + summary_writer.add_summary(results["summary"]) + summary_writer.add_summary(val_results["summary"]) + + #print("results_global_step", results["global_step"]) + #print("Val_results_global_step", val_results["global_step"]) + val_datasets = [val_dataset] val_models = [model] @@ -244,8 +251,9 @@ def main(): # global_step will have the correct step count if we resume from a checkpoint # global step is read before it's incremented steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size - train_epoch = results["global_step"] / steps_per_epoch - print("progress global step %d epoch %0.1f" % (results["global_step"] + 1, train_epoch)) + #train_epoch = results["global_step"] / steps_per_epoch + train_epoch = global_step/steps_per_epoch + print("progress global step %d epoch %0.1f" % (global_step + 1, train_epoch)) if step > 0: elapsed_time = time.time() - start_time average_time = elapsed_time / step @@ -266,7 +274,7 @@ def main(): print(" Results_total_loss",results["total_loss"]) print("saving model to", args.output_dir) - saver.save(sess, os.path.join(args.output_dir, "model"), global_step=model.global_step) + saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)##Bing: cheat here a little bit because of the global step issue print("done") #global_step = global_step + 1 if __name__ == '__main__': diff --git a/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction/layers/BasicConvLSTMCell.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8defc2874ba29f177e4512bfea78a5f4298518 --- /dev/null +++ b/video_prediction/layers/BasicConvLSTMCell.py @@ -0,0 +1,148 @@ + +import tensorflow as tf +from .layer_def import * + +class ConvRNNCell(object): + """Abstract object representing an Convolutional RNN cell. + """ + + def __call__(self, inputs, state, scope=None): + """Run this RNN cell on inputs, starting from the given state. + """ + raise NotImplementedError("Abstract method") + + @property + def state_size(self): + """size(s) of state(s) used by this cell. + """ + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer or TensorShape: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + def zero_state(self,input, dtype): + """Return zero-filled state tensor(s). + Args: + batch_size: int, float, or unit Tensor representing the batch size. + dtype: the data type to use for the state. + Returns: + tensor of shape '[batch_size x shape[0] x shape[1] x num_features] + filled with zeros + """ + + shape = self.shape + num_features = self.num_features + #x= tf.placeholder(tf.float32, shape=[input.shape[0], shape[0], shape[1], num_features * 2])#Bing: add this to + zeros = tf.zeros([tf.shape(input)[0], shape[0], shape[1], num_features * 2]) + #zeros = tf.zeros_like(x) + return zeros + + +class BasicConvLSTMCell(ConvRNNCell): + """Basic Conv LSTM recurrent network cell. The + """ + + def __init__(self, shape, filter_size, num_features, forget_bias=1.0, input_size=None, + state_is_tuple=False, activation=tf.nn.tanh): + """Initialize the basic Conv LSTM cell. + Args: + shape: int tuple thats the height and width of the cell + filter_size: int tuple thats the height and width of the filter + num_features: int thats the depth of the cell + forget_bias: float, The bias added to forget gates (see above). + input_size: Deprecated and unused. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. The latter behavior will soon be deprecated. + activation: Activation function of the inner states. + """ + # if not state_is_tuple: + # logging.warn("%s: Using a concatenated state is slower and will soon be " + # "deprecated. Use state_is_tuple=True.", self) + if input_size is not None: + logging.warn("%s: The input_size parameter is deprecated.", self) + self.shape = shape + self.filter_size = filter_size + self.num_features = num_features + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + self._activation = activation + + @property + def state_size(self): + return (LSTMStateTuple(self._num_units, self._num_units) + if self._state_is_tuple else 2 * self._num_units) + + @property + def output_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None): + """Long short-term memory cell (LSTM).""" + with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state) + concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat) + + new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * + self._activation(j)) + new_h = self._activation(new_c) * tf.nn.sigmoid(o) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = tf.concat(axis = 3, values = [new_c, new_h]) + return new_h, new_state + + +def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=None): + """convolution: + Args: + args: a 4D Tensor or a list of 4D, batch x n, Tensors. + filter_size: int tuple of filter height and width. + num_features: int, number of features. + bias_start: starting value to initialize the bias; 0 by default. + scope: VariableScope for the created subgraph; defaults to "Linear". + Returns: + A 4D Tensor with shape [batch h w num_features] + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + + # Calculate the total size of arguments on dimension 1. + total_arg_size_depth = 0 + shapes = [a.get_shape().as_list() for a in args] + for shape in shapes: + if len(shape) != 4: + raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes)) + if not shape[3]: + raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes)) + else: + total_arg_size_depth += shape[3] + + dtype = [a.dtype for a in args][0] + + # Now the computation. + with tf.variable_scope(scope or "Conv"): + matrix = tf.get_variable( + "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype) + if len(args) == 1: + res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME') + else: + res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME') + if not bias: + return res + bias_term = tf.get_variable( + "Bias", [num_features], + dtype = dtype, + initializer = tf.constant_initializer( + bias_start, dtype = dtype)) + return res + bias_term diff --git a/video_prediction/layers/layer_def.py b/video_prediction/layers/layer_def.py index 35a7c910e0b3ec12cb9fdc3cbb9ceda3a86922dd..6b7f4387001c9318507ad809d7176071312742d0 100644 --- a/video_prediction/layers/layer_def.py +++ b/video_prediction/layers/layer_def.py @@ -28,11 +28,11 @@ def _variable_on_cpu(name, shape, initializer): Variable Tensor """ with tf.device('/cpu:0'): - var = tf.get_variable(name, shape, initializer = initializer) + var = tf.get_variable(name, shape, initializer=initializer) return var -def _variable_with_weight_decay(name, shape, stddev, wd): +def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.layers.xavier_initializer()): """Helper to create an initialized Variable with weight decay. Note that the Variable is initialized with a truncated normal distribution. A weight decay is added only if one is specified. @@ -45,8 +45,8 @@ def _variable_with_weight_decay(name, shape, stddev, wd): Returns: Variable Tensor """ - var = _variable_on_cpu(name, shape,tf.truncated_normal_initializer(stddev = stddev)) - #var = _variable_on_cpu(name, shape,tf.contrib.layers.xavier_initializer()) + #var = _variable_on_cpu(name, shape,tf.truncated_normal_initializer(stddev = stddev)) + var = _variable_on_cpu(name, shape, initializer) if wd: weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name = 'weight_loss') weight_decay.set_shape([]) @@ -54,16 +54,16 @@ def _variable_with_weight_decay(name, shape, stddev, wd): return var -def conv_layer(inputs, kernel_size, stride, num_features, idx, activate="relu"): +def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"): print("conv_layer activation function",activate) with tf.variable_scope('{0}_conv'.format(idx)) as scope: - print ("DEBUG input shape",inputs.get_shape()) + input_channels = inputs.get_shape()[-1] weights = _variable_with_weight_decay('weights',shape = [kernel_size, kernel_size, input_channels, num_features], stddev = 0.01, wd = weight_decay) - biases = _variable_on_cpu('biases', [num_features], tf.contrib.layers.xavier_initializer()) + biases = _variable_on_cpu('biases', [num_features], initializer) conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding = 'SAME') conv_biased = tf.nn.bias_add(conv, biases) if activate == "linear": @@ -72,25 +72,25 @@ def conv_layer(inputs, kernel_size, stride, num_features, idx, activate="relu"): conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx)) elif activate == "elu": conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "leaky_relu": + conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx)) else: raise ("activation function is not correct") - return conv_rect -def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, activate="relu"): +def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer(),activate="relu"): with tf.variable_scope('{0}_trans_conv'.format(idx)) as scope: input_channels = inputs.get_shape()[3] input_shape = inputs.get_shape().as_list() - print("input_channel",input_channels) + weights = _variable_with_weight_decay('weights', shape = [kernel_size, kernel_size, num_features, input_channels], stddev = 0.1, wd = weight_decay) - biases = _variable_on_cpu('biases', [num_features], tf.contrib.layers.xavier_initializer()) + biases = _variable_on_cpu('biases', [num_features],initializer) batch_size = tf.shape(inputs)[0] -# output_shape = tf.stack( -# [tf.shape(inputs)[0], tf.shape(inputs)[1] * stride, tf.shape(inputs)[2] * stride, num_features]) + output_shape = tf.stack( [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features]) print ("output_shape",output_shape) @@ -102,11 +102,15 @@ def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, activat return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx)) elif activate == "relu": return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "leaky_relu": + return tf.nn.leaky_relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "sigmoid": + return tf.nn.sigmoid(conv_biased, name ='sigmoid') else: - return None + return conv_biased -def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01): +def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01,initializer=tf.contrib.layers.xavier_initializer()): with tf.variable_scope('{0}_fc'.format(idx)) as scope: input_shape = inputs.get_shape().as_list() if flat: @@ -118,7 +122,7 @@ def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01) weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init, wd = weight_decay) - biases = _variable_on_cpu('biases', [hiddens],tf.contrib.layers.xavier_initializer()) + biases = _variable_on_cpu('biases', [hiddens],initializer) if activate == "linear": return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc') elif activate == "sigmoid": @@ -127,15 +131,30 @@ def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01) return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) elif activate == "relu": return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "leaky_relu": + return tf.nn.leaky_relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) else: ip = tf.add(tf.matmul(inputs_processed, weights), biases) return tf.nn.elu(ip, name = str(idx) + '_fc') -def bn_layers(inputs,idx,epsilon = 1e-3): +def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None): with tf.variable_scope('{0}_bn'.format(idx)) as scope: - # Calculate batch mean and variance - batch_mean, batch_var = tf.nn.moments(inputs,[0]) - tz1_hat = (inputs - batch_mean) / tf.sqrt(batch_var + epsilon) - l1_BN = tf.nn.sigmoid(tz1_hat) + #Calculate batch mean and variance + shape = inputs.get_shape().as_list() + scale = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=is_training) + beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=is_training) + pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False) + pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False) - return l1_BN \ No newline at end of file + if is_training: + batch_mean, batch_var = tf.nn.moments(inputs,[0]) + train_mean = tf.assign(pop_mean,pop_mean * decay + batch_mean * (1 - decay)) + train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) + with tf.control_dependencies([train_mean,train_var]): + return tf.nn.batch_normalization(inputs,batch_mean,batch_var,beta,scale,epsilon) + else: + return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon) + +def bn_layers_wrapper(inputs, is_training): + pass + \ No newline at end of file diff --git a/video_prediction/models/__init__.py b/video_prediction/models/__init__.py index ea1fa77f821827b61c3c2cbfa362014c1da20faf..4103a236ab6430d701bae28ee9b6ff6670b110fa 100644 --- a/video_prediction/models/__init__.py +++ b/video_prediction/models/__init__.py @@ -8,6 +8,7 @@ from .dna_model import DNAVideoPredictionModel from .sna_model import SNAVideoPredictionModel from .sv2p_model import SV2PVideoPredictionModel from .vanilla_vae_model import VanillaVAEVideoPredictionModel +from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel def get_model_class(model): model_mappings = { @@ -18,6 +19,7 @@ def get_model_class(model): 'sna': 'SNAVideoPredictionModel', 'sv2p': 'SV2PVideoPredictionModel', 'vae': 'VanillaVAEVideoPredictionModel', + 'convLSTM': 'VanillaConvLstmVideoPredictionModel' } model_class = model_mappings.get(model, model) model_class = globals().get(model_class) diff --git a/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction/models/vanilla_convLSTM_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd2ad3f2b99e9a88c9471db2c0dc6f4ccb89913 --- /dev/null +++ b/video_prediction/models/vanilla_convLSTM_model.py @@ -0,0 +1,162 @@ +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from video_prediction import ops, flow_ops +from video_prediction.models import BaseVideoPredictionModel +from video_prediction.models import networks +from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from video_prediction.layers import layer_def as ld +from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell + +class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): + def __init__(self, mode='train',aggregate_nccl=None, hparams_dict=None, + hparams=None, **kwargs): + super(VanillaConvLstmVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) + print ("Hparams_dict",self.hparams) + self.mode = mode + self.learning_rate = self.hparams.lr + self.gen_images_enc = None + self.recon_loss = None + self.latent_loss = None + self.total_loss = None + self.context_frames = 10 + self.sequence_length = 20 + self.predict_frames = self.sequence_length - self.context_frames + self.aggregate_nccl=aggregate_nccl + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + batch_size: batch size for training. + lr: learning rate. if decay steps is non-zero, this is the + learning rate for steps <= decay_step. + end_lr: learning rate for steps >= end_decay_step if decay_steps + is non-zero, ignored otherwise. + decay_steps: (decay_step, end_decay_step) tuple. + max_steps: number of training steps. + beta1: momentum term of Adam. + beta2: momentum term of Adam. + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + """ + default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict() + print ("default hparams",default_hparams) + hparams = dict( + batch_size=16, + lr=0.001, + end_lr=0.0, + nz=16, + decay_steps=(200000, 300000), + max_steps=350000, + ) + + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def build_graph(self, x): + self.x = x["images"] + #self.global_step = tf.train.get_or_create_global_step() + self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + original_global_variables = tf.global_variables() + # ARCHITECTURE + self.x_hat_context_frames, self.x_hat_predict_frames = self.convLSTM_network() + self.x_hat = tf.concat([self.x_hat_context_frames, self.x_hat_predict_frames], 1) + print("x_hat,shape", self.x_hat) + + self.context_frames_loss = tf.reduce_mean( + tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) + self.predict_frames_loss = tf.reduce_mean( + tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_predict_frames[:, :, :, :, 0])) + self.total_loss = self.context_frames_loss + self.predict_frames_loss + + self.train_op = tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) + self.outputs = {} + self.outputs["gen_images"] = self.x_hat + # Summary op + self.loss_summary = tf.summary.scalar("recon_loss", self.context_frames_loss) + self.loss_summary = tf.summary.scalar("latent_loss", self.predict_frames_loss) + self.loss_summary = tf.summary.scalar("total_loss", self.total_loss) + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + return + + + @staticmethod + def convLSTM_cell(inputs, hidden, nz=16): + print("Inputs shape", inputs.shape) + conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu") + print("Encode_1_shape", conv1.shape) + conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2", activate = "leaky_relu") + print("Encode 2_shape,", conv2.shape) + conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3", activate = "leaky_relu") + print("Encode 3_shape, ", conv3.shape) + y_0 = conv3 + # conv lstm cell + cell_shape = y_0.get_shape().as_list() + with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): + cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size = [3, 3], num_features = 8) + if hidden is None: + hidden = cell.zero_state(y_0, tf.float32) + print("hidden zero layer", hidden.shape) + output, hidden = cell(y_0, hidden) + print("output for cell:", output) + + output_shape = output.get_shape().as_list() + print("output_shape,", output_shape) + + z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) + + conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, "decode_5", activate = "leaky_relu") + print("conv5 shape", conv5) + + conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6", activate = "leaky_relu") + print("conv6 shape", conv6) + + x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, "decode_7", activate = "sigmoid") # set activation to linear + print("x hat shape", x_hat) + return x_hat, hidden + + def convLSTM_network(self): + network_template = tf.make_template('network', + VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables + # create network + x_hat_context = [] + x_hat_predict = [] + seq_start = 1 + hidden = None + for i in range(self.context_frames): + if i < seq_start: + x_1, hidden = network_template(self.x[:, i, :, :, :], hidden) + else: + x_1, hidden = network_template(x_1, hidden) + x_hat_context.append(x_1) + + for i in range(self.predict_frames): + x_1, hidden = network_template(x_1, hidden) + x_hat_predict.append(x_1) + + # pack them all together + x_hat_context = tf.stack(x_hat_context) + x_hat_predict = tf.stack(x_hat_predict) + self.x_hat_context = tf.transpose(x_hat_context, [1, 0, 2, 3, 4]) # change first dim with sec dim + self.x_hat_predict = tf.transpose(x_hat_predict, [1, 0, 2, 3, 4]) # change first dim with sec dim + return self.x_hat_context, self.x_hat_predict diff --git a/video_prediction/models/vanilla_vae_model.py b/video_prediction/models/vanilla_vae_model.py index 1ae1eb06351dd11fa2fd8269f4966a682c9341fd..74280896dca61007c1b361ec4caff9ad5f718d26 100644 --- a/video_prediction/models/vanilla_vae_model.py +++ b/video_prediction/models/vanilla_vae_model.py @@ -79,11 +79,14 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): #print ("self_x:",self.x) #tf.reset_default_graph() #self.x = tf.placeholder(tf.float32, [None,20,64,64,3]) + tf.set_random_seed(12345) self.x = x["images"] - self.global_step = tf.train.get_or_create_global_step() + + #self.global_step = tf.train.get_or_create_global_step() + self.global_step = tf.Variable(0, name = 'global_step', trainable = False) original_global_variables = tf.global_variables() - #self.global_step = tf.Variable(0, name = 'global_step', trainable = False) - #self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step') + self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step') + self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all() # Loss # Reconstruction loss @@ -129,6 +132,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): self.outputs["gen_images"] = self.x_hat global_variables = [var for var in tf.global_variables() if var not in original_global_variables] self.saveable_variables = [self.global_step] + global_variables + #train_op = tf.assign_add(global_step, 1) return