diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh
new file mode 100755
index 0000000000000000000000000000000000000000..a81e9a1499ce2619c6d934d32396c7128bd6b565
--- /dev/null
+++ b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh
@@ -0,0 +1,37 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+##SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=1
+#SBATCH --output=DataPreprocess_to_tf-out.%j
+#SBATCH --error=DataPreprocess_to_tf-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --partition=devel
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+
+
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
+
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+
+source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist 
+destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist
+
+# run Preprocessing (step 2 where Tf-records are generated)
+srun python ../video_prediction/datasets/moving_mnist.py ${source_dir} ${destination_dir}/tfrecords
diff --git a/video_prediction_savp/HPC_scripts/generate_movingmnist.sh b/video_prediction_savp/HPC_scripts/generate_movingmnist.sh
new file mode 100755
index 0000000000000000000000000000000000000000..1de81d2543d255a160ff811ff391a963ef712bde
--- /dev/null
+++ b/video_prediction_savp/HPC_scripts/generate_movingmnist.sh
@@ -0,0 +1,44 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+##SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=1
+#SBATCH --output=generate_era5-out.%j
+#SBATCH --error=generate_era5-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --gres=gpu:1
+#SBATCH --partition=develgpus
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=s.stadtler@fz-juelich.de
+##jutil env activate -p cjjsc42
+
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
+
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist
+checkpoint_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist
+results_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist
+# name of model
+model=convLSTM
+
+# run postprocessing/generation of model results including evaluation metrics
+srun python -u ../scripts/generate_movingmnist.py \
+--input_dir ${source_dir}/ --dataset_hparams sequence_length=20 --checkpoint  ${checkpoint_dir}/${model} \
+--mode test --model ${model} --results_dir ${results_dir}/${model} --batch_size 2 --dataset era5   > generate_era5-out.out
+
+#srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
diff --git a/video_prediction_savp/HPC_scripts/hyperparam_setup.sh b/video_prediction_savp/HPC_scripts/hyperparam_setup.sh
index a6c24a062ca30b06879641806d771beacc4b34f8..34894da8c3e345955c05b69994f4f3cf431174ff 100644
--- a/video_prediction_savp/HPC_scripts/hyperparam_setup.sh
+++ b/video_prediction_savp/HPC_scripts/hyperparam_setup.sh
@@ -1,6 +1,6 @@
 #!/usr/bin/env bash
 
-# for choosing the model
+# for choosing the model convLSTM vae mcnet savp
 export model=convLSTM
 export model_hparams=../hparams/era5/${model}/model_hparams.json
 
diff --git a/video_prediction_savp/HPC_scripts/train_movingmnist.sh b/video_prediction_savp/HPC_scripts/train_movingmnist.sh
new file mode 100755
index 0000000000000000000000000000000000000000..006ff73c30c4a53c80aef9371bfbe29fac39f973
--- /dev/null
+++ b/video_prediction_savp/HPC_scripts/train_movingmnist.sh
@@ -0,0 +1,47 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+##SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=1
+#SBATCH --output=train_moving_mnist-out.%j
+#SBATCH --error=train_moving_mnist-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --gres=gpu:1
+#SBATCH --partition=gpus
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+##jutil env activate -p cjjsc42
+
+
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
+
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+
+source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist
+destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist
+
+# for choosing the model, convLSTM,savp, mcnet,vae
+model=convLSTM
+dataset=moving_mnist
+model_hparams=../hparams/${dataset}/${model}/model_hparams.json
+destination_dir=${destination_dir}/${model}/"$(date +"%Y%m%dT%H%M")_"$USER""
+
+# rund training
+
+srun python ../scripts/train_dummy.py --input_dir  ${source_dir}/tfrecords/ --dataset moving_mnist  --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/
diff --git a/video_prediction_savp/env_setup/requirements.txt b/video_prediction_savp/env_setup/requirements.txt
index 4bf2f0b25d082c4c503bbd56f46d28360a48df43..173b8a10c8dec1d8186adc84c144b79863406d3f 100644
--- a/video_prediction_savp/env_setup/requirements.txt
+++ b/video_prediction_savp/env_setup/requirements.txt
@@ -1,4 +1,4 @@
-opencv-python
+opencv-python==4.2.0.34
 scipy
 scikit-image
 pandas
diff --git a/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json b/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json
index c2edaad9f9ac158f6e7b8d94bb81db16d55d05e8..fde951edd2e6b41965fbdce6ce831c1e154cbd0e 100644
--- a/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json
+++ b/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json
@@ -1,10 +1,11 @@
 
 {
-    "batch_size": 10,
+    "batch_size": 4,
     "lr": 0.001,
-    "max_epochs":2,
+    "max_epochs":20,
     "context_frames":10,
-    "sequence_length":20
+    "sequence_length":20,
+    "loss_fun":"rmse"
 
 }
 
