diff --git a/hparams/era5/vae/model_hparams.json b/hparams/era5/vae/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..be9e05e6bb2cee5870f090b3a05edbd401755af7 --- /dev/null +++ b/hparams/era5/vae/model_hparams.json @@ -0,0 +1,8 @@ +{ + "batch_size": 8, + "lr": 0.0002, + "nz": 32, + "max_steps":20 +} + + diff --git a/scripts/train_dummy.py b/scripts/train_dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..b89ca957aa4696f4ba6f4118a83bee10683c16ff --- /dev/null +++ b/scripts/train_dummy.py @@ -0,0 +1,273 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import random +import time +import numpy as np +import tensorflow as tf +from video_prediction import datasets, models + + +def add_tag_suffix(summary, tag_suffix): + summary_proto = tf.Summary() + summary_proto.ParseFromString(summary) + summary = summary_proto + + for value in summary.value: + tag_split = value.tag.split('/') + value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:]) + return summary.SerializeToString() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") + parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") + parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " + "default is logs_dir/model_fname, where model_fname consists of " + "information from model and model_hparams") + parser.add_argument("--output_dir_postfix", default="") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") + + parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training") + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + if args.output_dir is None: + list_depth = 0 + model_fname = '' + for t in ('model=%s,%s' % (args.model, args.model_hparams)): + if t == '[': + list_depth += 1 + if t == ']': + list_depth -= 1 + if list_depth and t == ',': + t = '..' + if t in '=,': + t = '.' + if t in '[]': + t = '' + model_fname += t + args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix + + if args.resume: + if args.checkpoint: + raise ValueError('resume and checkpoint cannot both be specified') + args.checkpoint = args.output_dir + + + model_hparams_dict = {} + if args.model_hparams_dict: + with open(args.model_hparams_dict) as f: + model_hparams_dict.update(json.loads(f.read())) + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict.update(json.loads(f.read())) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + train_dataset = VideoDataset( + args.input_dir, + mode='train') + val_dataset = VideoDataset( + args.val_input_dir or args.input_dir, + mode='val') + + variable_scope = tf.get_variable_scope() + variable_scope.set_use_resource(True) + + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': train_dataset.hparams.context_frames, + 'sequence_length': train_dataset.hparams.sequence_length, + 'repeat': train_dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + hparams_dict=hparams_dict, + hparams=args.model_hparams, + aggregate_nccl=args.aggregate_nccl) + + batch_size = model.hparams.batch_size + train_tf_dataset = train_dataset.make_dataset_v2(batch_size)#Bing: adopt the meteo data prepartion here + train_iterator = train_tf_dataset.make_one_shot_iterator()#Bing:for era5, the problem happen in sess.run(feches) should come from here + # The `Iterator.string_handle()` method returns a tensor that can be evaluated + # and used to feed the `handle` placeholder. + train_handle = train_iterator.string_handle() + val_tf_dataset = val_dataset.make_dataset_v2(batch_size) + val_iterator = val_tf_dataset.make_one_shot_iterator() + val_handle = val_iterator.string_handle() + #iterator = tf.data.Iterator.from_string_handle( + # train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) + inputs = train_iterator.get_next() + val = val_iterator.get_next() + # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles + model.build_graph(inputs) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + with tf.name_scope("parameter_count"): + # exclude trainable variables that are replicas (used in multi-gpu setting) + trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables) + parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) + + saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) + + # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero + summary_writer = tf.summary.FileWriter(args.output_dir) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + #global_step = tf.train.get_or_create_global_step() + max_steps = model.hparams.max_steps + print ("max_steps",max_steps) + with tf.Session(config=config) as sess: + print("parameter_count =", sess.run(parameter_count)) + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + + #coord = tf.train.Coordinator() + #threads = tf.train.start_queue_runners(sess = sess, coord = coord) + print("Init done: {sess.run(tf.local_variables_initializer())}%") + model.restore(sess, args.checkpoint) + print("Restore processed finished") + #sess.run(model.post_init_ops) + print("Model run started") + #val_handle_eval = sess.run(val_handle) + #print ("val_handle_val",val_handle_eval) + #print("val handle done") + sess.graph.finalize() + start_step = sess.run(model.global_step) + print("global step done") + + # start at one step earlier to log everything without doing any training + # step is relative to the start_step + for step in range(-1, max_steps - start_step): + val_handle_eval = sess.run(val_handle) + print ("val_handle_val",val_handle_eval) + if step == 1: + # skip step -1 and 0 for timing purposes (for warmstarting) + start_time = time.time() + + fetches = {"global_step": model.global_step} + fetches["train_op"] = model.train_op + fetches["latent_loss"] = model.latent_loss + fetches["total_loss"] = model.total_loss + + + if isinstance(model.learning_rate, tf.Tensor): + fetches["learning_rate"] = model.learning_rate + + fetches["summary"] = model.summary_op + + run_start_time = time.time() + #Run training results + X = inputs["images"].eval(session=sess) + #results = sess.run(fetches,feed_dict={model.x:X}) #fetch the elements in dictinoary fetch + results = sess.run(fetches) + run_elapsed_time = time.time() - run_start_time + if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: + print('running train_op took too long (%0.1fs)' % run_elapsed_time) + + #Run testing results + val_fetches = {"global_step": model.global_step} + val_fetches["latent_loss"] = model.latent_loss + val_fetches["total_loss"] = model.total_loss + val_fetches["summary"] = model.summary_op + val_results = sess.run(val_fetches) + + summary_writer.add_summary(results["summary"], results["global_step"]) + summary_writer.add_summary(val_results["summary"], val_results["global_step"]) + + + val_datasets = [val_dataset] + val_models = [model] + + # for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)): + # sess.run(val_model.accum_eval_metrics_reset_op) + # # traverse (roughly up to rounding based on the batch size) all the validation dataset + # accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size + # val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op} + # for update_step in range(accum_eval_summary_num_updates): + # print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates)) + # val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) + # accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1)) + # print("recording accum eval summary") + # summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) + summary_writer.flush() + + # global_step will have the correct step count if we resume from a checkpoint + # global step is read before it's incremented + steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size + train_epoch = results["global_step"] / steps_per_epoch + print("progress global step %d epoch %0.1f" % (results["global_step"] + 1, train_epoch)) + if step > 0: + elapsed_time = time.time() - start_time + average_time = elapsed_time / step + images_per_sec = batch_size / average_time + remaining_time = (max_steps - (start_step + step + 1)) * average_time + print("image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % + (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) + + # if results['d_losses']: + # print("d_loss", results["d_loss"]) + # for name, loss in results['d_losses'].items(): + # print(" ", name, loss) + # if results['g_losses']: + # print("g_loss", results["g_loss"]) + # for name, loss in results['g_losses'].items(): + # print(" ", name, loss) + #for name, loss in results['total_loss'].items(): + print(" Results_total_loss",results["total_loss"]) + + print("saving model to", args.output_dir) + saver.save(sess, os.path.join(args.output_dir, "model"), global_step=model.global_step) + print("done") + #global_step = global_step + 1 +if __name__ == '__main__': + main() diff --git a/video_prediction/layers/layer_def.py b/video_prediction/layers/layer_def.py new file mode 100644 index 0000000000000000000000000000000000000000..35a7c910e0b3ec12cb9fdc3cbb9ceda3a86922dd --- /dev/null +++ b/video_prediction/layers/layer_def.py @@ -0,0 +1,141 @@ +"""functions used to construct different architectures +""" + +import tensorflow as tf +import numpy as np + +weight_decay = 0.0005 +def _activation_summary(x): + """Helper to create summaries for activations. + Creates a summary that provides a histogram of activations. + Creates a summary that measure the sparsity of activations. + Args: + x: Tensor + Returns: + nothing + """ + tensor_name = x.op.name + tf.summary.histogram(tensor_name + '/activations', x) + tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) + +def _variable_on_cpu(name, shape, initializer): + """Helper to create a Variable stored on CPU memory. + Args: + name: name of the variable + shape: list of ints + initializer: initializer for Variable + Returns: + Variable Tensor + """ + with tf.device('/cpu:0'): + var = tf.get_variable(name, shape, initializer = initializer) + return var + + +def _variable_with_weight_decay(name, shape, stddev, wd): + """Helper to create an initialized Variable with weight decay. + Note that the Variable is initialized with a truncated normal distribution. + A weight decay is added only if one is specified. + Args: + name: name of the variable + shape: list of ints + stddev: standard deviation of a truncated Gaussian + wd: add L2Loss weight decay multiplied by this float. If None, weight + decay is not added for this Variable. + Returns: + Variable Tensor + """ + var = _variable_on_cpu(name, shape,tf.truncated_normal_initializer(stddev = stddev)) + #var = _variable_on_cpu(name, shape,tf.contrib.layers.xavier_initializer()) + if wd: + weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name = 'weight_loss') + weight_decay.set_shape([]) + tf.add_to_collection('losses', weight_decay) + return var + + +def conv_layer(inputs, kernel_size, stride, num_features, idx, activate="relu"): + print("conv_layer activation function",activate) + + with tf.variable_scope('{0}_conv'.format(idx)) as scope: + print ("DEBUG input shape",inputs.get_shape()) + input_channels = inputs.get_shape()[-1] + weights = _variable_with_weight_decay('weights',shape = [kernel_size, kernel_size, + input_channels, num_features], + stddev = 0.01, wd = weight_decay) + biases = _variable_on_cpu('biases', [num_features], tf.contrib.layers.xavier_initializer()) + conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding = 'SAME') + conv_biased = tf.nn.bias_add(conv, biases) + if activate == "linear": + return conv_biased + elif activate == "relu": + conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "elu": + conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx)) + else: + raise ("activation function is not correct") + + return conv_rect + + +def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, activate="relu"): + with tf.variable_scope('{0}_trans_conv'.format(idx)) as scope: + input_channels = inputs.get_shape()[3] + input_shape = inputs.get_shape().as_list() + print("input_channel",input_channels) + + weights = _variable_with_weight_decay('weights', + shape = [kernel_size, kernel_size, num_features, input_channels], + stddev = 0.1, wd = weight_decay) + biases = _variable_on_cpu('biases', [num_features], tf.contrib.layers.xavier_initializer()) + batch_size = tf.shape(inputs)[0] +# output_shape = tf.stack( +# [tf.shape(inputs)[0], tf.shape(inputs)[1] * stride, tf.shape(inputs)[2] * stride, num_features]) + output_shape = tf.stack( + [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features]) + print ("output_shape",output_shape) + conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME') + conv_biased = tf.nn.bias_add(conv, biases) + if activate == "linear": + return conv_biased + elif activate == "elu": + return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "relu": + return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + else: + return None + + +def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01): + with tf.variable_scope('{0}_fc'.format(idx)) as scope: + input_shape = inputs.get_shape().as_list() + if flat: + dim = input_shape[1] * input_shape[2] * input_shape[3] + inputs_processed = tf.reshape(inputs, [-1, dim]) + else: + dim = input_shape[1] + inputs_processed = inputs + + weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init, + wd = weight_decay) + biases = _variable_on_cpu('biases', [hiddens],tf.contrib.layers.xavier_initializer()) + if activate == "linear": + return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc') + elif activate == "sigmoid": + return tf.nn.sigmoid(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "softmax": + return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "relu": + return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + else: + ip = tf.add(tf.matmul(inputs_processed, weights), biases) + return tf.nn.elu(ip, name = str(idx) + '_fc') + +def bn_layers(inputs,idx,epsilon = 1e-3): + with tf.variable_scope('{0}_bn'.format(idx)) as scope: + # Calculate batch mean and variance + batch_mean, batch_var = tf.nn.moments(inputs,[0]) + tz1_hat = (inputs - batch_mean) / tf.sqrt(batch_var + epsilon) + l1_BN = tf.nn.sigmoid(tz1_hat) + + return l1_BN \ No newline at end of file diff --git a/video_prediction/models/vanilla_vae_model.py b/video_prediction/models/vanilla_vae_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e6d959964aad35db55d677bde87b98bfde5d11 --- /dev/null +++ b/video_prediction/models/vanilla_vae_model.py @@ -0,0 +1,222 @@ +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from video_prediction import ops, flow_ops +from video_prediction.models import BaseVideoPredictionModel +from video_prediction.models import networks +from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from video_prediction.layers import layer_def as ld + +class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): + def __init__(self, mode='train', aggregate_nccl=None,hparams_dict=None, + hparams=None,**kwargs): + super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) + self.mode = mode + self.learning_rate = self.hparams.lr + self.aggregate_nccl=aggregate_nccl + self.gen_images_enc = None + self.g_losses = None + self.d_losses = None + self.g_loss = None + self.d_loss = None + self.g_vars = None + self.d_vars = None + self.train_op = None + self.summary_op = None + self.image_summary_op = None + self.eval_summary_op = None + self.accum_eval_summary_op = None + self.accum_eval_metrics_reset_op = None + self.recon_loss = None + self.latent_loss = None + self.total_loss = None + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + batch_size: batch size for training. + lr: learning rate. if decay steps is non-zero, this is the + learning rate for steps <= decay_step. + end_lr: learning rate for steps >= end_decay_step if decay_steps + is non-zero, ignored otherwise. + decay_steps: (decay_step, end_decay_step) tuple. + max_steps: number of training steps. + beta1: momentum term of Adam. + beta2: momentum term of Adam. + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + """ + default_hparams = super(VanillaVAEVideoPredictionModel, self).get_default_hparams_dict() + hparams = dict( + batch_size=16, + lr=0.001, + end_lr=0.0, + decay_steps=(200000, 300000), + lr_boundaries=(0,), + max_steps=350000, + nz=10, + beta1=0.9, + beta2=0.999, + context_frames=-1, + sequence_length=-1, + clip_length=10, #Bing: TODO What is the clip_length, original is 10, + l1_weight=0.0, + l2_weight=1.0, + vgg_cdist_weight=0.0, + feature_l2_weight=0.0, + ae_l2_weight=0.0, + state_weight=0.0, + tv_weight=0.0, + image_sn_gan_weight=0.0, + image_sn_vae_gan_weight=0.0, + images_sn_gan_weight=0.0, + images_sn_vae_gan_weight=0.0, + video_sn_gan_weight=0.0, + video_sn_vae_gan_weight=0.0, + gan_feature_l2_weight=0.0, + gan_feature_cdist_weight=0.0, + vae_gan_feature_l2_weight=0.0, + vae_gan_feature_cdist_weight=0.0, + gan_loss_type='LSGAN', + joint_gan_optimization=False, + kl_weight=0.0, + kl_anneal='linear', + kl_anneal_k=-1.0, + kl_anneal_steps=(50000, 100000), + z_l1_weight=0.0, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def build_graph(self,x): + + #global_step = tf.train.get_or_create_global_step() + #original_global_variables = tf.global_variables() + # self.x = x["images"] + #print ("self_x:",self.x) + #tf.reset_default_graph() + #self.x = tf.placeholder(tf.float32, [None,20,64,64,3]) + self.x = x["images"] + self.global_step = tf.train.get_or_create_global_step() + original_global_variables = tf.global_variables() + #self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + #self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step') + self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all() + # Loss + # Reconstruction loss + # Minimize the cross-entropy loss + # epsilon = 1e-10 + # recon_loss = -tf.reduce_sum( + # self.x[:,1:,:,:,:] * tf.log(epsilon+self.x_hat[:,:-1,:,:,:]) + + # (1-self.x[:,1:,:,:,:]) * tf.log(epsilon+1-self.x_hat[:,:-1,:,:,:]), + # axis=1 + # ) + + # self.recon_loss = tf.reduce_mean(recon_loss) + self.recon_loss = tf.reduce_mean(tf.square(self.x[:, 1:, :, :, 0] - self.x_hat[:, :-1, :, :, 0])) + + # Latent loss + # KL divergence: measure the difference between two distributions + # Here we measure the divergence between + # the latent distribution and N(0, 1) + latent_loss = -0.5 * tf.reduce_sum( + 1 + self.z_log_sigma_sq - tf.square(self.z_mu) - + tf.exp(self.z_log_sigma_sq), axis = 1) + self.latent_loss = tf.reduce_mean(latent_loss) + self.total_loss = self.recon_loss + self.latent_loss + self.train_op = tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) + # Build a saver + #self.saver = tf.train.Saver(tf.global_variables()) + self.losses = { + 'recon_loss': self.recon_loss, + 'latent_loss': self.latent_loss, + 'total_loss': self.total_loss, + } + + # Summary op + self.loss_summary = tf.summary.scalar("recon_loss", self.recon_loss) + self.loss_summary = tf.summary.scalar("latent_loss", self.latent_loss) + self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss) + self.summary_op = tf.summary.merge_all() + # H(x, x_hat) = -\Sigma x*log(x_hat) + (1-x)*log(1-x_hat) + # self.ckpt = tf.train.Checkpoint(model=self.vae_arc2()) + # self.manager = tf.train.CheckpointManager(self.ckpt,self.checkpoint_dir,max_to_keep=3) + self.outputs = {} + self.outputs["gen_images"] = self.x_hat + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + return + + + @staticmethod + def vae_arc3(x, l_name=0): + seq_name = "sq_" + str(l_name) + "_" + print("DBBUG: INPUT", x) + conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1") + print("Encode_1_shape", conv1.shape) # (?,2,2,8) + # conv2 + conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2") # (?,2,2,8) + print("Encode 2_shape,", conv2.shape) + # conv3 + conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3") # (?,1,1,8) + print("Encode 3_shape, ", conv3.shape) + # flatten + conv4 = tf.layers.Flatten()(conv3) + print("Encode 4_shape, ", conv4.shape) + conv3_shape = conv3.get_shape().as_list() + print("conv4_shape",conv3_shape) + # Todo: to conv3 to + z_mu = ld.fc_layer(conv4, hiddens = 16, idx = seq_name + "enc_fc4_m") + z_log_sigma_sq = ld.fc_layer(conv4, hiddens = 16, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma') + eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32) + z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps + print("latend variables z ", z) + z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "deenc_fc1") + print("latend variables z2 ", z2) + z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]]) + print("latend variables z3 ", z3) + # conv5 + conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, + seq_name + "decode_5") # (16,1,1,8)inputs, kernel_size, stride, num_features + print("Decode 5 shape", conv5.shape) + conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, + seq_name + "decode_6") # (16,1,1,8)inputs, kernel_size, stride, num_features + + # x_1 + x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8") # set activation to linear + print("X_hat", x_hat.shape) + return x_hat, z_mu, z_log_sigma_sq, z + + def vae_arc_all(self): + X = [] + z_log_sigma_sq_all = [] + z_mu_all = [] + for i in range(20): + q, z_mu, z_log_sigma_sq, z = VanillaVAEVideoPredictionModel.vae_arc3(self.x[:, i, :, :, :], l_name = i) + X.append(q) + z_log_sigma_sq_all.append(z_log_sigma_sq) + z_mu_all.append(z_mu) + x_hat = tf.stack(X, axis = 1) + z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1) + z_mu_all = tf.stack(z_mu_all, axis = 1) + print("X_hat", x_hat.shape) + print("zlog_sigma_sq_all", z_log_sigma_sq_all.shape) + return x_hat, z_log_sigma_sq_all, z_mu_all