diff --git a/hparams/era5/vae/model_hparams.json b/hparams/era5/vae/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..be9e05e6bb2cee5870f090b3a05edbd401755af7
--- /dev/null
+++ b/hparams/era5/vae/model_hparams.json
@@ -0,0 +1,8 @@
+{
+    "batch_size": 8,
+    "lr": 0.0002,
+    "nz": 32,
+    "max_steps":20
+}
+
+
diff --git a/scripts/train_dummy.py b/scripts/train_dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b89ca957aa4696f4ba6f4118a83bee10683c16ff
--- /dev/null
+++ b/scripts/train_dummy.py
@@ -0,0 +1,273 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import random
+import time
+import numpy as np
+import tensorflow as tf
+from video_prediction import datasets, models
+
+
+def add_tag_suffix(summary, tag_suffix):
+    summary_proto = tf.Summary()
+    summary_proto.ParseFromString(summary)
+    summary = summary_proto
+
+    for value in summary.value:
+        tag_split = value.tag.split('/')
+        value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:])
+    return summary.SerializeToString()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
+    parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
+    parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
+                                             "default is logs_dir/model_fname, where model_fname consists of "
+                                             "information from model and model_hparams")
+    parser.add_argument("--output_dir_postfix", default="")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+    parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
+    parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
+
+    parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training")
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+    if args.output_dir is None:
+        list_depth = 0
+        model_fname = ''
+        for t in ('model=%s,%s' % (args.model, args.model_hparams)):
+            if t == '[':
+                list_depth += 1
+            if t == ']':
+                list_depth -= 1
+            if list_depth and t == ',':
+                t = '..'
+            if t in '=,':
+                t = '.'
+            if t in '[]':
+                t = ''
+            model_fname += t
+        args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix
+
+    if args.resume:
+        if args.checkpoint:
+            raise ValueError('resume and checkpoint cannot both be specified')
+        args.checkpoint = args.output_dir
+
+
+    model_hparams_dict = {}
+    if args.model_hparams_dict:
+        with open(args.model_hparams_dict) as f:
+            model_hparams_dict.update(json.loads(f.read()))
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict.update(json.loads(f.read()))
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    train_dataset = VideoDataset(
+        args.input_dir,
+        mode='train')
+    val_dataset = VideoDataset(
+        args.val_input_dir or args.input_dir,
+        mode='val')
+
+    variable_scope = tf.get_variable_scope()
+    variable_scope.set_use_resource(True)
+
+    VideoPredictionModel = models.get_model_class(args.model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': train_dataset.hparams.context_frames,
+        'sequence_length': train_dataset.hparams.sequence_length,
+        'repeat': train_dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams,
+        aggregate_nccl=args.aggregate_nccl)
+
+    batch_size = model.hparams.batch_size
+    train_tf_dataset = train_dataset.make_dataset_v2(batch_size)#Bing: adopt the meteo data prepartion here
+    train_iterator = train_tf_dataset.make_one_shot_iterator()#Bing:for era5, the problem happen in sess.run(feches) should come from here
+    # The `Iterator.string_handle()` method returns a tensor that can be evaluated
+    # and used to feed the `handle` placeholder.
+    train_handle = train_iterator.string_handle()
+    val_tf_dataset = val_dataset.make_dataset_v2(batch_size)
+    val_iterator = val_tf_dataset.make_one_shot_iterator()
+    val_handle = val_iterator.string_handle()
+    #iterator = tf.data.Iterator.from_string_handle(
+    #    train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
+    inputs = train_iterator.get_next()
+    val = val_iterator.get_next()
+    # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles
+    model.build_graph(inputs)
+
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+
+    with tf.name_scope("parameter_count"):
+        # exclude trainable variables that are replicas (used in multi-gpu setting)
+        trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables)
+        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables])
+
+    saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2)
+
+    # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero
+    summary_writer = tf.summary.FileWriter(args.output_dir)
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+    #global_step = tf.train.get_or_create_global_step()
+    max_steps = model.hparams.max_steps
+    print ("max_steps",max_steps)
+    with tf.Session(config=config) as sess:
+        print("parameter_count =", sess.run(parameter_count))
+        sess.run(tf.global_variables_initializer())
+        sess.run(tf.local_variables_initializer())
+       
+        #coord = tf.train.Coordinator()
+        #threads = tf.train.start_queue_runners(sess = sess, coord = coord)
+        print("Init done: {sess.run(tf.local_variables_initializer())}%")
+        model.restore(sess, args.checkpoint)
+        print("Restore processed finished")
+        #sess.run(model.post_init_ops)
+        print("Model run started")
+        #val_handle_eval = sess.run(val_handle)
+        #print ("val_handle_val",val_handle_eval)
+        #print("val handle done")
+        sess.graph.finalize()
+        start_step = sess.run(model.global_step)
+        print("global step done")
+
+        # start at one step earlier to log everything without doing any training
+        # step is relative to the start_step
+        for step in range(-1, max_steps - start_step):
+            val_handle_eval = sess.run(val_handle)
+            print ("val_handle_val",val_handle_eval)
+            if step == 1:
+                # skip step -1 and 0 for timing purposes (for warmstarting)
+                start_time = time.time()
+
+            fetches = {"global_step": model.global_step}
+            fetches["train_op"] = model.train_op
+            fetches["latent_loss"] = model.latent_loss
+            fetches["total_loss"] = model.total_loss
+
+
+            if isinstance(model.learning_rate, tf.Tensor):
+                fetches["learning_rate"] = model.learning_rate
+
+            fetches["summary"] = model.summary_op
+
+            run_start_time = time.time()
+            #Run training results
+            X = inputs["images"].eval(session=sess)           
+            #results = sess.run(fetches,feed_dict={model.x:X}) #fetch the elements in dictinoary fetch
+            results = sess.run(fetches)
+            run_elapsed_time = time.time() - run_start_time
+            if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}:
+                print('running train_op took too long (%0.1fs)' % run_elapsed_time)
+
+            #Run testing results
+            val_fetches = {"global_step": model.global_step}
+            val_fetches["latent_loss"] = model.latent_loss
+            val_fetches["total_loss"] = model.total_loss
+            val_fetches["summary"] = model.summary_op
+            val_results = sess.run(val_fetches)
+
+            summary_writer.add_summary(results["summary"], results["global_step"])
+            summary_writer.add_summary(val_results["summary"], val_results["global_step"])
+
+
+            val_datasets = [val_dataset]
+            val_models = [model]
+
+            # for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)):
+            #     sess.run(val_model.accum_eval_metrics_reset_op)
+            #     # traverse (roughly up to rounding based on the batch size) all the validation dataset
+            #     accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size
+            #     val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op}
+            #     for update_step in range(accum_eval_summary_num_updates):
+            #         print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates))
+            #         val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
+            #     accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1))
+            #     print("recording accum eval summary")
+            #     summary_writer.add_summary(accum_eval_summary, val_results["global_step"])
+            summary_writer.flush()
+
+            # global_step will have the correct step count if we resume from a checkpoint
+            # global step is read before it's incremented
+            steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size
+            train_epoch = results["global_step"] / steps_per_epoch
+            print("progress  global step %d  epoch %0.1f" % (results["global_step"] + 1, train_epoch))
+            if step > 0:
+                elapsed_time = time.time() - start_time
+                average_time = elapsed_time / step
+                images_per_sec = batch_size / average_time
+                remaining_time = (max_steps - (start_step + step + 1)) * average_time
+                print("image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)" %
+                      (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))
+
+            # if results['d_losses']:
+            #     print("d_loss", results["d_loss"])
+            # for name, loss in results['d_losses'].items():
+            #     print("  ", name, loss)
+            # if results['g_losses']:
+            #     print("g_loss", results["g_loss"])
+            # for name, loss in results['g_losses'].items():
+            #     print("  ", name, loss)
+            #for name, loss in results['total_loss'].items():
+            print(" Results_total_loss",results["total_loss"])
+            
+            print("saving model to", args.output_dir)
+            saver.save(sess, os.path.join(args.output_dir, "model"), global_step=model.global_step)
+            print("done")
+            #global_step = global_step + 1
+if __name__ == '__main__':
+    main()
diff --git a/video_prediction/layers/layer_def.py b/video_prediction/layers/layer_def.py
new file mode 100644
index 0000000000000000000000000000000000000000..35a7c910e0b3ec12cb9fdc3cbb9ceda3a86922dd
--- /dev/null
+++ b/video_prediction/layers/layer_def.py
@@ -0,0 +1,141 @@
+"""functions used to construct different architectures
+"""
+
+import tensorflow as tf
+import numpy as np
+
+weight_decay = 0.0005
+def _activation_summary(x):
+    """Helper to create summaries for activations.
+    Creates a summary that provides a histogram of activations.
+    Creates a summary that measure the sparsity of activations.
+    Args:
+      x: Tensor
+    Returns:
+      nothing
+    """
+    tensor_name = x.op.name
+    tf.summary.histogram(tensor_name + '/activations', x)
+    tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
+
+def _variable_on_cpu(name, shape, initializer):
+    """Helper to create a Variable stored on CPU memory.
+    Args:
+      name: name of the variable
+      shape: list of ints
+      initializer: initializer for Variable
+    Returns:
+      Variable Tensor
+    """
+    with tf.device('/cpu:0'):
+        var = tf.get_variable(name, shape, initializer = initializer)
+    return var
+
+
+def _variable_with_weight_decay(name, shape, stddev, wd):
+    """Helper to create an initialized Variable with weight decay.
+    Note that the Variable is initialized with a truncated normal distribution.
+    A weight decay is added only if one is specified.
+    Args:
+      name: name of the variable
+      shape: list of ints
+      stddev: standard deviation of a truncated Gaussian
+      wd: add L2Loss weight decay multiplied by this float. If None, weight
+          decay is not added for this Variable.
+    Returns:
+      Variable Tensor
+    """
+    var = _variable_on_cpu(name, shape,tf.truncated_normal_initializer(stddev = stddev))
+    #var = _variable_on_cpu(name, shape,tf.contrib.layers.xavier_initializer())
+    if wd:
+        weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name = 'weight_loss')
+        weight_decay.set_shape([])
+        tf.add_to_collection('losses', weight_decay)
+    return var
+
+
+def conv_layer(inputs, kernel_size, stride, num_features, idx, activate="relu"):
+    print("conv_layer activation function",activate)
+    
+    with tf.variable_scope('{0}_conv'.format(idx)) as scope:
+        print ("DEBUG input shape",inputs.get_shape())
+        input_channels = inputs.get_shape()[-1]
+        weights = _variable_with_weight_decay('weights',shape = [kernel_size, kernel_size, 
+                                                                 input_channels, num_features],
+                                              stddev = 0.01, wd = weight_decay)
+        biases = _variable_on_cpu('biases', [num_features], tf.contrib.layers.xavier_initializer())
+        conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding = 'SAME')
+        conv_biased = tf.nn.bias_add(conv, biases)
+        if activate == "linear":
+            return conv_biased
+        elif activate == "relu":
+            conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx))  
+        elif activate == "elu":
+            conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx))   
+        else:
+            raise ("activation function is not correct")
+
+        return conv_rect
+
+
+def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, activate="relu"):
+    with tf.variable_scope('{0}_trans_conv'.format(idx)) as scope:
+        input_channels = inputs.get_shape()[3]
+        input_shape = inputs.get_shape().as_list()
+        print("input_channel",input_channels)
+
+        weights = _variable_with_weight_decay('weights',
+                                              shape = [kernel_size, kernel_size, num_features, input_channels],
+                                              stddev = 0.1, wd = weight_decay)
+        biases = _variable_on_cpu('biases', [num_features], tf.contrib.layers.xavier_initializer())
+        batch_size = tf.shape(inputs)[0]
+#         output_shape = tf.stack(
+#             [tf.shape(inputs)[0], tf.shape(inputs)[1] * stride, tf.shape(inputs)[2] * stride, num_features])
+        output_shape = tf.stack(
+            [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features])
+        print ("output_shape",output_shape)
+        conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME')
+        conv_biased = tf.nn.bias_add(conv, biases)
+        if activate == "linear":
+            return conv_biased
+        elif activate == "elu":
+            return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx))       
+        elif activate == "relu":
+            return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx))
+        else:
+            return None
+    
+
+def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01):
+    with tf.variable_scope('{0}_fc'.format(idx)) as scope:
+        input_shape = inputs.get_shape().as_list()
+        if flat:
+            dim = input_shape[1] * input_shape[2] * input_shape[3]
+            inputs_processed = tf.reshape(inputs, [-1, dim])
+        else:
+            dim = input_shape[1]
+            inputs_processed = inputs
+
+        weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init,
+                                              wd = weight_decay)
+        biases = _variable_on_cpu('biases', [hiddens],tf.contrib.layers.xavier_initializer())
+        if activate == "linear":
+            return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')
+        elif activate == "sigmoid":
+            return tf.nn.sigmoid(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+        elif activate == "softmax":
+            return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+        elif activate == "relu":
+            return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+        else:
+            ip = tf.add(tf.matmul(inputs_processed, weights), biases)
+            return tf.nn.elu(ip, name = str(idx) + '_fc')
+        
+def bn_layers(inputs,idx,epsilon = 1e-3):
+    with tf.variable_scope('{0}_bn'.format(idx)) as scope:
+        # Calculate batch mean and variance
+        batch_mean, batch_var = tf.nn.moments(inputs,[0])
+        tz1_hat = (inputs - batch_mean) / tf.sqrt(batch_var + epsilon)
+        l1_BN = tf.nn.sigmoid(tz1_hat)
+        
+    return l1_BN
\ No newline at end of file
diff --git a/video_prediction/models/vanilla_vae_model.py b/video_prediction/models/vanilla_vae_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2e6d959964aad35db55d677bde87b98bfde5d11
--- /dev/null
+++ b/video_prediction/models/vanilla_vae_model.py
@@ -0,0 +1,222 @@
+import collections
+import functools
+import itertools
+from collections import OrderedDict
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.util import nest
+from video_prediction import ops, flow_ops
+from video_prediction.models import BaseVideoPredictionModel
+from video_prediction.models import networks
+from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
+from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
+from video_prediction.utils import tf_utils
+from datetime import datetime
+from pathlib import Path
+from video_prediction.layers import layer_def as ld
+
+class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
+    def __init__(self, mode='train', aggregate_nccl=None,hparams_dict=None,
+                 hparams=None,**kwargs):
+        super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
+        self.mode = mode
+        self.learning_rate = self.hparams.lr
+        self.aggregate_nccl=aggregate_nccl
+        self.gen_images_enc = None
+        self.g_losses = None
+        self.d_losses = None
+        self.g_loss = None
+        self.d_loss = None
+        self.g_vars = None
+        self.d_vars = None
+        self.train_op = None
+        self.summary_op = None
+        self.image_summary_op = None
+        self.eval_summary_op = None
+        self.accum_eval_summary_op = None
+        self.accum_eval_metrics_reset_op = None
+        self.recon_loss = None
+        self.latent_loss = None
+        self.total_loss = None
+
+    def get_default_hparams_dict(self):
+        """
+        The keys of this dict define valid hyperparameters for instances of
+        this class. A class inheriting from this one should override this
+        method if it has a different set of hyperparameters.
+
+        Returns:
+            A dict with the following hyperparameters.
+
+            batch_size: batch size for training.
+            lr: learning rate. if decay steps is non-zero, this is the
+                learning rate for steps <= decay_step.
+            end_lr: learning rate for steps >= end_decay_step if decay_steps
+                is non-zero, ignored otherwise.
+            decay_steps: (decay_step, end_decay_step) tuple.
+            max_steps: number of training steps.
+            beta1: momentum term of Adam.
+            beta2: momentum term of Adam.
+            context_frames: the number of ground-truth frames to pass in at
+                start. Must be specified during instantiation.
+            sequence_length: the number of frames in the video sequence,
+                including the context frames, so this model predicts
+                `sequence_length - context_frames` future frames. Must be
+                specified during instantiation.
+        """
+        default_hparams = super(VanillaVAEVideoPredictionModel, self).get_default_hparams_dict()
+        hparams = dict(
+            batch_size=16,
+            lr=0.001,
+            end_lr=0.0,
+            decay_steps=(200000, 300000),
+            lr_boundaries=(0,),
+            max_steps=350000,
+            nz=10,
+            beta1=0.9,
+            beta2=0.999,
+            context_frames=-1,
+            sequence_length=-1,
+            clip_length=10, #Bing: TODO What is the clip_length, original is 10,
+            l1_weight=0.0,
+            l2_weight=1.0,
+            vgg_cdist_weight=0.0,
+            feature_l2_weight=0.0,
+            ae_l2_weight=0.0,
+            state_weight=0.0,
+            tv_weight=0.0,
+            image_sn_gan_weight=0.0,
+            image_sn_vae_gan_weight=0.0,
+            images_sn_gan_weight=0.0,
+            images_sn_vae_gan_weight=0.0,
+            video_sn_gan_weight=0.0,
+            video_sn_vae_gan_weight=0.0,
+            gan_feature_l2_weight=0.0,
+            gan_feature_cdist_weight=0.0,
+            vae_gan_feature_l2_weight=0.0,
+            vae_gan_feature_cdist_weight=0.0,
+            gan_loss_type='LSGAN',
+            joint_gan_optimization=False,
+            kl_weight=0.0,
+            kl_anneal='linear',
+            kl_anneal_k=-1.0,
+            kl_anneal_steps=(50000, 100000),
+            z_l1_weight=0.0,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def build_graph(self,x):
+        
+        #global_step = tf.train.get_or_create_global_step()
+        #original_global_variables = tf.global_variables()
+        # self.x = x["images"]
+        #print ("self_x:",self.x)
+        #tf.reset_default_graph()
+        #self.x = tf.placeholder(tf.float32, [None,20,64,64,3])
+        self.x = x["images"]
+        self.global_step = tf.train.get_or_create_global_step()
+        original_global_variables = tf.global_variables()
+        #self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
+        #self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step')
+        self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all()
+        # Loss
+        # Reconstruction loss
+        # Minimize the cross-entropy loss
+        #         epsilon = 1e-10
+        #         recon_loss = -tf.reduce_sum(
+        #             self.x[:,1:,:,:,:] * tf.log(epsilon+self.x_hat[:,:-1,:,:,:]) +
+        #             (1-self.x[:,1:,:,:,:]) * tf.log(epsilon+1-self.x_hat[:,:-1,:,:,:]),
+        #             axis=1
+        #         )
+
+        #        self.recon_loss = tf.reduce_mean(recon_loss)
+        self.recon_loss = tf.reduce_mean(tf.square(self.x[:, 1:, :, :, 0] - self.x_hat[:, :-1, :, :, 0]))
+
+        # Latent loss
+        # KL divergence: measure the difference between two distributions
+        # Here we measure the divergence between
+        # the latent distribution and N(0, 1)
+        latent_loss = -0.5 * tf.reduce_sum(
+            1 + self.z_log_sigma_sq - tf.square(self.z_mu) -
+            tf.exp(self.z_log_sigma_sq), axis = 1)
+        self.latent_loss = tf.reduce_mean(latent_loss)
+        self.total_loss = self.recon_loss + self.latent_loss
+        self.train_op = tf.train.AdamOptimizer(
+            learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
+        # Build a saver
+        #self.saver = tf.train.Saver(tf.global_variables())
+        self.losses = {
+            'recon_loss': self.recon_loss,
+            'latent_loss': self.latent_loss,
+            'total_loss': self.total_loss,
+        }
+
+        # Summary op
+        self.loss_summary = tf.summary.scalar("recon_loss", self.recon_loss)
+        self.loss_summary = tf.summary.scalar("latent_loss", self.latent_loss)
+        self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss)
+        self.summary_op = tf.summary.merge_all()
+        # H(x, x_hat) = -\Sigma x*log(x_hat) + (1-x)*log(1-x_hat)
+        # self.ckpt = tf.train.Checkpoint(model=self.vae_arc2())
+        # self.manager = tf.train.CheckpointManager(self.ckpt,self.checkpoint_dir,max_to_keep=3)
+        self.outputs = {}
+        self.outputs["gen_images"] = self.x_hat
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        return
+
+
+    @staticmethod
+    def vae_arc3(x, l_name=0):
+        seq_name = "sq_" + str(l_name) + "_"
+        print("DBBUG: INPUT", x)
+        conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1")
+        print("Encode_1_shape", conv1.shape)  # (?,2,2,8)
+        # conv2
+        conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2")  # (?,2,2,8)
+        print("Encode 2_shape,", conv2.shape)
+        # conv3
+        conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3")  # (?,1,1,8)
+        print("Encode 3_shape, ", conv3.shape)
+        # flatten
+        conv4 = tf.layers.Flatten()(conv3)
+        print("Encode 4_shape, ", conv4.shape)
+        conv3_shape = conv3.get_shape().as_list()
+        print("conv4_shape",conv3_shape)
+        # Todo: to conv3 to 
+        z_mu = ld.fc_layer(conv4, hiddens = 16, idx = seq_name + "enc_fc4_m")
+        z_log_sigma_sq = ld.fc_layer(conv4, hiddens = 16, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma')
+        eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32)
+        z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps
+        print("latend variables z ", z)
+        z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "deenc_fc1")
+        print("latend variables z2 ", z2)
+        z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]])
+        print("latend variables z3 ", z3)
+        # conv5
+        conv5 = ld.transpose_conv_layer(z3, 3, 2, 8,
+                                        seq_name + "decode_5")  # (16,1,1,8)inputs, kernel_size, stride, num_features
+        print("Decode 5 shape", conv5.shape)
+        conv6  = ld.transpose_conv_layer(conv5, 3, 1, 8,
+                                        seq_name + "decode_6")  # (16,1,1,8)inputs, kernel_size, stride, num_features
+        
+        # x_1
+        x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8")  # set activation to linear
+        print("X_hat", x_hat.shape)
+        return x_hat, z_mu, z_log_sigma_sq, z
+
+    def vae_arc_all(self):
+        X = []
+        z_log_sigma_sq_all = []
+        z_mu_all = []
+        for i in range(20):
+            q, z_mu, z_log_sigma_sq, z = VanillaVAEVideoPredictionModel.vae_arc3(self.x[:, i, :, :, :], l_name = i)
+            X.append(q)
+            z_log_sigma_sq_all.append(z_log_sigma_sq)
+            z_mu_all.append(z_mu)
+        x_hat = tf.stack(X, axis = 1)
+        z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1)
+        z_mu_all = tf.stack(z_mu_all, axis = 1)
+        print("X_hat", x_hat.shape)
+        print("zlog_sigma_sq_all", z_log_sigma_sq_all.shape)
+        return x_hat, z_log_sigma_sq_all, z_mu_all