diff --git a/video_prediction_savp/hparams/moving_mnist/convLSTM/model_hparams.json b/video_prediction_savp/hparams/moving_mnist/convLSTM/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..b59f6cb2ee96162b2eb6014d7ca6bd37f54d4218
--- /dev/null
+++ b/video_prediction_savp/hparams/moving_mnist/convLSTM/model_hparams.json
@@ -0,0 +1,12 @@
+
+{
+    "batch_size": 10,
+    "lr": 0.001,
+    "max_epochs":20,
+    "context_frames":10,
+    "sequence_length":20,
+    "loss_fun":"cross_entropy"
+}
+
+
+
diff --git a/video_prediction_savp/scripts/generate_movingmnist.py b/video_prediction_savp/scripts/generate_movingmnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4fbf5eb5d8d8f4cad87ae26d15bc2787d9e6c0a
--- /dev/null
+++ b/video_prediction_savp/scripts/generate_movingmnist.py
@@ -0,0 +1,822 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import math
+import random
+import cv2
+import numpy as np
+import tensorflow as tf
+import pickle
+from random import seed
+import random
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import matplotlib.animation as animation
+import pandas as pd
+import re
+from video_prediction import datasets, models
+from matplotlib.colors import LinearSegmentedColormap
+#from matplotlib.ticker import MaxNLocator
+#from video_prediction.utils.ffmpeg_gif import save_gif
+from skimage.metrics import structural_similarity as ssim
+import datetime
+# Scarlet 2020/05/28: access to statistical values in json file 
+from os import path
+import sys
+sys.path.append(path.abspath('../video_prediction/datasets/'))
+from era5_dataset_v2 import Norm_data
+from os.path import dirname
+from netCDF4 import Dataset,date2num
+from metadata import MetaData as MetaData
+
+def set_seed(seed):
+    if seed is not None:
+        tf.set_random_seed(seed)
+        np.random.seed(seed)
+        random.seed(seed) 
+
+def get_coordinates(metadata_fname):
+    """
+    Retrieves the latitudes and longitudes read from the metadata json file.
+    """
+    md = MetaData(json_file=metadata_fname)
+    md.get_metadata_from_file(metadata_fname)
+    
+    try:
+        print("lat:",md.lat)
+        print("lon:",md.lon)
+        return md.lat, md.lon
+    except:
+        raise ValueError("Error when handling: '"+metadata_fname+"'")
+    
+
+def load_checkpoints_and_create_output_dirs(checkpoint,dataset,model):
+    if checkpoint:
+        checkpoint_dir = os.path.normpath(checkpoint)
+        if not os.path.isdir(checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % checkpoint)
+            options = json.loads(f.read())
+            dataset = dataset or options['dataset']
+            model = model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+    else:
+        if not dataset:
+            raise ValueError('dataset is required when checkpoint is not specified')
+        if not model:
+            raise ValueError('model is required when checkpoint is not specified')
+
+    return options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict
+
+
+    
+def setup_dataset(dataset,input_dir,mode,seed,num_epochs,dataset_hparams,dataset_hparams_dict):
+    VideoDataset = datasets.get_dataset_class(dataset)
+    dataset = VideoDataset(
+        input_dir,
+        mode = mode,
+        num_epochs = num_epochs,
+        seed = seed,
+        hparams_dict = dataset_hparams_dict,
+        hparams = dataset_hparams)
+    return dataset
+
+
+def setup_dirs(input_dir,results_png_dir):
+    input_dir = args.input_dir
+    temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/"
+    print ("temporal_dir:",temporal_dir)
+
+
+def update_hparams_dict(model_hparams_dict,dataset):
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    return hparams_dict
+
+
+def psnr(img1, img2):
+    mse = np.mean((img1 - img2) ** 2)
+    if mse == 0: return 100
+    PIXEL_MAX = 1
+    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+
+
+def setup_num_samples_per_epoch(num_samples, dataset):
+    if num_samples:
+        if num_samples > dataset.num_examples_per_epoch():
+            raise ValueError('num_samples cannot be larger than the dataset')
+        num_examples_per_epoch = num_samples
+    else:
+        num_examples_per_epoch = dataset.num_examples_per_epoch()
+    #if num_examples_per_epoch % args.batch_size != 0:
+    #    raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
+    return num_examples_per_epoch
+
+
+def initia_save_data():
+    sample_ind = 0
+    gen_images_all = []
+    #Bing:20200410
+    persistent_images_all = []
+    input_images_all = []
+    return sample_ind, gen_images_all,persistent_images_all, input_images_all
+
+
+def write_params_to_results_dir(args,output_dir,dataset,model):
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+    with open(os.path.join(output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys = True, indent = 4))
+    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4))
+    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4))
+    return None
+
+def get_one_seq_and_time(input_images,i):
+    assert (len(np.array(input_images).shape)==5)
+    input_images_ = input_images[i,:,:,:,:]
+    return input_images_
+
+
+def denorm_images_all_channels(input_images_):
+    input_images_all_channles_denorm = []
+    input_images_ = np.array(input_images_)
+    input_images_denorm = input_images_ * 255.0
+    #print("input_images_denorm shape",input_images_denorm.shape)
+    return input_images_denorm
+
+def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"):
+    """
+    Plot the seq images 
+    """
+
+    if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)")
+    img_len = imgs.shape[0]
+    fig = plt.figure(figsize=(18,6))
+    gs = gridspec.GridSpec(1, 10)
+    gs.update(wspace = 0., hspace = 0.)
+    for i in range(img_len):      
+        ax1 = plt.subplot(gs[i])
+        plt.imshow(imgs[i] ,cmap = 'jet')
+        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+    plt.savefig(os.path.join(output_png_dir, label + "_" +   str(idx) +  ".jpg"))
+    print("images_saved")
+    plt.clf()
+ 
+
+    
+def get_persistence(ts):
+    pass
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type = str, required = True,
+                        help = "either a directory containing subdirectories "
+                               "train, val, test, etc, or a directory containing "
+                               "the tfrecords")
+    parser.add_argument("--results_dir", type = str, default = 'results',
+                        help = "ignored if output_gif_dir is specified")
+    parser.add_argument("--checkpoint",
+                        help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+    parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val',
+                        help = 'mode for dataset, val or test.')
+    parser.add_argument("--dataset", type = str, help = "dataset class name")
+    parser.add_argument("--dataset_hparams", type = str,
+                        help = "a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--model", type = str, help = "model class name")
+    parser.add_argument("--model_hparams", type = str,
+                        help = "a string of comma separated list of model hyperparameters")
+    parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch")
+    parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)")
+    parser.add_argument("--num_epochs", type = int, default = 1)
+    parser.add_argument("--num_stochastic_samples", type = int, default = 1)
+    parser.add_argument("--gif_length", type = int, help = "default is sequence_length")
+    parser.add_argument("--fps", type = int, default = 4)
+    parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use")
+    parser.add_argument("--seed", type = int, default = 7)
+    args = parser.parse_args()
+    set_seed(args.seed)
+
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+
+    options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict = load_checkpoints_and_create_output_dirs(args.checkpoint,args.dataset,args.model)
+    print("Step 1 finished")
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    #setup dataset and model object
+    input_dir_tf = os.path.join(args.input_dir, "tfrecords") # where tensorflow records are stored
+    dataset = setup_dataset(dataset,input_dir_tf,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict)
+    
+    print("Step 2 finished")
+    VideoPredictionModel = models.get_model_class(model)
+    
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    
+    model = VideoPredictionModel(
+        mode = args.mode,
+        hparams_dict = hparams_dict,
+        hparams = args.model_hparams)
+
+    sequence_length = model.hparams.sequence_length
+    context_frames = model.hparams.context_frames
+    future_length = sequence_length - context_frames #context_Frames is the number of input frames
+
+    num_examples_per_epoch = setup_num_samples_per_epoch(args.num_samples,dataset)
+    
+    inputs = dataset.make_batch(args.batch_size)
+    print("inputs",inputs)
+    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    print("input_phs",input_phs)
+    
+    
+    # Build graph
+    with tf.variable_scope(''):
+        model.build_graph(input_phs)
+
+    #Write the update hparameters into results_dir    
+    write_params_to_results_dir(args=args,output_dir=args.results_dir,dataset=dataset,model=model)
+        
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac)
+    config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True)
+    sess = tf.Session(config = config)
+    sess.graph.as_default()
+    sess.run(tf.global_variables_initializer())
+    sess.run(tf.local_variables_initializer())
+    model.restore(sess, args.checkpoint)
+    
+    #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data
+    sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data()
+    
+    is_first=True
+    #loop for in samples
+    while sample_ind < 5:
+        gen_images_stochastic = []
+        if args.num_samples and sample_ind >= args.num_samples:
+            break
+        try:
+            input_results = sess.run(inputs)
+            input_images = input_results["images"]
+            #get the intial times
+            t_starts = input_results["T_start"]
+        except tf.errors.OutOfRangeError:
+            break
+            
+        #Get prediction values 
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel]
+        print("gen_images 20200822:",np.array(gen_images).shape)       
+        #Loop in batch size
+        for i in range(args.batch_size):
+            
+            #get one seq and the corresponding start time point
+            input_images_ = get_one_seq_and_time(input_images,i)
+            
+            #Renormalized data for inputs
+            input_images_denorm = denorm_images_all_channels(input_images_)  
+            print("input_images_denorm",input_images_denorm[0][0])
+                                                             
+            #Renormalized data for inputs
+            gen_images_ = gen_images[i]
+            gen_images_denorm = denorm_images_all_channels(gen_images_)
+            print("gene_images_denorm:",gen_images_denorm[0][0])
+            
+            #Generate images inputs
+            plot_seq_imgs(imgs=input_images_denorm[context_frames+1:,:,:,0],idx = sample_ind + i, label="Ground Truth",output_png_dir=args.results_dir)  
+                                                             
+            #Generate forecast images
+            plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],idx = sample_ind + i,label="Forecast by Model " + args.model,output_png_dir=args.results_dir) 
+            
+            #TODO: Scaret plot persistence image
+            #implment get_persistence() function
+
+            #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation
+
+        sample_ind += args.batch_size
+
+
+        #for input_image in input_images_:
+
+#             for stochastic_sample_ind in range(args.num_stochastic_samples):
+#                 input_images_all.extend(input_images)
+#                 with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files:
+#                     pickle.dump(list(input_images_all), input_files)
+
+
+#                 gen_images_stochastic.append(gen_images)
+#                 #print("Stochastic_sample,", stochastic_sample_ind)
+#                 for i in range(args.batch_size):
+#                     #bing:20200417
+#                     t_stampe = test_temporal_pkl[sample_ind+i]
+#                     print("timestamp:",type(t_stampe))
+#                     persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1)
+#                     print ("persistent ts",persistent_ts)
+#                     persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts))
+#                     persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length]
+#                     print("persistent index in test set:", persistent_idx)
+#                     print("persistent_X.shape",persistent_X.shape)
+#                     persistent_images_all.append(persistent_X)
+
+#                     cmap_name = 'my_list'
+#                     if sample_ind < 100:
+#                         #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
+#                         #    sample_ind) + " + Sample_" + str(i)
+#                         name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S")
+#                         print ("name",name)
+#                         gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
+#                         #gen_images_ =  gen_images[i, :]
+#                         input_images_ = input_images[i, :]
+#                         #Bing:20200417
+#                         #persistent_images = ?
+#                         #+++Scarlet:20200528   
+#                         #print('Scarlet1')
+#                         input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm)
+#                         persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm)
+#                         #---Scarlet:20200528    
+#                         gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
+#                                         range(sequence_length)]  # return the list with 10 (sequence) mse
+#                         persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in
+#                                         range(sequence_length)]  # return the list with 10 (sequence) mse
+
+#                         fig = plt.figure(figsize=(18,6))
+#                         gs = gridspec.GridSpec(1, 10)
+#                         gs.update(wspace = 0., hspace = 0.)
+#                         ts = list(range(10,20)) #[10,11,12,..]
+#                         xlables = [round(i,2) for i  in list(np.linspace(np.min(lon),np.max(lon),5))]
+#                         ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+
+#                         for t in ts:
+
+#                             #if t==0 : ax1=plt.subplot(gs[t])
+#                             ax1 = plt.subplot(gs[ts.index(t)])
+#                             #+++Scarlet:20200528
+#                             #print('Scarlet2')
+#                             input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm)
+#                             #---Scarlet:20200528
+#                             plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
+#                             ax1.title.set_text("t = " + str(t+1-10))
+#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+#                             if t == 0:
+#                                 plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
+#                                 plt.ylabel("Ground Truth", fontsize=10)
+#                         plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
+#                         plt.clf()
+
+#                         fig = plt.figure(figsize=(12,6))
+#                         gs = gridspec.GridSpec(1, 10)
+#                         gs.update(wspace = 0., hspace = 0.)
+
+#                         for t in ts:
+#                             #if t==0 : ax1=plt.subplot(gs[t])
+#                             ax1 = plt.subplot(gs[ts.index(t)])
+#                             #+++Scarlet:20200528
+#                             #print('Scarlet3')
+#                             gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm)
+#                             #---Scarlet:20200528
+#                             plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
+#                             ax1.title.set_text("t = " + str(t+1-10))
+#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+#                         plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
+#                         plt.clf()
+
+
+#                         fig = plt.figure(figsize=(12,6))
+#                         gs = gridspec.GridSpec(1, 10)
+#                         gs.update(wspace = 0., hspace = 0.)
+#                         for t in ts:
+#                             #if t==0 : ax1=plt.subplot(gs[t])
+#                             ax1 = plt.subplot(gs[ts.index(t)])
+#                             #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+#                             plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300)
+#                             ax1.title.set_text("t = " + str(t+1-10))
+#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+#                         plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg"))
+#                         plt.clf()
+
+                        
+#                 with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files:
+#                     pickle.dump(list(persistent_images_all), input_files)
+#                     print ("Save persistent all")
+#                 if is_first:
+#                     gen_images_all = gen_images_stochastic
+#                     is_first = False
+#                 else:
+#                     gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
+
+#                 if args.num_stochastic_samples == 1:
+#                     with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files:
+#                         pickle.dump(list(gen_images_all[0]), gen_files)
+#                         print ("Save generate all")
+#                 else:
+#                     with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
+#                         pickle.dump(list(gen_images_stochastic), gen_files)
+#                     with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
+#                         pickle.dump(list(gen_images_all), gen_files)
+
+#         sample_ind += args.batch_size
+
+
+#     with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files:
+#         input_images_all = pickle.load(input_files)
+
+#     with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files:
+#         gen_images_all = pickle.load(gen_files)
+
+#     with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files:
+#         persistent_images_all = pickle.load(gen_files)
+
+#     #+++Scarlet:20200528
+#     #print('Scarlet4')
+#     input_images_all = np.array(input_images_all)
+#     input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm)
+#     #---Scarlet:20200528
+#     persistent_images_all = np.array(persistent_images_all)
+#     if len(np.array(gen_images_all).shape) == 6:
+#         for i in range(len(gen_images_all)):
+#             #+++Scarlet:20200528
+#             #print('Scarlet5')
+#             gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:]
+#             gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm)
+#             #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+#             #---Scarlet:20200528
+#             mse_all = []
+#             psnr_all = []
+#             ssim_all = []
+#             f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w')
+#             for i in range(future_length):
+#                 mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :,
+#                                                                             0]) ** 2)  # look at all timesteps except the first
+#                 psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0])
+#                 ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0],
+#                                   data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min(
+#                                       input_images_all[:, i + 10, :, :, 0].flatten()))
+#                 mse_all.extend([mse_model])
+#                 psnr_all.extend([psnr_model])
+#                 ssim_all.extend([ssim_model])
+#                 results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
+#                 f.write("##########Predicted Frame {}\n".format(str(i + 1)))
+#                 f.write("Model MSE: %f\n" % mse_model)
+#                 # f.write("Previous Frame MSE: %f\n" % mse_prev)
+#                 f.write("Model PSNR: %f\n" % psnr_model)
+#                 f.write("Model SSIM: %f\n" % ssim_model)
+
+
+#             pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb"))
+#             # f.write("Previous frame PSNR: %f\n" % psnr_prev)
+#             f.write("Shape of X_test: " + str(input_images_all.shape))
+#             f.write("")
+#             f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape))
+
+#     else:
+#         #+++Scarlet:20200528
+#         #print('Scarlet6')
+#         gen_images_all = np.array(gen_images_all)
+#         gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm)
+#         #---Scarlet:20200528
+        
+#         # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
+#         # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
+#         # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
+#         mse_all = []
+#         psnr_all = []
+#         ssim_all = []
+#         persistent_mse_all = []
+#         persistent_psnr_all = []
+#         persistent_ssim_all = []
+#         f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w')
+#         for i in range(future_length):
+#             mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :,
+#                                                                         0]) ** 2)  # look at all timesteps except the first
+#             persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :,
+#                                                                         0]) ** 2)  # look at all timesteps except the first
+            
+#             psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0])
+#             ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0],
+#                               data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min(
+#                                   input_images_all[:, i + 10, :, :, 0].flatten()))
+#             persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0])
+#             persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0],
+#                               data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten()))
+#             mse_all.extend([mse_model])
+#             psnr_all.extend([psnr_model])
+#             ssim_all.extend([ssim_model])
+#             persistent_mse_all.extend([persistent_mse_model])
+#             persistent_psnr_all.extend([persistent_psnr_model])
+#             persistent_ssim_all.extend([persistent_ssim_model])
+#             results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
+
+#             persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all}
+#             f.write("##########Predicted Frame {}\n".format(str(i + 1)))
+#             f.write("Model MSE: %f\n" % mse_model)
+#             # f.write("Previous Frame MSE: %f\n" % mse_prev)
+#             f.write("Model PSNR: %f\n" % psnr_model)
+#             f.write("Model SSIM: %f\n" % ssim_model)
+
+#         pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb"))
+#         pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb"))
+#         # f.write("Previous frame PSNR: %f\n" % psnr_prev)
+#         f.write("Shape of X_test: " + str(input_images_all.shape))
+#         f.write("")
+#         f.write("Shape of X_hat: " + str(gen_images_all.shape)      
+
+if __name__ == '__main__':
+    main()        
+
+    #psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
+    #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
+    #psnr_prev = psnr(input_images_all[:, :, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
+
+    # ims = []
+    # fig = plt.figure()
+    # for frame in range(20):
+    #     input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel)
+    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
+    #     #pix_std = np.std(input_gen_diff, axis=0)
+    #     im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
+    #     if frame == 0:
+    #         fig.colorbar(im)
+    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
+    #     ims.append([im, ttl])
+    # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000)
+    # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
+    # plt.close("all")
+
+    # ims = []
+    # fig = plt.figure()
+    # for frame in range(19):
+    #     pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
+    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
+    #     #pix_std = np.std(input_gen_diff, axis=0)
+    #     im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
+    #     if frame == 0:
+    #         fig.colorbar(im)
+    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
+    #     ims.append([im, ttl])
+    # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
+
+    # seed(1)
+    # s = random.sample(range(len(gen_images_all)), 100)
+    # print("******KDP******")
+    # #kernel density plot for checking the model collapse
+    # fig = plt.figure()
+    # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
+    # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
+    # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
+    # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
+    # plt.clf()
+
+    #line plot for evaluating the prediction and groud-truth
+    # for i in [0,3,6,9,12,15,18]:
+    #     fig = plt.figure()
+    #     plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
+    #     #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
+    #     plt.xlabel("Prediction")
+    #     plt.ylabel("Real values")
+    #     plt.title("Frame_{}".format(i+1))
+    #     plt.plot([250,300], [250,300],color="black")
+    #     plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
+    #     plt.clf()
+    #
+    # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
+    # x = [str(i+1) for i in list(range(19))]
+    # fig,axis = plt.subplots()
+    # mean_f = np.mean(mse_model_by_frames, axis = 0)
+    # median = np.median(mse_model_by_frames, axis=0)
+    # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
+    # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
+    # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
+    # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
+    # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
+    # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
+    # plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
+    # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
+    # plt.title(f'MSE percentile')
+    # plt.xlabel("Frames")
+    # plt.legend(loc=2, fontsize=8)
+    # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
+
+
+##                
+##
+##                    # fig = plt.figure()
+##                    # gs = gridspec.GridSpec(4,6)
+##                    # gs.update(wspace = 0.7,hspace=0.8)
+##                    # ax1 = plt.subplot(gs[0:2,0:3])
+##                    # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
+##                    # ax3 = plt.subplot(gs[2:4,0:3])
+##                    # ax4 = plt.subplot(gs[2:4,3:])
+##                    # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
+##                    # ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+##                    # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
+##                    # ax1.title.set_text("(a) Ground Truth")
+##                    # ax2.title.set_text("(b) SAVP")
+##                    # ax3.title.set_text("(c) Diff.")
+##                    # ax4.title.set_text("(d) MSE")
+##                    #
+##                    # ax1.xaxis.set_tick_params(labelsize=7)
+##                    # ax1.yaxis.set_tick_params(labelsize = 7)
+##                    # ax2.xaxis.set_tick_params(labelsize=7)
+##                    # ax2.yaxis.set_tick_params(labelsize = 7)
+##                    # ax3.xaxis.set_tick_params(labelsize=7)
+##                    # ax3.yaxis.set_tick_params(labelsize = 7)
+##                    #
+##                    # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
+##                    # print("inti images shape", init_images.shape)
+##                    # xdata, ydata = [], []
+##                    # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
+##                    # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
+##                    # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+##                    # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+##                    # #x = np.linspace(0, 64, 64)
+##                    # #y = np.linspace(0, 64, 64)
+##                    # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+##                    # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+##                    # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
+##                    # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
+##                    #
+##                    # cm = LinearSegmentedColormap.from_list(
+##                    #     cmap_name, "bwr", N = 5)
+##                    #
+##                    # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r',
+##                    # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm)  # cmap = 'PuBu_r',
+##                    # plot4, = ax4.plot([], [], color = "r")
+##                    # ax4.set_xlim(0, future_length-1)
+##                    # ax4.set_ylim(0, 20)
+##                    # #ax4.set_ylim(0, 0.5)
+##                    # ax4.set_xlabel("Frames", fontsize=10)
+##                    # #ax4.set_ylabel("MSE", fontsize=10)
+##                    # ax4.xaxis.set_tick_params(labelsize=7)
+##                    # ax4.yaxis.set_tick_params(labelsize=7)
+##                    #
+##                    #
+##                    # plots = [plot1, plot2, plot3, plot4]
+##                    #
+##                    # #fig.colorbar(plots[1], ax = [ax1, ax2])
+##                    #
+##                    # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
+##                    # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
+##                    # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
+##                    #
+##                    # def animation_sample(t):
+##                    #     input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
+##                    #     gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
+##                    #     diff_image = input_gen_diff[t,:,:]
+##                    #     # p = sns.lineplot(x=x,y=data,color="b")
+##                    #     # p.tick_params(labelsize=17)
+##                    #     # plt.setp(p.lines, linewidth=6)
+##                    #     plots[0].set_data(input_image)
+##                    #     plots[1].set_data(gen_image)
+##                    #     #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+##                    #     #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+##                    #     plots[2].set_data(diff_image)
+##                    #
+##                    #     if t >= future_length:
+##                    #         #data = gen_mse_avg_[:t + 1]
+##                    #         # x = list(range(len(gen_mse_avg_)))[:t+1]
+##                    #         xdata.append(t-future_length)
+##                    #         print("xdata", xdata)
+##                    #         ydata.append(gen_mse_avg_[t])
+##                    #         print("ydata", ydata)
+##                    #         plots[3].set_data(xdata, ydata)
+##                    #         fig.suptitle("Predicted Frame " + str(t-future_length))
+##                    #     else:
+##                    #         #plots[3].set_data(xdata, ydata)
+##                    #         fig.suptitle("Context Frame " + str(t))
+##                    #     return plots
+##                    #
+##                    # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
+##                    #                               repeat_delay=2000)
+##                    # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
+##
+####                else:
+####                    pass
+##
+
+
+
+
+
+    #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     plt.xlim(0,len(gen_mse_avg_))
+    #         #     plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg))
+    #         #     plt.xlabel("Frames")
+    #         #     plt.ylabel("MSE_AVG")
+    #         #     #X = list(range(len(gen_mse_avg_)))
+    #         #     #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     def animate_metric(j):
+    #         #         data = gen_mse_avg_[:(j+1)]
+    #         #         x = list(range(len(gen_mse_avg_)))[:(j+1)]
+    #         #         p = sns.lineplot(x=x,y=data,color="b")
+    #         #         p.tick_params(labelsize=17)
+    #         #         plt.setp(p.lines, linewidth=6)
+    #         #     ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif"))
+    #         #
+    #         #
+    #         # for i, input_images_ in enumerate(input_images):
+    #         #     #context_images_ = (input_results['images'][i])
+    #         #     #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, input_image in enumerate(input_images_):
+    #         #         im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2,"Frame_" + str(t))
+    #         #         ims.append([im,ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif"))
+    #         #     #plt.show()
+    #         #
+    #         # for i,gen_images_ in enumerate(gen_images):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, gen_image in enumerate(gen_images_):
+    #         #         im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2, "Frame_" + str(t))
+    #         #         ims.append([im, ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif"))
+    #
+    #
+    #             # for i, gen_images_ in enumerate(gen_images):
+    #             #     #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8)
+    #             #     #gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
+    #             #     #bing
+    #             #     context_images_ = (input_results['images'][i])
+    #             #     gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #             #     context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
+    #             #     plt.figure(figsize = (10,2))
+    #             #     gs = gridspec.GridSpec(2,10)
+    #             #     gs.update(wspace=0.,hspace=0.)
+    #             #     for t, gen_image in enumerate(gen_images_):
+    #             #         gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1)))
+    #             #         gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #         plt.subplot(gs[t])
+    #             #         plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')  # the last index sets the channel. 0 = t2
+    #             #         # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Actual', fontsize = 10)
+    #             #
+    #             #         plt.subplot(gs[t + 10])
+    #             #         plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #             #         # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Predicted', fontsize = 10)
+    #             #     plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png')
+    #             #     plt.clf()
+    #
+    #             # if args.gif_length:
+    #             #     context_and_gen_images = context_and_gen_images[:args.gif_length]
+    #             # save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
+    #             #          context_and_gen_images, fps=args.fps)
+    #             #
+    #             # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
+    #             # for t, gen_image in enumerate(gen_images_):
+    #             #     gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #     if gen_image.shape[-1] == 1:
+    #             #       gen_image = np.tile(gen_image, (1, 1, 3))
+    #             #     else:
+    #             #       gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
+    #             #     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
index 9a62f2f29346ae1edda8b50315c0e81daa9b2c11..7948d650cc270c9ee33dcd32c2e03c70f9216225 100644
--- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
+++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
@@ -201,11 +201,6 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t
     ts_len = len(ts)
     ts_input = ts[:context_frames]
     ts_forecast = ts[context_frames:]
-    #print("context_frame:",context_frames)
-    #print("future_frame",future_length)
-    #print("length of ts input:",len(ts_input))
-
-    print("input_images_ shape in netcdf,",input_images_.shape)
     gen_images_ = np.array(gen_images_)
 
     output_file = os.path.join(output_dir,fl_name)
@@ -288,18 +283,19 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t
         #Temperature:
         t2 = nc_file.createVariable("/forecast/{}/T2".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
         t2.units = 'K'
-        t2[:,:,:] = gen_images_[context_frames:,:,:,0]
+        print ("gen_images_ 20200822:",np.array(gen_images_).shape)
+        t2[:,:,:] = gen_images_[context_frames-1:,:,:,0]
         print("NetCDF created")
 
         #mean sea level pressure
         msl = nc_file.createVariable("/forecast/{}/MSL".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
         msl.units = 'Pa'
-        msl[:,:,:] = gen_images_[context_frames:,:,:,1]
+        msl[:,:,:] = gen_images_[context_frames-1:,:,:,1]
 
         #Geopotential at 500 
         gph500 = nc_file.createVariable("/forecast/{}/GPH500".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
         gph500.units = 'm'
-        gph500[:,:,:] = gen_images_[context_frames:,:,:,2]        
+        gph500[:,:,:] = gen_images_[context_frames-1:,:,:,2]        
 
         print("{} created".format(output_file)) 
 
@@ -450,7 +446,8 @@ def main():
         #Get prediction values 
         feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
         gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel]
-        
+        assert gen_images.shape[1] == sequence_length-1 #The generate images seq_len should be sequence_len -1, since the last one is not used for comparing with groud truth 
+        print("gen_images 20200822:",np.array(gen_images).shape)       
         #Loop in batch size
         for i in range(args.batch_size):
             
@@ -458,26 +455,26 @@ def main():
             input_images_,t_start = get_one_seq_and_time(input_images,t_starts,i)
             #generate time stamps for sequences
             ts = generate_seq_timestamps(t_start,len_seq=sequence_length)
-            
+             
             #Renormalized data for inputs
             stat_fl = os.path.join(args.input_dir,"pickle/statistics.json")
             input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"])  
-            print("input_images_denorm",input_images_denorm[0][0])
+            print("input_images_denorm shape",np.array(input_images_denorm).shape)
                                                              
             #Renormalized data for inputs
             gen_images_ = gen_images[i]
             gen_images_denorm = denorm_images_all_channels(stat_fl,gen_images_,["T2","MSL","gph500"])
-            print("gene_images_denorm:",gen_images_denorm[0][0])
+            print("gene_images_denorm shape",np.array(gen_images_denorm).shape)
             
             #Save input to netCDF file
             init_date_str = ts[0].strftime("%Y%m%d%H")
             save_to_netcdf_per_sequence(args.results_dir,input_images_denorm,gen_images_denorm,lons,lats,ts,context_frames,future_length,args.model,fl_name="vfp_{}.nc".format(init_date_str))
                                                              
             #Generate images inputs
-            plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],lats=lats,lons=lons,ts=ts[:context_frames-1],label="Ground Truth",output_png_dir=args.results_dir)  
+            plot_seq_imgs(imgs=input_images_denorm[context_frames+1:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Ground Truth",output_png_dir=args.results_dir)  
                                                              
             #Generate forecast images
-            plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) 
+            plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) 
             
             #TODO: Scaret plot persistence image
             #implment get_persistence() function
diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py
index d4db02c3804c445ee434d2fbd4de01ae4ae2dd29..0417a36514fd6136fb9fbe934bfb396633fa6093 100644
--- a/video_prediction_savp/scripts/train_dummy.py
+++ b/video_prediction_savp/scripts/train_dummy.py
@@ -16,13 +16,6 @@ from json import JSONEncoder
 import pickle as pkl
 
 
-class NumpyArrayEncoder(JSONEncoder):
-    def default(self, obj):
-        if isinstance(obj, np.ndarray):
-            return obj.tolist()
-        return JSONEncoder.default(self, obj)
-
-
 def add_tag_suffix(summary, tag_suffix):
     summary_proto = tf.Summary()
     summary_proto.ParseFromString(summary)
@@ -80,7 +73,6 @@ def set_seed(seed):
         random.seed(seed)
 
 def load_params_from_checkpoints_dir(model_hparams_dict,checkpoint,dataset,model):
-   
     model_hparams_dict_load = {}
     if model_hparams_dict:
         with open(model_hparams_dict) as f:
@@ -159,8 +151,19 @@ def make_dataset_iterator(train_dataset, val_dataset, batch_size ):
     return inputs,train_handle, val_handle
 
 
-def plot_train(train_losses,val_losses,output_dir):
-    iterations = list(range(len(train_losses))) 
+def plot_train(train_losses,val_losses,step,output_dir):
+    """
+    Function to plot training losses for train and val datasets against steps
+    params:
+    train_losses/val_losses (list): train losses, which length should be equal to the number of training steps
+    step (int): current training step
+    output_dir (str): the path to save the plot
+    
+    return: None
+    """
+   
+    iterations = list(range(len(train_losses)))
+    if len(train_losses) != len(val_losses) or len(train_losses) != step +1 : raise ValueError("The length of training losses must be equal to the length of val losses and  step +1 !")  
     plt.plot(iterations, train_losses, 'g', label='Training loss')
     plt.plot(iterations, val_losses, 'b', label='validation loss')
     plt.title('Training and Validation loss')
@@ -168,6 +171,8 @@ def plot_train(train_losses,val_losses,output_dir):
     plt.ylabel('Loss')
     plt.legend()
     plt.savefig(os.path.join(output_dir,'plot_train.png'))
+    plt.close()
+    return None
 
 def save_results_to_dict(results_dict,output_dir):
     with open(os.path.join(output_dir,"results.json"),"w") as fp:
@@ -232,7 +237,9 @@ def main():
     inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size)
     
     #build model graph
-    del inputs["T_start"]
+    #since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train the model, otherwise the model will raise error 
+    if args.dataset == "era5":
+       del inputs["T_start"]
     model.build_graph(inputs)
     
     #save all the model, data params to output dirctory
@@ -255,6 +262,7 @@ def main():
     num_examples_per_epoch = train_dataset.num_examples_per_epoch()
     print ("number of exmaples per epoch:",num_examples_per_epoch)
     steps_per_epoch = int(num_examples_per_epoch/batch_size)
+    #number of steps totally equal to the number of steps per each echo multiple by number of epochs
     total_steps = steps_per_epoch * max_epochs
     global_step = tf.train.get_or_create_global_step()
     #mock total_steps only for fast debugging
@@ -276,17 +284,18 @@ def main():
         val_losses=[]
         run_start_time = time.time()        
         for step in range(start_step,total_steps):
-            #global_step = sess.run(global_step):q
- 
+            #global_step = sess.run(global_step)
+            # +++ Scarlet 20200813
+            timeit_start = time.time()  
+            # --- Scarlet 20200813
             print ("step:", step)
             val_handle_eval = sess.run(val_handle)
-
             #Fetch variables in the graph
-
             fetches = {"train_op": model.train_op}
             #fetches["latent_loss"] = model.latent_loss
             fetches["summary"] = model.summary_op 
-            
+            fetches["global_step"] = model.global_step
+
             if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
                 fetches["global_step"] = model.global_step
                 fetches["total_loss"] = model.total_loss
@@ -322,8 +331,8 @@ def main():
             val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval})
             val_losses.append(val_results["total_loss"])
 
-            summary_writer.add_summary(results["summary"])
-            summary_writer.add_summary(val_results["summary"])
+            summary_writer.add_summary(results["summary"],results["global_step"])
+            summary_writer.add_summary(val_results["summary"],results["global_step"])
             summary_writer.flush()
 
             # global_step will have the correct step count if we resume from a checkpoint
@@ -342,16 +351,30 @@ def main():
                 print ("The model name does not exist")
 
             #print("saving model to", args.output_dir)
-            saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)#
+
+            saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)
+            # +++ Scarlet 20200813
+            timeit_end = time.time()  
+            # --- Scarlet 20200813
+            print("time needed for this step", timeit_end - timeit_start, ' s')
+            if step % 20 == 0:
+                # I save the pickle file and plot here inside the loop in case the training process cannot finished after job is done.
+                save_results_to_pkl(train_losses,val_losses,args.output_dir)
+                plot_train(train_losses,val_losses,step,args.output_dir)
+                                
+
         train_time = time.time() - run_start_time
         results_dict = {"train_time":train_time,
                         "total_steps":total_steps}
         save_results_to_dict(results_dict,args.output_dir)
-        save_results_to_pkl(train_losses, val_losses, args.output_dir)
+        #save_results_to_pkl(train_losses, val_losses, args.output_dir)
         print("train_losses:",train_losses)
         print("val_losses:",val_losses) 
-        plot_train(train_losses,val_losses,args.output_dir)
+        #plot_train(train_losses,val_losses,args.output_dir)
         print("Done")
+        # +++ Scarlet 20200814
+        print("Total training time:", train_time/60., "min")
+        # +++ Scarlet 20200814
         
 if __name__ == '__main__':
     main()
diff --git a/video_prediction_savp/video_prediction/datasets/__init__.py b/video_prediction_savp/video_prediction/datasets/__init__.py
index e449a65bd48b14ef5a11e6846ce4d8f39f7ed193..f58607c2f4c14047aefb36956e98bd228a30aeb1 100644
--- a/video_prediction_savp/video_prediction/datasets/__init__.py
+++ b/video_prediction_savp/video_prediction/datasets/__init__.py
@@ -7,6 +7,7 @@ from .kth_dataset import KTHVideoDataset
 from .ucf101_dataset import UCF101VideoDataset
 from .cartgripper_dataset import CartgripperVideoDataset
 from .era5_dataset_v2 import ERA5Dataset_v2
+from .moving_mnist import MovingMnist
 #from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly
 
 def get_dataset_class(dataset):
@@ -19,6 +20,7 @@ def get_dataset_class(dataset):
         'ucf101': 'UCF101VideoDataset',
         'cartgripper': 'CartgripperVideoDataset',
         "era5":"ERA5Dataset_v2",
+        "moving_mnist":"MovingMnist"
 #        "era5_anomaly":"ERA5Dataset_v2_anomaly",
     }
     dataset_class = dataset_mappings.get(dataset, dataset)
diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
index 04d11b6357cd02fff9f822b7c9105eccd5d5b46b..7a61aa090f9e115ffd140a54dd0784dbbd35c48d 100644
--- a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
+++ b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
@@ -196,46 +196,56 @@ def read_frames_and_save_tf_records(stats,output_dir,input_file, temp_input_file
     #sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
     #Bing 2020/07/16
     #print ("open intput dir,",input_file)
-    with open(input_file, "rb") as data_file:
-        X_train = pickle.load(data_file)
-    with open(temp_input_file,"rb") as temp_file:
-        T_train = pickle.load(temp_file)
-    #print("T_train:",T_train) 
-    #check to make sure the X_train and T_train has the same length 
-    assert (len(X_train) == len(T_train))
+    try:
+        with open(input_file, "rb") as data_file:
+            X_train = pickle.load(data_file)
+        with open(temp_input_file,"rb") as temp_file:
+            T_train = pickle.load(temp_file)
+            
+        #print("T_train:",T_train) 
+        #check to make sure the X_train and T_train has the same length 
+        assert (len(X_train) == len(T_train))
+
+        X_possible_starts = [i for i in range(len(X_train) - seq_length)]
+        for X_start in X_possible_starts:
+            X_end = X_start + seq_length
+            #seq = X_train[X_start:X_end, :, :,:]
+            seq = X_train[X_start:X_end,:,:,:]
+            #Recored the start point of the timestamps
+            T_start = T_train[X_start]
+            #print("T_start:",T_start)  
+            seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
+            if not sequences:
+                last_start_sequence_iter = sequence_iter
+
+
+            sequences.append(seq)
+            T_start_points.append(T_start)
+            sequence_iter += 1    
+
+            if len(sequences) == sequences_per_file:
+                ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables
+                sequences = np.array(sequences)
+                ### normalization
+                for i in range(nvars):    
+                    sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm)
+
+                output_fname = 'sequence_Y_{}_M_{}_{}_to_{}.tfrecords'.format(year,month,last_start_sequence_iter,sequence_iter - 1)
+                output_fname = os.path.join(output_dir, output_fname)
+                print("T_start_points:",T_start_points)
+                if os.path.isfile(output_fname):
+                    print(output_fname, ' already exists, skip it')
+                else:
+                    save_tf_record(output_fname, list(sequences), T_start_points)
+                T_start_points = []
+                sequences = []
+        print("Finished for input file",input_file)
+        #sequence_lengths_file.close()
     
-    X_possible_starts = [i for i in range(len(X_train) - seq_length)]
-    for X_start in X_possible_starts:
-        X_end = X_start + seq_length
-        #seq = X_train[X_start:X_end, :, :,:]
-        seq = X_train[X_start:X_end,:,:,:]
-        #Recored the start point of the timestamps
-        T_start = T_train[X_start]
-        #print("T_start:",T_start)  
-        seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
-        if not sequences:
-            last_start_sequence_iter = sequence_iter
-           
-       
-        sequences.append(seq)
-        T_start_points.append(T_start)
-        sequence_iter += 1    
-        
-        if len(sequences) == sequences_per_file:
-            ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables
-            sequences = np.array(sequences)
-            ### normalization
-            for i in range(nvars):    
-                sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm)
-
-            output_fname = 'sequence_Y_{}_M_{}_{}_to_{}.tfrecords'.format(year,month,last_start_sequence_iter,sequence_iter - 1)
-            output_fname = os.path.join(output_dir, output_fname)
-            print("T_start_points:",T_start_points)
-            save_tf_record(output_fname, list(sequences), T_start_points)
-            T_start_points = []
-            sequences = []
-    print("Finished for input file",input_file)
-    #sequence_lengths_file.close()
+    except FileNotFoundError as fnf_error:
+        print(fnf_error)
+        pass
+
     return 
 
 def write_sequence_file(output_dir,seq_length,sequences_per_file):
@@ -274,18 +284,18 @@ def main():
     ############################################################
     partition = {
             "train":{
-               # "2222":[1,2,3,5,6,7,8,9,10,11,12],
-               # "2010_1":[1,2,3,4,5,6,7,8,9,10,11,12],
-               # "2012":[1,2,3,4,5,6,7,8,9,10,11,12],
-               # "2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12],
-               # "2015":[1,2,3,4,5,6,7,8,9,10,11,12],
-                "2017_test":[1,2,3,4,5,6,7,8,9,10]
+           #     "2222":[1,2,3,5,6,7,8,9,10,11,12], # Issue due to month 04, it is missing
+                "2010":[1,2,3,4,5,6,7,8,9,10,11,12],
+           #     "2012":[1,2,3,4,5,6,7,8,9,10,11,12],
+                "2013":[1,2,3,4,5,6,7,8,9,10,11,12],
+                "2015":[1,2,3,4,5,6,7,8,9,10,11,12],
+                "2019":[1,2,3,4,5,6,7,8,9,10,11,12]
                  },
             "val":
-                {"2017_test":[11]
+                {"2017":[1,2,3,4,5,6,7,8,9,10,11,12]
                  },
             "test":
-                {"2017_test":[12]
+                {"2016":[1,2,3,4,5,6,7,8,9,10,11,12]
                  }
             }
     
diff --git a/video_prediction_savp/video_prediction/datasets/kth_dataset.py b/video_prediction_savp/video_prediction/datasets/kth_dataset.py
index 40fb6bf57b8219fc7e75c3759df9e6b38fffeb30..e1e11d51968e706868fd89f26faa25d1999d3a9b 100644
--- a/video_prediction_savp/video_prediction/datasets/kth_dataset.py
+++ b/video_prediction_savp/video_prediction/datasets/kth_dataset.py
@@ -136,11 +136,11 @@ def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequence
 
 def main():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--input_dir", type=str, help="directory containing the processed directories "
+    parser.add_argument("input_dir", type=str, help="directory containing the processed directories "
                                                     "boxing, handclapping, handwaving, "
                                                     "jogging, running, walking")
-    parser.add_argument("--output_dir", type=str)
-    parser.add_argument("--image_size", type=int)
+    parser.add_argument("output_dir", type=str)
+    parser.add_argument("image_size", type=int)
     args = parser.parse_args()
 
     partition_names = ['train', 'val', 'test']
diff --git a/video_prediction_savp/video_prediction/datasets/moving_mnist.py b/video_prediction_savp/video_prediction/datasets/moving_mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..12a517ac9540dd443866d47ea689803abe9e9a4d
--- /dev/null
+++ b/video_prediction_savp/video_prediction/datasets/moving_mnist.py
@@ -0,0 +1,241 @@
+import argparse
+import glob
+import itertools
+import os
+import pickle
+import random
+import re
+import numpy as np
+import json
+import tensorflow as tf
+from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
+# ML 2020/04/14: hack for getting functions of process_netCDF_v2:
+from os import path
+import sys
+sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/'))
+import DataPreprocess.process_netCDF_v2 
+from DataPreprocess.process_netCDF_v2 import get_unique_vars
+from DataPreprocess.process_netCDF_v2 import Calc_data_stat
+from metadata import MetaData
+#from base_dataset import VarLenFeatureVideoDataset
+from collections import OrderedDict
+from tensorflow.contrib.training import HParams
+from mpi4py import MPI
+import glob
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+
+class MovingMnist(VarLenFeatureVideoDataset):
+    def __init__(self, *args, **kwargs):
+        super(MovingMnist, self).__init__(*args, **kwargs)
+        from google.protobuf.json_format import MessageToDict
+        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
+        dict_message = MessageToDict(tf.train.Example.FromString(example))
+        feature = dict_message['features']['feature']
+        print("features in dataset:",feature.keys())
+        self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
+        self.image_shape = self.video_shape[1:]
+        self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(MovingMnist, self).get_default_hparams_dict()
+        hparams = dict(
+            context_frames=10,#Bing: Todo oriignal is 10
+            sequence_length=20,#bing: TODO original is 20,
+            shuffle_on_val=True, 
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+
+    @property
+    def jpeg_encoding(self):
+        return False
+
+
+
+    def num_examples_per_epoch(self):
+        with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
+            sequence_lengths = sequence_lengths_file.readlines()
+        sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
+        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
+
+    def filter(self, serialized_example):
+        return tf.convert_to_tensor(True)
+
+
+    def make_dataset_v2(self, batch_size):
+        def parser(serialized_example):
+            seqs = OrderedDict()
+            keys_to_features = {
+                'width': tf.FixedLenFeature([], tf.int64),
+                'height': tf.FixedLenFeature([], tf.int64),
+                'sequence_length': tf.FixedLenFeature([], tf.int64),
+                'channels': tf.FixedLenFeature([], tf.int64),
+                'images/encoded': tf.VarLenFeature(tf.float32)
+            }
+            
+            # for i in range(20):
+            #     keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
+            parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
+            print ("Parse features", parsed_features)
+            seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
+            #width = tf.sparse_tensor_to_dense(parsed_features["width"])
+           # height = tf.sparse_tensor_to_dense(parsed_features["height"])
+           # channels  = tf.sparse_tensor_to_dense(parsed_features["channels"])
+           # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
+            images = []
+            print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
+            images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
+            seqs["images"] = images
+            return seqs
+        filenames = self.filenames
+        print ("FILENAMES",filenames)
+	    #TODO:
+	    #temporal_filenames = self.temporal_filenames
+        shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
+        if shuffle:
+            random.shuffle(filenames)
+        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)  # todo: what is buffer_size
+        print("files", self.filenames)
+        print("mode", self.mode)
+        dataset = dataset.filter(self.filter)
+        if shuffle:
+            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
+        else:
+            dataset = dataset.repeat(self.num_epochs)
+
+        num_parallel_calls = None if shuffle else 1
+        dataset = dataset.apply(tf.contrib.data.map_and_batch(
+            parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
+        #dataset = dataset.map(parser)
+        # num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
+        # dataset = dataset.apply(tf.contrib.data.map_and_batch(
+        #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
+        dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU
+        return dataset
+
+
+
+    def make_batch(self, batch_size):
+        dataset = self.make_dataset_v2(batch_size)
+        iterator = dataset.make_one_shot_iterator()
+        return iterator.get_next()
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _bytes_list_feature(values):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
+
+def _floats_feature(value):
+    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+def save_tf_record(output_fname, sequences):
+    with tf.python_io.TFRecordWriter(output_fname) as writer:
+        for i in range(len(sequences)):
+            sequence = sequences[:,i,:,:,:] 
+            num_frames = len(sequence)
+            height, width = sequence[0,:,:,0].shape
+            encoded_sequence = np.array([list(image) for image in sequence])
+            features = tf.train.Features(feature={
+                'sequence_length': _int64_feature(num_frames),
+                'height': _int64_feature(height),
+                'width': _int64_feature(width),
+                'channels': _int64_feature(1),
+                'images/encoded': _floats_feature(encoded_sequence.flatten()),
+            })
+            example = tf.train.Example(features=features)
+            writer.write(example.SerializeToString())
+
+def read_frames_and_save_tf_records(output_dir,dat_npz, seq_length=20, sequences_per_file=128, height=64, width=64):#Bing: original 128
+    """
+    Read the moving_mnst data which is npz format, and save it to tfrecords files
+    The shape of dat_npz is [seq_length,number_samples,height,width]
+    moving_mnst only has one channel
+
+    """
+    os.makedirs(output_dir,exist_ok=True)
+    idx = 0
+    num_samples = dat_npz.shape[1]
+    dat_npz = np.expand_dims(dat_npz, axis=4) #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel]
+    print("data_npz_shape",dat_npz.shape)
+    dat_npz = dat_npz.astype(np.float32)
+    dat_npz /= 255.0 #normalize RGB codes by dividing it to the max RGB value 
+    while idx < num_samples - sequences_per_file:
+        sequences = dat_npz[:,idx:idx+sequences_per_file,:,:,:]
+        output_fname = 'sequence_{}_{}.tfrecords'.format(idx,idx+sequences_per_file)
+        output_fname = os.path.join(output_dir, output_fname)
+        save_tf_record(output_fname, sequences)
+        idx = idx + sequences_per_file
+    return None
+
+
+def write_sequence_file(output_dir,seq_length,sequences_per_file):    
+    partition_names = ["train","val","test"]
+    for partition_name in partition_names:
+        save_output_dir = os.path.join(output_dir,partition_name)
+        tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords"))
+        print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter))
+        sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w')
+        for i in range(tfCounter*sequences_per_file):
+            sequence_lengths_file.write("%d\n" % seq_length)
+        sequence_lengths_file.close()
+
+
+def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"):
+    """
+    Plot the seq images 
+    """
+
+    if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)")
+    img_len = imgs.shape[0]
+    fig = plt.figure(figsize=(18,6))
+    gs = gridspec.GridSpec(1, 10)
+    gs.update(wspace = 0., hspace = 0.)
+    for i in range(img_len):
+        ax1 = plt.subplot(gs[i])
+        plt.imshow(imgs[i] ,cmap = 'jet')
+        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+    plt.savefig(os.path.join(output_png_dir, label + "_" +   str(idx) +  ".jpg"))
+    print("images_saved")
+    plt.clf()
+
+
+    
+    
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
+    parser.add_argument("output_dir", type=str)
+    parser.add_argument("-sequences_per_file",type=int,default=2)
+    args = parser.parse_args()
+    current_path = os.getcwd()
+    data = np.load(os.path.join(args.input_dir,"mnist_test_seq.npy"))
+    print("data in minist_test_Seq shape",data.shape)
+    seq_length =  data.shape[0]
+    height = data.shape[2]
+    width = data.shape[3]
+    num_samples = data.shape[1] 
+    max_npz = np.max(data)
+    min_npz = np.min(data)
+    print("max_npz,",max_npz)
+    print("min_npz",min_npz)
+    #Todo need to discuss how to split the data, since we have totally 10000 samples, the origin paper convLSTM used 10000 as training, 2000 as validation and 3000 for testing
+    dat_train = data[:,:6000,:,:]
+    dat_val = data[:,6000:7000,:,:]
+    dat_test = data[:,7000:,:]
+    #plot_seq_imgs(dat_test[10:,0,:,:],output_png_dir="/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist/convLSTM",idx=1,label="Ground Truth from npz")
+    #save train
+    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"train"),dat_train, seq_length=20, sequences_per_file=40, height=height, width=width)
+    #save val
+    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"val"),dat_val, seq_length=20, sequences_per_file=40, height=height, width=width)
+    #save test     
+    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"test"),dat_test, seq_length=20, sequences_per_file=40, height=height, width=width)
+    #write_sequence_file(output_dir=args.output_dir,seq_length=20,sequences_per_file=40)
+if __name__ == '__main__':
+     main()
+
diff --git a/video_prediction_savp/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction_savp/video_prediction/layers/BasicConvLSTMCell.py
index 321f6cc7e05320cf83e1173d8004429edf07ec24..c4a095dc8fc3abdbd87c1eaf79adcd7dad99020b 100644
--- a/video_prediction_savp/video_prediction/layers/BasicConvLSTMCell.py
+++ b/video_prediction_savp/video_prediction/layers/BasicConvLSTMCell.py
@@ -88,10 +88,14 @@ class BasicConvLSTMCell(ConvRNNCell):
             else:
                 c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state)
             concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True)
-
+            print("concat1:",concat)
             # i = input_gate, j = new_input, f = forget_gate, o = output_gate
             i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat)
-
+            print("input gate i:",i)
+            print("new_input j:",j)
+            print("forget gate:",f)
+            print("output gate:",o)
+           
             new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) *
                      self._activation(j))
             new_h = self._activation(new_c) * tf.nn.sigmoid(o)
@@ -100,6 +104,8 @@ class BasicConvLSTMCell(ConvRNNCell):
                 new_state = LSTMStateTuple(new_c, new_h)
             else:
                 new_state = tf.concat(axis = 3, values = [new_c, new_h])
+            print("new h", new_h)
+            print("new state",new_state)
             return new_h, new_state
 
 
@@ -135,9 +141,14 @@ def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=No
         matrix = tf.get_variable(
             "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype)
         if len(args) == 1:
+            print("args[0]:",args[0])
             res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME')
+            print("res1:",res)
         else:
+            print("matrix:",matrix)
+            print("tf.concat(axis = 3, values = args):",tf.concat(axis = 3, values = args))
             res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME')
+            print("res2:",res)
         if not bias:
             return res
         bias_term = tf.get_variable(
@@ -146,3 +157,4 @@ def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=No
             initializer = tf.constant_initializer(
                 bias_start, dtype = dtype))
     return res + bias_term
+
diff --git a/video_prediction_savp/video_prediction/layers/layer_def.py b/video_prediction_savp/video_prediction/layers/layer_def.py
index a59643c7a6d69141134ec01c9b147c4798bfed8e..273b5eaee3cab703841b214ccc09ef190b6dd3ae 100644
--- a/video_prediction_savp/video_prediction/layers/layer_def.py
+++ b/video_prediction_savp/video_prediction/layers/layer_def.py
@@ -3,8 +3,8 @@
 
 import tensorflow as tf
 import numpy as np
-
 weight_decay = 0.0005
+
 def _activation_summary(x):
     """Helper to create summaries for activations.
     Creates a summary that provides a histogram of activations.
@@ -55,8 +55,7 @@ def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.l
 
 
 def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"):
-    print("conv_layer activation function",activate)
-    
+    print("conv_layer activation function",activate) 
     with tf.variable_scope('{0}_conv'.format(idx)) as scope:
  
         input_channels = inputs.get_shape()[-1]
@@ -74,6 +73,8 @@ def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.co
             conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx))   
         elif activate == "leaky_relu":
             conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx))
+        elif activate == "sigmoid":
+            conv_rect = tf.nn.sigmoid(conv_biased, name = '{0}_conv'.format(idx))
         else:
             raise ("activation function is not correct")
         return conv_rect
@@ -157,4 +158,4 @@ def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None):
 
 def bn_layers_wrapper(inputs, is_training):
     pass
-   
\ No newline at end of file
+   
diff --git a/video_prediction_savp/video_prediction/models/__init__.py b/video_prediction_savp/video_prediction/models/__init__.py
index 6d7323f3750949b0ddb411d4a98934928537bc53..8010c4eeb2123fd94995c6474e7e1c8af6b02113 100644
--- a/video_prediction_savp/video_prediction/models/__init__.py
+++ b/video_prediction_savp/video_prediction/models/__init__.py
@@ -21,7 +21,6 @@ def get_model_class(model):
         'vae': 'VanillaVAEVideoPredictionModel',
         'convLSTM': 'VanillaConvLstmVideoPredictionModel',
         'mcnet': 'McNetVideoPredictionModel',
-        
         }
     model_class = model_mappings.get(model, model)
     model_class = globals().get(model_class)
diff --git a/video_prediction_savp/video_prediction/models/base_model.py b/video_prediction_savp/video_prediction/models/base_model.py
index 846621d8ca1e235c39618951be86fe184a2d974d..0d3bf6e4b554c70671d4678b530688c44f999b77 100644
--- a/video_prediction_savp/video_prediction/models/base_model.py
+++ b/video_prediction_savp/video_prediction/models/base_model.py
@@ -3,12 +3,10 @@ import itertools
 import os
 import re
 from collections import OrderedDict
-
 import numpy as np
 import tensorflow as tf
 from tensorflow.contrib.training import HParams
 from tensorflow.python.util import nest
-
 import video_prediction as vp
 from video_prediction.utils import tf_utils
 from video_prediction.utils.tf_utils import compute_averaged_gradients, reduce_tensors, local_device_setter, \
@@ -244,7 +242,9 @@ class BaseVideoPredictionModel(object):
                 savers.append(saver)
             restore_op = [saver.saver_def.restore_op_name for saver in savers]
             sess.run(restore_op)
-
+            return True
+        else:
+            return False
 
 class VideoPredictionModel(BaseVideoPredictionModel):
     def __init__(self,
@@ -487,6 +487,7 @@ class VideoPredictionModel(BaseVideoPredictionModel):
         # skip_vars = {" discriminator_encoder/video_sn_fc4/dense/bias"}
 
         if self.num_gpus <= 1:  # cpu or 1 gpu
+            print("self.inputs:>20200822",self.inputs)
             outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs)
             self.outputs, self.eval_outputs = outputs_tuple
             self.d_losses, self.g_losses, g_losses_post = losses_tuple
diff --git a/video_prediction_savp/video_prediction/models/savp_model.py b/video_prediction_savp/video_prediction/models/savp_model.py
index c510d050c89908d0e06fe0f1a66e355e61c90530..039533864f34d1608c5a10d4a664d40ce73594a7 100644
--- a/video_prediction_savp/video_prediction/models/savp_model.py
+++ b/video_prediction_savp/video_prediction/models/savp_model.py
@@ -689,6 +689,10 @@ class SAVPCell(tf.nn.rnn_cell.RNNCell):
 def generator_given_z_fn(inputs, mode, hparams):
     # all the inputs needs to have the same length for unrolling the rnn
     print("inputs.items",inputs.items())
+    #20200822 bing
+    inputs ={"images":inputs["images"]}
+    print("inputs 20200822:",inputs)
+    #20200822
     inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
               for name, input in inputs.items()}
     cell = SAVPCell(inputs, mode, hparams)
diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
index 5a8a2e1f3fffe5c66d5b93e53137300bf792317e..796486a453f9dc6807928deeb2b8962e2908a4f2 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
@@ -28,32 +28,29 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
         self.sequence_length = self.hparams.sequence_length
         self.predict_frames = self.sequence_length - self.context_frames
         self.max_epochs = self.hparams.max_epochs
+        self.loss_fun = self.hparams.loss_fun
+
+
     def get_default_hparams_dict(self):
         """
         The keys of this dict define valid hyperparameters for instances of
         this class. A class inheriting from this one should override this
         method if it has a different set of hyperparameters.
-
         Returns:
             A dict with the following hyperparameters.
-
             batch_size: batch size for training.
             lr: learning rate. if decay steps is non-zero, this is the
                 learning rate for steps <= decay_step.
-            max_steps: number of training steps.
-            context_frames: the number of ground-truth frames to pass :qin at
-                start. Must be specified during instantiation.
-            sequence_length: the number of frames in the video sequence,
-                including the context frames, so this model predicts
-                `sequence_length - context_frames` future frames. Must be
-                specified during instantiation.
-        """
+            max_epochs: number of training epochs, each epoch equal to sample_size/batch_size
+            loss_fun: string can be either "rmse" or "cross_entropy", loss function has to be set from the user 
+         """
         default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict()
         print ("default hparams",default_hparams)
         hparams = dict(
             batch_size=16,
             lr=0.001,
             max_epochs=3000,
+            loss_fun = None
         )
 
         return dict(itertools.chain(default_hparams.items(), hparams.items()))
@@ -70,9 +67,22 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
         #self.context_frames_loss = tf.reduce_mean(
         #    tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
         # This is the loss function (RMSE):
-        self.total_loss = tf.reduce_mean(
-            tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_context_frames[:, (self.context_frames-1):-1, :, :, 0]))
-
+        #This is loss function only for 1 channel (temperature RMSE)
+        if self.loss_fun == "rmse":
+            self.total_loss = tf.reduce_mean(
+                tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0]))
+        elif self.loss_fun == "cross_entropy":
+            x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1])
+            x_hat_predict_frames_flatten = tf.reshape(self.x_hat_predict_frames[:,:,:,:,0],[-1])
+            bce = tf.keras.losses.BinaryCrossentropy()
+            self.total_loss = bce(x_flatten,x_hat_predict_frames_flatten)  
+        else:
+            raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")
+
+        #This is the loss for only all the channels(temperature, geo500, pressure)
+        #self.total_loss = tf.reduce_mean(
+        #    tf.square(self.x[:, self.context_frames:,:,:,:] - self.x_hat_predict_frames[:,:,:,:,:]))            
+ 
         self.train_op = tf.train.AdamOptimizer(
             learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
         self.outputs = {}
@@ -84,60 +94,40 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
         self.saveable_variables = [self.global_step] + global_variables
         return None
 
-
     @staticmethod
     def convLSTM_cell(inputs, hidden):
-
-        conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu")
-
-        conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2", activate = "leaky_relu")
-
-        conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3", activate = "leaky_relu")
-
-        y_0 = conv3
+        y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
+        channels = inputs.get_shape()[-1]
         # conv lstm cell
         cell_shape = y_0.get_shape().as_list()
+        channels = cell_shape[-1]
         with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
-            cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size = [3, 3], num_features = 8)
+            cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size = [5, 5], num_features = 256)
             if hidden is None:
                 hidden = cell.zero_state(y_0, tf.float32)
-
             output, hidden = cell(y_0, hidden)
-
-
         output_shape = output.get_shape().as_list()
-
-
         z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
-
-        conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, "decode_5", activate = "leaky_relu")
-
-
-        conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6", activate = "leaky_relu")
-
-
-        x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, "decode_7", activate = "sigmoid")  # set activation to linear
-
+        #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction
+        x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid")
         return x_hat, hidden
 
     def convLSTM_network(self):
         network_template = tf.make_template('network',
                                             VanillaConvLstmVideoPredictionModel.convLSTM_cell)  # make the template to share the variables
         # create network
-        x_hat_context = []
         x_hat = []
-        hidden = None
-        #This is for training 
-        for i in range(self.sequence_length):
-            if i < self.context_frames:
-                x_1, hidden = network_template(self.x[:, i, :, :, :], hidden)
-            else:
-                x_1, hidden = network_template(x_1, hidden)
-            x_hat_context.append(x_1)
+
+        # for i in range(self.sequence_length-1):
+        #     if i < self.context_frames:
+        #         x_1, hidden = network_template(self.x[:, i, :, :, :], hidden)
+        #     else:
+        #         x_1, hidden = network_template(x_1, hidden)
+        #     x_hat_context.append(x_1)
         
-        #This is for generating video
+        #This is for training (optimization of convLSTM layer)
         hidden_g = None
-        for i in range(self.sequence_length):
+        for i in range(self.sequence_length-1):
             if i < self.context_frames:
                 x_1_g, hidden_g = network_template(self.x[:, i, :, :, :], hidden_g)
             else:
@@ -145,8 +135,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
             x_hat.append(x_1_g)
         
         # pack them all together
-        x_hat_context = tf.stack(x_hat_context)
         x_hat = tf.stack(x_hat)
-        self.x_hat_context_frames = tf.transpose(x_hat_context, [1, 0, 2, 3, 4])  # change first dim with sec dim
         self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4])  # change first dim with sec dim
-        self.x_hat_predict_frames = self.x_hat[:,self.context_frames:,:,:,:]
+        self.x_hat_predict_frames = self.x_hat[:,self.context_frames-1:,:,:,:]
+