diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh
new file mode 100755
index 0000000000000000000000000000000000000000..a81e9a1499ce2619c6d934d32396c7128bd6b565
--- /dev/null
+++ b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh
@@ -0,0 +1,37 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+##SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=1
+#SBATCH --output=DataPreprocess_to_tf-out.%j
+#SBATCH --error=DataPreprocess_to_tf-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --partition=devel
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+
+
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
+
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+
+source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist 
+destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist
+
+# run Preprocessing (step 2 where Tf-records are generated)
+srun python ../video_prediction/datasets/moving_mnist.py ${source_dir} ${destination_dir}/tfrecords
diff --git a/video_prediction_savp/HPC_scripts/generate_movingmnist.sh b/video_prediction_savp/HPC_scripts/generate_movingmnist.sh
new file mode 100755
index 0000000000000000000000000000000000000000..f6a36d1a366a1e1492546e8a8ac14760cc434098
--- /dev/null
+++ b/video_prediction_savp/HPC_scripts/generate_movingmnist.sh
@@ -0,0 +1,44 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+##SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=1
+#SBATCH --output=generate_era5-out.%j
+#SBATCH --error=generate_era5-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --gres=gpu:1
+#SBATCH --partition=develgpus
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+##jutil env activate -p cjjsc42
+
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
+
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist
+checkpoint_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist
+results_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist
+# name of model
+model=convLSTM
+
+# run postprocessing/generation of model results including evaluation metrics
+srun python -u ../scripts/generate_movingmnist.py \
+--input_dir ${source_dir}/ --dataset_hparams sequence_length=20 --checkpoint  ${checkpoint_dir}/${model} \
+--mode test --model ${model} --results_dir ${results_dir}/${model} --batch_size 2 --dataset era5   > generate_era5-out.out
+
+#srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
diff --git a/video_prediction_savp/HPC_scripts/train_movingmnist.sh b/video_prediction_savp/HPC_scripts/train_movingmnist.sh
index 1ceaebbfcfe62bca30c9dd728b61f6ed2bc22d4d..f62d333dbf01db0affbf72a3e1ef1ecd96b94ec7 100755
--- a/video_prediction_savp/HPC_scripts/train_movingmnist.sh
+++ b/video_prediction_savp/HPC_scripts/train_movingmnist.sh
@@ -38,7 +38,7 @@ destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/model
 
 # for choosing the model, convLSTM,savp, mcnet,vae,convLSTM_Loliver
 model=convLSTM
-model_hparams=../hparams/era5/${model}/model_hpain_movingmnist.shrams.json
+model_hparams=../hparams/era5/${model}/model_hparams.json
 
 # rund training
 srun python ../scripts/train_moving_mnist.py --input_dir  ${source_dir}/tfrecords/ --dataset moving_mnist  --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/  --checkpoint ${destination_dir}/${model}/ 
diff --git a/video_prediction_savp/scripts/generate_movingmnist.py b/video_prediction_savp/scripts/generate_movingmnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ec2af488c81dddeef6bff2deeb867c4e7b4ffed
--- /dev/null
+++ b/video_prediction_savp/scripts/generate_movingmnist.py
@@ -0,0 +1,822 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import math
+import random
+import cv2
+import numpy as np
+import tensorflow as tf
+import pickle
+from random import seed
+import random
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import matplotlib.animation as animation
+import pandas as pd
+import re
+from video_prediction import datasets, models
+from matplotlib.colors import LinearSegmentedColormap
+#from matplotlib.ticker import MaxNLocator
+#from video_prediction.utils.ffmpeg_gif import save_gif
+from skimage.metrics import structural_similarity as ssim
+import datetime
+# Scarlet 2020/05/28: access to statistical values in json file 
+from os import path
+import sys
+sys.path.append(path.abspath('../video_prediction/datasets/'))
+from era5_dataset_v2 import Norm_data
+from os.path import dirname
+from netCDF4 import Dataset,date2num
+from metadata import MetaData as MetaData
+
+def set_seed(seed):
+    if seed is not None:
+        tf.set_random_seed(seed)
+        np.random.seed(seed)
+        random.seed(seed) 
+
+def get_coordinates(metadata_fname):
+    """
+    Retrieves the latitudes and longitudes read from the metadata json file.
+    """
+    md = MetaData(json_file=metadata_fname)
+    md.get_metadata_from_file(metadata_fname)
+    
+    try:
+        print("lat:",md.lat)
+        print("lon:",md.lon)
+        return md.lat, md.lon
+    except:
+        raise ValueError("Error when handling: '"+metadata_fname+"'")
+    
+
+def load_checkpoints_and_create_output_dirs(checkpoint,dataset,model):
+    if checkpoint:
+        checkpoint_dir = os.path.normpath(checkpoint)
+        if not os.path.isdir(checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % checkpoint)
+            options = json.loads(f.read())
+            dataset = dataset or options['dataset']
+            model = model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+    else:
+        if not dataset:
+            raise ValueError('dataset is required when checkpoint is not specified')
+        if not model:
+            raise ValueError('model is required when checkpoint is not specified')
+
+    return options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict
+
+
+    
+def setup_dataset(dataset,input_dir,mode,seed,num_epochs,dataset_hparams,dataset_hparams_dict):
+    VideoDataset = datasets.get_dataset_class(dataset)
+    dataset = VideoDataset(
+        input_dir,
+        mode = mode,
+        num_epochs = num_epochs,
+        seed = seed,
+        hparams_dict = dataset_hparams_dict,
+        hparams = dataset_hparams)
+    return dataset
+
+
+def setup_dirs(input_dir,results_png_dir):
+    input_dir = args.input_dir
+    temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/"
+    print ("temporal_dir:",temporal_dir)
+
+
+def update_hparams_dict(model_hparams_dict,dataset):
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    return hparams_dict
+
+
+def psnr(img1, img2):
+    mse = np.mean((img1 - img2) ** 2)
+    if mse == 0: return 100
+    PIXEL_MAX = 1
+    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+
+
+def setup_num_samples_per_epoch(num_samples, dataset):
+    if num_samples:
+        if num_samples > dataset.num_examples_per_epoch():
+            raise ValueError('num_samples cannot be larger than the dataset')
+        num_examples_per_epoch = num_samples
+    else:
+        num_examples_per_epoch = dataset.num_examples_per_epoch()
+    #if num_examples_per_epoch % args.batch_size != 0:
+    #    raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
+    return num_examples_per_epoch
+
+
+def initia_save_data():
+    sample_ind = 0
+    gen_images_all = []
+    #Bing:20200410
+    persistent_images_all = []
+    input_images_all = []
+    return sample_ind, gen_images_all,persistent_images_all, input_images_all
+
+
+def write_params_to_results_dir(args,output_dir,dataset,model):
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+    with open(os.path.join(output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys = True, indent = 4))
+    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4))
+    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4))
+    return None
+
+def get_one_seq_and_time(input_images,i):
+    assert (len(np.array(input_images).shape)==5)
+    input_images_ = input_images[i,:,:,:,:]
+    return input_images_
+
+
+def denorm_images_all_channels(input_images_):
+    input_images_all_channles_denorm = []
+    input_images_ = np.array(input_images_)
+    input_images_denorm = input_images_ * 255.0
+    #print("input_images_denorm shape",input_images_denorm.shape)
+    return input_images_denorm
+
+def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"):
+    """
+    Plot the seq images 
+    """
+
+    if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)")
+    img_len = imgs.shape[0]
+    fig = plt.figure(figsize=(18,6))
+    gs = gridspec.GridSpec(1, 10)
+    gs.update(wspace = 0., hspace = 0.)
+    for i in range(img_len):      
+        ax1 = plt.subplot(gs[i])
+        plt.imshow(imgs[i] ,cmap = 'jet')
+        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+    plt.savefig(os.path.join(output_png_dir, label + "_" +   str(idx) +  ".jpg"))
+    print("images_saved")
+    plt.clf()
+ 
+
+    
+def get_persistence(ts):
+    pass
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type = str, required = True,
+                        help = "either a directory containing subdirectories "
+                               "train, val, test, etc, or a directory containing "
+                               "the tfrecords")
+    parser.add_argument("--results_dir", type = str, default = 'results',
+                        help = "ignored if output_gif_dir is specified")
+    parser.add_argument("--checkpoint",
+                        help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+    parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val',
+                        help = 'mode for dataset, val or test.')
+    parser.add_argument("--dataset", type = str, help = "dataset class name")
+    parser.add_argument("--dataset_hparams", type = str,
+                        help = "a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--model", type = str, help = "model class name")
+    parser.add_argument("--model_hparams", type = str,
+                        help = "a string of comma separated list of model hyperparameters")
+    parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch")
+    parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)")
+    parser.add_argument("--num_epochs", type = int, default = 1)
+    parser.add_argument("--num_stochastic_samples", type = int, default = 1)
+    parser.add_argument("--gif_length", type = int, help = "default is sequence_length")
+    parser.add_argument("--fps", type = int, default = 4)
+    parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use")
+    parser.add_argument("--seed", type = int, default = 7)
+    args = parser.parse_args()
+    set_seed(args.seed)
+
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+
+    options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict = load_checkpoints_and_create_output_dirs(args.checkpoint,args.dataset,args.model)
+    print("Step 1 finished")
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    #setup dataset and model object
+    input_dir_tf = os.path.join(args.input_dir, "tfrecords") # where tensorflow records are stored
+    dataset = setup_dataset(dataset,input_dir_tf,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict)
+    
+    print("Step 2 finished")
+    VideoPredictionModel = models.get_model_class(model)
+    
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    
+    model = VideoPredictionModel(
+        mode = args.mode,
+        hparams_dict = hparams_dict,
+        hparams = args.model_hparams)
+
+    sequence_length = model.hparams.sequence_length
+    context_frames = model.hparams.context_frames
+    future_length = sequence_length - context_frames #context_Frames is the number of input frames
+
+    num_examples_per_epoch = setup_num_samples_per_epoch(args.num_samples,dataset)
+    
+    inputs = dataset.make_batch(args.batch_size)
+    print("inputs",inputs)
+    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    print("input_phs",input_phs)
+    
+    
+    # Build graph
+    with tf.variable_scope(''):
+        model.build_graph(input_phs)
+
+    #Write the update hparameters into results_dir    
+    write_params_to_results_dir(args=args,output_dir=args.results_dir,dataset=dataset,model=model)
+        
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac)
+    config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True)
+    sess = tf.Session(config = config)
+    sess.graph.as_default()
+    sess.run(tf.global_variables_initializer())
+    sess.run(tf.local_variables_initializer())
+    model.restore(sess, args.checkpoint)
+    
+    #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data
+    sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data()
+    
+    is_first=True
+    #loop for in samples
+    while sample_ind < 5:
+        gen_images_stochastic = []
+        if args.num_samples and sample_ind >= args.num_samples:
+            break
+        try:
+            input_results = sess.run(inputs)
+            input_images = input_results["images"]
+            #get the intial times
+            t_starts = input_results["T_start"]
+        except tf.errors.OutOfRangeError:
+            break
+            
+        #Get prediction values 
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel]
+        print("gen_images 20200822:",np.array(gen_images).shape)       
+        #Loop in batch size
+        for i in range(args.batch_size):
+            
+            #get one seq and the corresponding start time point
+            input_images_ = get_one_seq_and_time(input_images,i)
+            
+            #Renormalized data for inputs
+            input_images_denorm = denorm_images_all_channels(input_images_)  
+            print("input_images_denorm",input_images_denorm[0][0])
+                                                             
+            #Renormalized data for inputs
+            gen_images_ = gen_images[i]
+            gen_images_denorm = denorm_images_all_channels(gen_images_)
+            print("gene_images_denorm:",gen_images_denorm[0][0])
+            
+            #Generate images inputs
+            plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],idx = sample_ind + i, label="Ground Truth",output_png_dir=args.results_dir)  
+                                                             
+            #Generate forecast images
+            plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],idx = sample_ind + i,label="Forecast by Model " + args.model,output_png_dir=args.results_dir) 
+            
+            #TODO: Scaret plot persistence image
+            #implment get_persistence() function
+
+            #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation
+
+        sample_ind += args.batch_size
+
+
+        #for input_image in input_images_:
+
+#             for stochastic_sample_ind in range(args.num_stochastic_samples):
+#                 input_images_all.extend(input_images)
+#                 with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files:
+#                     pickle.dump(list(input_images_all), input_files)
+
+
+#                 gen_images_stochastic.append(gen_images)
+#                 #print("Stochastic_sample,", stochastic_sample_ind)
+#                 for i in range(args.batch_size):
+#                     #bing:20200417
+#                     t_stampe = test_temporal_pkl[sample_ind+i]
+#                     print("timestamp:",type(t_stampe))
+#                     persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1)
+#                     print ("persistent ts",persistent_ts)
+#                     persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts))
+#                     persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length]
+#                     print("persistent index in test set:", persistent_idx)
+#                     print("persistent_X.shape",persistent_X.shape)
+#                     persistent_images_all.append(persistent_X)
+
+#                     cmap_name = 'my_list'
+#                     if sample_ind < 100:
+#                         #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
+#                         #    sample_ind) + " + Sample_" + str(i)
+#                         name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S")
+#                         print ("name",name)
+#                         gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
+#                         #gen_images_ =  gen_images[i, :]
+#                         input_images_ = input_images[i, :]
+#                         #Bing:20200417
+#                         #persistent_images = ?
+#                         #+++Scarlet:20200528   
+#                         #print('Scarlet1')
+#                         input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm)
+#                         persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm)
+#                         #---Scarlet:20200528    
+#                         gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
+#                                         range(sequence_length)]  # return the list with 10 (sequence) mse
+#                         persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in
+#                                         range(sequence_length)]  # return the list with 10 (sequence) mse
+
+#                         fig = plt.figure(figsize=(18,6))
+#                         gs = gridspec.GridSpec(1, 10)
+#                         gs.update(wspace = 0., hspace = 0.)
+#                         ts = list(range(10,20)) #[10,11,12,..]
+#                         xlables = [round(i,2) for i  in list(np.linspace(np.min(lon),np.max(lon),5))]
+#                         ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+
+#                         for t in ts:
+
+#                             #if t==0 : ax1=plt.subplot(gs[t])
+#                             ax1 = plt.subplot(gs[ts.index(t)])
+#                             #+++Scarlet:20200528
+#                             #print('Scarlet2')
+#                             input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm)
+#                             #---Scarlet:20200528
+#                             plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
+#                             ax1.title.set_text("t = " + str(t+1-10))
+#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+#                             if t == 0:
+#                                 plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
+#                                 plt.ylabel("Ground Truth", fontsize=10)
+#                         plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
+#                         plt.clf()
+
+#                         fig = plt.figure(figsize=(12,6))
+#                         gs = gridspec.GridSpec(1, 10)
+#                         gs.update(wspace = 0., hspace = 0.)
+
+#                         for t in ts:
+#                             #if t==0 : ax1=plt.subplot(gs[t])
+#                             ax1 = plt.subplot(gs[ts.index(t)])
+#                             #+++Scarlet:20200528
+#                             #print('Scarlet3')
+#                             gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm)
+#                             #---Scarlet:20200528
+#                             plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
+#                             ax1.title.set_text("t = " + str(t+1-10))
+#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+#                         plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
+#                         plt.clf()
+
+
+#                         fig = plt.figure(figsize=(12,6))
+#                         gs = gridspec.GridSpec(1, 10)
+#                         gs.update(wspace = 0., hspace = 0.)
+#                         for t in ts:
+#                             #if t==0 : ax1=plt.subplot(gs[t])
+#                             ax1 = plt.subplot(gs[ts.index(t)])
+#                             #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+#                             plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300)
+#                             ax1.title.set_text("t = " + str(t+1-10))
+#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+#                         plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg"))
+#                         plt.clf()
+
+                        
+#                 with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files:
+#                     pickle.dump(list(persistent_images_all), input_files)
+#                     print ("Save persistent all")
+#                 if is_first:
+#                     gen_images_all = gen_images_stochastic
+#                     is_first = False
+#                 else:
+#                     gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
+
+#                 if args.num_stochastic_samples == 1:
+#                     with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files:
+#                         pickle.dump(list(gen_images_all[0]), gen_files)
+#                         print ("Save generate all")
+#                 else:
+#                     with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
+#                         pickle.dump(list(gen_images_stochastic), gen_files)
+#                     with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
+#                         pickle.dump(list(gen_images_all), gen_files)
+
+#         sample_ind += args.batch_size
+
+
+#     with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files:
+#         input_images_all = pickle.load(input_files)
+
+#     with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files:
+#         gen_images_all = pickle.load(gen_files)
+
+#     with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files:
+#         persistent_images_all = pickle.load(gen_files)
+
+#     #+++Scarlet:20200528
+#     #print('Scarlet4')
+#     input_images_all = np.array(input_images_all)
+#     input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm)
+#     #---Scarlet:20200528
+#     persistent_images_all = np.array(persistent_images_all)
+#     if len(np.array(gen_images_all).shape) == 6:
+#         for i in range(len(gen_images_all)):
+#             #+++Scarlet:20200528
+#             #print('Scarlet5')
+#             gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:]
+#             gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm)
+#             #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+#             #---Scarlet:20200528
+#             mse_all = []
+#             psnr_all = []
+#             ssim_all = []
+#             f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w')
+#             for i in range(future_length):
+#                 mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :,
+#                                                                             0]) ** 2)  # look at all timesteps except the first
+#                 psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0])
+#                 ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0],
+#                                   data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min(
+#                                       input_images_all[:, i + 10, :, :, 0].flatten()))
+#                 mse_all.extend([mse_model])
+#                 psnr_all.extend([psnr_model])
+#                 ssim_all.extend([ssim_model])
+#                 results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
+#                 f.write("##########Predicted Frame {}\n".format(str(i + 1)))
+#                 f.write("Model MSE: %f\n" % mse_model)
+#                 # f.write("Previous Frame MSE: %f\n" % mse_prev)
+#                 f.write("Model PSNR: %f\n" % psnr_model)
+#                 f.write("Model SSIM: %f\n" % ssim_model)
+
+
+#             pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb"))
+#             # f.write("Previous frame PSNR: %f\n" % psnr_prev)
+#             f.write("Shape of X_test: " + str(input_images_all.shape))
+#             f.write("")
+#             f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape))
+
+#     else:
+#         #+++Scarlet:20200528
+#         #print('Scarlet6')
+#         gen_images_all = np.array(gen_images_all)
+#         gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm)
+#         #---Scarlet:20200528
+        
+#         # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
+#         # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
+#         # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
+#         mse_all = []
+#         psnr_all = []
+#         ssim_all = []
+#         persistent_mse_all = []
+#         persistent_psnr_all = []
+#         persistent_ssim_all = []
+#         f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w')
+#         for i in range(future_length):
+#             mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :,
+#                                                                         0]) ** 2)  # look at all timesteps except the first
+#             persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :,
+#                                                                         0]) ** 2)  # look at all timesteps except the first
+            
+#             psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0])
+#             ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0],
+#                               data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min(
+#                                   input_images_all[:, i + 10, :, :, 0].flatten()))
+#             persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0])
+#             persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0],
+#                               data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten()))
+#             mse_all.extend([mse_model])
+#             psnr_all.extend([psnr_model])
+#             ssim_all.extend([ssim_model])
+#             persistent_mse_all.extend([persistent_mse_model])
+#             persistent_psnr_all.extend([persistent_psnr_model])
+#             persistent_ssim_all.extend([persistent_ssim_model])
+#             results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
+
+#             persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all}
+#             f.write("##########Predicted Frame {}\n".format(str(i + 1)))
+#             f.write("Model MSE: %f\n" % mse_model)
+#             # f.write("Previous Frame MSE: %f\n" % mse_prev)
+#             f.write("Model PSNR: %f\n" % psnr_model)
+#             f.write("Model SSIM: %f\n" % ssim_model)
+
+#         pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb"))
+#         pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb"))
+#         # f.write("Previous frame PSNR: %f\n" % psnr_prev)
+#         f.write("Shape of X_test: " + str(input_images_all.shape))
+#         f.write("")
+#         f.write("Shape of X_hat: " + str(gen_images_all.shape)      
+
+if __name__ == '__main__':
+    main()        
+
+    #psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
+    #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
+    #psnr_prev = psnr(input_images_all[:, :, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
+
+    # ims = []
+    # fig = plt.figure()
+    # for frame in range(20):
+    #     input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel)
+    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
+    #     #pix_std = np.std(input_gen_diff, axis=0)
+    #     im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
+    #     if frame == 0:
+    #         fig.colorbar(im)
+    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
+    #     ims.append([im, ttl])
+    # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000)
+    # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
+    # plt.close("all")
+
+    # ims = []
+    # fig = plt.figure()
+    # for frame in range(19):
+    #     pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
+    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
+    #     #pix_std = np.std(input_gen_diff, axis=0)
+    #     im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
+    #     if frame == 0:
+    #         fig.colorbar(im)
+    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
+    #     ims.append([im, ttl])
+    # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
+
+    # seed(1)
+    # s = random.sample(range(len(gen_images_all)), 100)
+    # print("******KDP******")
+    # #kernel density plot for checking the model collapse
+    # fig = plt.figure()
+    # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
+    # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
+    # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
+    # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
+    # plt.clf()
+
+    #line plot for evaluating the prediction and groud-truth
+    # for i in [0,3,6,9,12,15,18]:
+    #     fig = plt.figure()
+    #     plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
+    #     #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
+    #     plt.xlabel("Prediction")
+    #     plt.ylabel("Real values")
+    #     plt.title("Frame_{}".format(i+1))
+    #     plt.plot([250,300], [250,300],color="black")
+    #     plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
+    #     plt.clf()
+    #
+    # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
+    # x = [str(i+1) for i in list(range(19))]
+    # fig,axis = plt.subplots()
+    # mean_f = np.mean(mse_model_by_frames, axis = 0)
+    # median = np.median(mse_model_by_frames, axis=0)
+    # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
+    # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
+    # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
+    # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
+    # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
+    # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
+    # plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
+    # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
+    # plt.title(f'MSE percentile')
+    # plt.xlabel("Frames")
+    # plt.legend(loc=2, fontsize=8)
+    # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
+
+
+##                
+##
+##                    # fig = plt.figure()
+##                    # gs = gridspec.GridSpec(4,6)
+##                    # gs.update(wspace = 0.7,hspace=0.8)
+##                    # ax1 = plt.subplot(gs[0:2,0:3])
+##                    # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
+##                    # ax3 = plt.subplot(gs[2:4,0:3])
+##                    # ax4 = plt.subplot(gs[2:4,3:])
+##                    # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
+##                    # ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+##                    # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
+##                    # ax1.title.set_text("(a) Ground Truth")
+##                    # ax2.title.set_text("(b) SAVP")
+##                    # ax3.title.set_text("(c) Diff.")
+##                    # ax4.title.set_text("(d) MSE")
+##                    #
+##                    # ax1.xaxis.set_tick_params(labelsize=7)
+##                    # ax1.yaxis.set_tick_params(labelsize = 7)
+##                    # ax2.xaxis.set_tick_params(labelsize=7)
+##                    # ax2.yaxis.set_tick_params(labelsize = 7)
+##                    # ax3.xaxis.set_tick_params(labelsize=7)
+##                    # ax3.yaxis.set_tick_params(labelsize = 7)
+##                    #
+##                    # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
+##                    # print("inti images shape", init_images.shape)
+##                    # xdata, ydata = [], []
+##                    # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
+##                    # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
+##                    # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+##                    # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+##                    # #x = np.linspace(0, 64, 64)
+##                    # #y = np.linspace(0, 64, 64)
+##                    # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+##                    # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+##                    # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
+##                    # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
+##                    #
+##                    # cm = LinearSegmentedColormap.from_list(
+##                    #     cmap_name, "bwr", N = 5)
+##                    #
+##                    # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r',
+##                    # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm)  # cmap = 'PuBu_r',
+##                    # plot4, = ax4.plot([], [], color = "r")
+##                    # ax4.set_xlim(0, future_length-1)
+##                    # ax4.set_ylim(0, 20)
+##                    # #ax4.set_ylim(0, 0.5)
+##                    # ax4.set_xlabel("Frames", fontsize=10)
+##                    # #ax4.set_ylabel("MSE", fontsize=10)
+##                    # ax4.xaxis.set_tick_params(labelsize=7)
+##                    # ax4.yaxis.set_tick_params(labelsize=7)
+##                    #
+##                    #
+##                    # plots = [plot1, plot2, plot3, plot4]
+##                    #
+##                    # #fig.colorbar(plots[1], ax = [ax1, ax2])
+##                    #
+##                    # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
+##                    # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
+##                    # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
+##                    #
+##                    # def animation_sample(t):
+##                    #     input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
+##                    #     gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
+##                    #     diff_image = input_gen_diff[t,:,:]
+##                    #     # p = sns.lineplot(x=x,y=data,color="b")
+##                    #     # p.tick_params(labelsize=17)
+##                    #     # plt.setp(p.lines, linewidth=6)
+##                    #     plots[0].set_data(input_image)
+##                    #     plots[1].set_data(gen_image)
+##                    #     #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+##                    #     #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+##                    #     plots[2].set_data(diff_image)
+##                    #
+##                    #     if t >= future_length:
+##                    #         #data = gen_mse_avg_[:t + 1]
+##                    #         # x = list(range(len(gen_mse_avg_)))[:t+1]
+##                    #         xdata.append(t-future_length)
+##                    #         print("xdata", xdata)
+##                    #         ydata.append(gen_mse_avg_[t])
+##                    #         print("ydata", ydata)
+##                    #         plots[3].set_data(xdata, ydata)
+##                    #         fig.suptitle("Predicted Frame " + str(t-future_length))
+##                    #     else:
+##                    #         #plots[3].set_data(xdata, ydata)
+##                    #         fig.suptitle("Context Frame " + str(t))
+##                    #     return plots
+##                    #
+##                    # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
+##                    #                               repeat_delay=2000)
+##                    # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
+##
+####                else:
+####                    pass
+##
+
+
+
+
+
+    #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     plt.xlim(0,len(gen_mse_avg_))
+    #         #     plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg))
+    #         #     plt.xlabel("Frames")
+    #         #     plt.ylabel("MSE_AVG")
+    #         #     #X = list(range(len(gen_mse_avg_)))
+    #         #     #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     def animate_metric(j):
+    #         #         data = gen_mse_avg_[:(j+1)]
+    #         #         x = list(range(len(gen_mse_avg_)))[:(j+1)]
+    #         #         p = sns.lineplot(x=x,y=data,color="b")
+    #         #         p.tick_params(labelsize=17)
+    #         #         plt.setp(p.lines, linewidth=6)
+    #         #     ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif"))
+    #         #
+    #         #
+    #         # for i, input_images_ in enumerate(input_images):
+    #         #     #context_images_ = (input_results['images'][i])
+    #         #     #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, input_image in enumerate(input_images_):
+    #         #         im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2,"Frame_" + str(t))
+    #         #         ims.append([im,ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif"))
+    #         #     #plt.show()
+    #         #
+    #         # for i,gen_images_ in enumerate(gen_images):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, gen_image in enumerate(gen_images_):
+    #         #         im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2, "Frame_" + str(t))
+    #         #         ims.append([im, ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif"))
+    #
+    #
+    #             # for i, gen_images_ in enumerate(gen_images):
+    #             #     #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8)
+    #             #     #gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
+    #             #     #bing
+    #             #     context_images_ = (input_results['images'][i])
+    #             #     gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #             #     context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
+    #             #     plt.figure(figsize = (10,2))
+    #             #     gs = gridspec.GridSpec(2,10)
+    #             #     gs.update(wspace=0.,hspace=0.)
+    #             #     for t, gen_image in enumerate(gen_images_):
+    #             #         gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1)))
+    #             #         gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #         plt.subplot(gs[t])
+    #             #         plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')  # the last index sets the channel. 0 = t2
+    #             #         # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Actual', fontsize = 10)
+    #             #
+    #             #         plt.subplot(gs[t + 10])
+    #             #         plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #             #         # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Predicted', fontsize = 10)
+    #             #     plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png')
+    #             #     plt.clf()
+    #
+    #             # if args.gif_length:
+    #             #     context_and_gen_images = context_and_gen_images[:args.gif_length]
+    #             # save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
+    #             #          context_and_gen_images, fps=args.fps)
+    #             #
+    #             # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
+    #             # for t, gen_image in enumerate(gen_images_):
+    #             #     gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #     if gen_image.shape[-1] == 1:
+    #             #       gen_image = np.tile(gen_image, (1, 1, 3))
+    #             #     else:
+    #             #       gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
+    #             #     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
index 0e250b47df28d115c8cdfc77fc708eab5e094ce6..1fcbd1cf97442f1ea440039a0bb6769473b957f3 100644
--- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
+++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
@@ -202,11 +202,6 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t
     ts_len = len(ts)
     ts_input = ts[:context_frames]
     ts_forecast = ts[context_frames:]
-    #print("context_frame:",context_frames)
-    #print("future_frame",future_length)
-    #print("length of ts input:",len(ts_input))
-
-    print("input_images_ shape in netcdf,",input_images_.shape)
     gen_images_ = np.array(gen_images_)
 
     output_file = os.path.join(output_dir,fl_name)
@@ -289,18 +284,19 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t
         #Temperature:
         t2 = nc_file.createVariable("/forecast/{}/T2".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
         t2.units = 'K'
-        t2[:,:,:] = gen_images_[context_frames:,:,:,0]
+        print ("gen_images_ 20200822:",np.array(gen_images_).shape)
+        t2[:,:,:] = gen_images_[context_frames-1:,:,:,0]
         print("NetCDF created")
 
         #mean sea level pressure
         msl = nc_file.createVariable("/forecast/{}/MSL".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
         msl.units = 'Pa'
-        msl[:,:,:] = gen_images_[context_frames:,:,:,1]
+        msl[:,:,:] = gen_images_[context_frames-1:,:,:,1]
 
         #Geopotential at 500 
         gph500 = nc_file.createVariable("/forecast/{}/GPH500".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
         gph500.units = 'm'
-        gph500[:,:,:] = gen_images_[context_frames:,:,:,2]        
+        gph500[:,:,:] = gen_images_[context_frames-1:,:,:,2]        
 
         print("{} created".format(output_file)) 
 
@@ -451,7 +447,8 @@ def main():
         #Get prediction values 
         feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
         gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel]
-        
+        assert gen_images.shape[1] == sequence_length-1 #The generate images seq_len should be sequence_len -1, since the last one is not used for comparing with groud truth 
+        print("gen_images 20200822:",np.array(gen_images).shape)       
         #Loop in batch size
         for i in range(args.batch_size):
             
@@ -459,26 +456,26 @@ def main():
             input_images_,t_start = get_one_seq_and_time(input_images,t_starts,i)
             #generate time stamps for sequences
             ts = generate_seq_timestamps(t_start,len_seq=sequence_length)
-            
+             
             #Renormalized data for inputs
             stat_fl = os.path.join(args.input_dir,"pickle/statistics.json")
             input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"])  
-            print("input_images_denorm",input_images_denorm[0][0])
+            print("input_images_denorm shape",np.array(input_images_denorm).shape)
                                                              
             #Renormalized data for inputs
             gen_images_ = gen_images[i]
             gen_images_denorm = denorm_images_all_channels(stat_fl,gen_images_,["T2","MSL","gph500"])
-            print("gene_images_denorm:",gen_images_denorm[0][0])
+            print("gene_images_denorm shape",np.array(gen_images_denorm).shape)
             
             #Save input to netCDF file
             init_date_str = ts[0].strftime("%Y%m%d%H")
             save_to_netcdf_per_sequence(args.results_dir,input_images_denorm,gen_images_denorm,lons,lats,ts,context_frames,future_length,args.model,fl_name="vfp_{}.nc".format(init_date_str))
                                                              
             #Generate images inputs
-            plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],lats=lats,lons=lons,ts=ts[:context_frames-1],label="Ground Truth",output_png_dir=args.results_dir)  
+            plot_seq_imgs(imgs=input_images_denorm[context_frames+1:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Ground Truth",output_png_dir=args.results_dir)  
                                                              
             #Generate forecast images
-            plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) 
+            plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) 
             
             #TODO: Scaret plot persistence image
             #implment get_persistence() function
diff --git a/video_prediction_savp/scripts/train_moving_mnist.py b/video_prediction_savp/scripts/train_moving_mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cb9a6be0d23449dbdf30cc316815e2a33b29de1
--- /dev/null
+++ b/video_prediction_savp/scripts/train_moving_mnist.py
@@ -0,0 +1,357 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import random
+import time
+import numpy as np
+import tensorflow as tf
+from video_prediction import datasets, models
+import matplotlib.pyplot as plt
+from json import JSONEncoder
+import pickle as pkl
+
+
+class NumpyArrayEncoder(JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, np.ndarray):
+            return obj.tolist()
+        return JSONEncoder.default(self, obj)
+
+
+def add_tag_suffix(summary, tag_suffix):
+    summary_proto = tf.Summary()
+    summary_proto.ParseFromString(summary)
+    summary = summary_proto
+
+    for value in summary.value:
+        tag_split = value.tag.split('/')
+        value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:])
+    return summary.SerializeToString()
+
+def generate_output_dir(output_dir, model,model_hparams,logs_dir,output_dir_postfix):
+    if output_dir is None:
+        list_depth = 0
+        model_fname = ''
+        for t in ('model=%s,%s' % (model, model_hparams)):
+            if t == '[':
+                list_depth += 1
+            if t == ']':
+                list_depth -= 1
+            if list_depth and t == ',':
+                t = '..'
+            if t in '=,':
+                t = '.'
+            if t in '[]':
+                t = ''
+            model_fname += t
+        output_dir = os.path.join(logs_dir, model_fname) + output_dir_postfix
+    return output_dir
+
+
+def get_model_hparams_dict(model_hparams_dict):
+    """
+    Get model_hparams_dict from json file
+    """
+    model_hparams_dict_load = {}
+    if model_hparams_dict:
+        with open(model_hparams_dict) as f:
+            model_hparams_dict_load.update(json.loads(f.read()))
+    return model_hparams_dict
+
+def resume_checkpoint(resume,checkpoint,output_dir):
+    """
+    Resume the existing model checkpoints and set checkpoint directory
+    """
+    if resume:
+        if checkpoint:
+            raise ValueError('resume and checkpoint cannot both be specified')
+        checkpoint = output_dir
+    return checkpoint
+
+def set_seed(seed):
+    if seed is not None:
+        tf.set_random_seed(seed)
+        np.random.seed(seed)
+        random.seed(seed)
+
+def load_params_from_checkpoints_dir(model_hparams_dict,checkpoint,dataset,model):
+   
+    model_hparams_dict_load = {}
+    if model_hparams_dict:
+        with open(model_hparams_dict) as f:
+            model_hparams_dict_load.update(json.loads(f.read()))
+ 
+    if checkpoint:
+        checkpoint_dir = os.path.normpath(checkpoint)
+        if not os.path.isdir(checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            dataset = dataset or options['dataset']
+            model = model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict_load.update(json.loads(f.read()))
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+    return dataset, model, model_hparams_dict_load
+
+def setup_dataset(dataset,input_dir,val_input_dir):
+    VideoDataset = datasets.get_dataset_class(dataset)
+    train_dataset = VideoDataset(
+        input_dir,
+        mode='train')
+    val_dataset = VideoDataset(
+        val_input_dir or input_dir,
+        mode='val')
+    variable_scope = tf.get_variable_scope()
+    variable_scope.set_use_resource(True)
+    return train_dataset,val_dataset,variable_scope
+
+def setup_model(model,model_hparams_dict,train_dataset,model_hparams):
+    """
+    Set up model instance
+    """
+    VideoPredictionModel = models.get_model_class(model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': train_dataset.hparams.context_frames,
+        'sequence_length': train_dataset.hparams.sequence_length,
+        'repeat': train_dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        hparams_dict=hparams_dict,
+        hparams=model_hparams)
+    return model
+
+def save_dataset_model_params_to_checkpoint_dir(args,output_dir,train_dataset,model):
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+    with open(os.path.join(output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+    return None
+
+def make_dataset_iterator(train_dataset, val_dataset, batch_size ):
+    train_tf_dataset = train_dataset.make_dataset_v2(batch_size)
+    train_iterator = train_tf_dataset.make_one_shot_iterator()
+    # The `Iterator.string_handle()` method returns a tensor that can be evaluated
+    # and used to feed the `handle` placeholder.
+    train_handle = train_iterator.string_handle()
+    val_tf_dataset = val_dataset.make_dataset_v2(batch_size)
+    val_iterator = val_tf_dataset.make_one_shot_iterator()
+    val_handle = val_iterator.string_handle()
+    #iterator = tf.data.Iterator.from_string_handle(
+    #    train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
+    inputs = train_iterator.get_next()
+    val = val_iterator.get_next()
+    return inputs,train_handle, val_handle
+
+
+def plot_train(train_losses,val_losses,output_dir):
+    iterations = list(range(len(train_losses))) 
+    plt.plot(iterations, train_losses, 'g', label='Training loss')
+    plt.plot(iterations, val_losses, 'b', label='validation loss')
+    plt.title('Training and Validation loss')
+    plt.xlabel('Iterations')
+    plt.ylabel('Loss')
+    plt.legend()
+    plt.savefig(os.path.join(output_dir,'plot_train.png'))
+
+def save_results_to_dict(results_dict,output_dir):
+    with open(os.path.join(output_dir,"results.json"),"w") as fp:
+        json.dump(results_dict,fp)    
+
+def save_results_to_pkl(train_losses,val_losses, output_dir):
+     with open(os.path.join(output_dir,"train_losses.pkl"),"wb") as f:
+        pkl.dump(train_losses,f)
+     with open(os.path.join(output_dir,"val_losses.pkl"),"wb") as f:
+        pkl.dump(val_losses,f) 
+ 
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
+    parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
+    parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
+                                             "default is logs_dir/model_fname, where model_fname consists of "
+                                             "information from model and model_hparams")
+    parser.add_argument("--output_dir_postfix", default="")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+    parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") 
+    parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
+
+    parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="fraction of gpu memory to use")
+    parser.add_argument("--seed",default=1234, type=int)
+
+    args = parser.parse_args()
+     
+    #Set seed  
+    set_seed(args.seed)
+    
+    #setup output directory
+    args.output_dir = generate_output_dir(args.output_dir, args.model, args.model_hparams, args.logs_dir, args.output_dir_postfix)
+    
+    #resume the existing checkpoint and set up the checkpoint directory to output directory
+    args.checkpoint = resume_checkpoint(args.resume,args.checkpoint,args.output_dir)
+ 
+    #get model hparams dict from json file
+    #load the existing checkpoint related datasets, model configure (This information was stored in the checkpoint dir when last time training model)
+    args.dataset,args.model,model_hparams_dict = load_params_from_checkpoints_dir(args.model_hparams_dict,args.checkpoint,args.dataset,args.model)
+     
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+    #setup training val datset instance
+    train_dataset,val_dataset,variable_scope = setup_dataset(args.dataset,args.input_dir,args.val_input_dir)
+    
+    #setup model instance 
+    model=setup_model(args.model,model_hparams_dict,train_dataset,args.model_hparams)
+
+    batch_size = model.hparams.batch_size
+    #Create input and val iterator
+    inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size)
+    
+    #build model graph
+    #del inputs["T_start"]
+    model.build_graph(inputs)
+    
+    #save all the model, data params to output dirctory
+    save_dataset_model_params_to_checkpoint_dir(args,args.output_dir,train_dataset,model)
+    
+    with tf.name_scope("parameter_count"):
+        # exclude trainable variables that are replicas (used in multi-gpu setting)
+        trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables)
+        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables])
+
+    saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2)
+
+    # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero
+    summary_writer = tf.summary.FileWriter(args.output_dir)
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+ 
+    max_epochs = model.hparams.max_epochs #the number of epochs
+    num_examples_per_epoch = train_dataset.num_examples_per_epoch()
+    print ("number of exmaples per epoch:",num_examples_per_epoch)
+    steps_per_epoch = int(num_examples_per_epoch/batch_size)
+    total_steps = steps_per_epoch * max_epochs
+    global_step = tf.train.get_or_create_global_step()
+    #mock total_steps only for fast debugging
+    #total_steps = 10
+    print ("Total steps for training:",total_steps)
+    results_dict = {}
+    with tf.Session(config=config) as sess:
+        print("parameter_count =", sess.run(parameter_count))
+        sess.run(tf.global_variables_initializer())
+        sess.run(tf.local_variables_initializer())
+        model.restore(sess, args.checkpoint)
+        sess.graph.finalize()
+        #start_step = sess.run(model.global_step)
+        start_step = sess.run(global_step)
+        print("start_step", start_step)
+        # start at one step earlier to log everything without doing any training
+        # step is relative to the start_step
+        train_losses=[]
+        val_losses=[]
+        run_start_time = time.time()        
+        for step in range(start_step,total_steps):
+            #global_step = sess.run(global_step):q
+ 
+            print ("step:", step)
+            val_handle_eval = sess.run(val_handle)
+
+            #Fetch variables in the graph
+
+            fetches = {"train_op": model.train_op}
+            #fetches["latent_loss"] = model.latent_loss
+            fetches["summary"] = model.summary_op 
+            
+            if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
+                fetches["global_step"] = model.global_step
+                fetches["total_loss"] = model.total_loss
+                #fetch the specific loss function only for mcnet
+                if model.__class__.__name__ == "McNetVideoPredictionModel":
+                    fetches["L_p"] = model.L_p
+                    fetches["L_gdl"] = model.L_gdl
+                    fetches["L_GAN"]  =model.L_GAN
+                if model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
+                    fetches["latent_loss"] = model.latent_loss
+                    fetches["recon_loss"] = model.recon_loss
+                results = sess.run(fetches)
+                train_losses.append(results["total_loss"])
+                #Fetch losses for validation data
+                val_fetches = {}
+                #val_fetches["latent_loss"] = model.latent_loss
+                val_fetches["total_loss"] = model.total_loss
+
+
+            if model.__class__.__name__ == "SAVPVideoPredictionModel":
+                fetches['d_loss'] = model.d_loss
+                fetches['g_loss'] = model.g_loss
+                fetches['d_losses'] = model.d_losses
+                fetches['g_losses'] = model.g_losses
+                results = sess.run(fetches)
+                train_losses.append(results["g_losses"])
+                val_fetches = {}
+                #val_fetches["latent_loss"] = model.latent_loss
+                #For SAVP the total loss is the generator loses
+                val_fetches["total_loss"] = model.g_losses
+
+            val_fetches["summary"] = model.summary_op
+            val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval})
+            val_losses.append(val_results["total_loss"])
+
+            summary_writer.add_summary(results["summary"])
+            summary_writer.add_summary(val_results["summary"])
+            summary_writer.flush()
+
+            # global_step will have the correct step count if we resume from a checkpoint
+            # global step is read before it's incemented
+            train_epoch = step/steps_per_epoch
+            print("progress  global step %d  epoch %0.1f" % (step + 1, train_epoch))
+            if model.__class__.__name__ == "McNetVideoPredictionModel":
+                print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"]))
+            elif model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
+                print ("Total_loss:{}".format(results["total_loss"]))
+            elif model.__class__.__name__ == "SAVPVideoPredictionModel":
+                print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"]))
+            elif model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
+                print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"]))
+            else:
+                print ("The model name does not exist")
+
+            #print("saving model to", args.output_dir)
+            saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)#
+        train_time = time.time() - run_start_time
+        results_dict = {"train_time":train_time,
+                        "total_steps":total_steps}
+        save_results_to_dict(results_dict,args.output_dir)
+        save_results_to_pkl(train_losses, val_losses, args.output_dir)
+        print("train_losses:",train_losses)
+        print("val_losses:",val_losses) 
+        plot_train(train_losses,val_losses,args.output_dir)
+        print("Done")
+        
+if __name__ == '__main__':
+    main()
diff --git a/video_prediction_savp/video_prediction/datasets/__init__.py b/video_prediction_savp/video_prediction/datasets/__init__.py
index e449a65bd48b14ef5a11e6846ce4d8f39f7ed193..f58607c2f4c14047aefb36956e98bd228a30aeb1 100644
--- a/video_prediction_savp/video_prediction/datasets/__init__.py
+++ b/video_prediction_savp/video_prediction/datasets/__init__.py
@@ -7,6 +7,7 @@ from .kth_dataset import KTHVideoDataset
 from .ucf101_dataset import UCF101VideoDataset
 from .cartgripper_dataset import CartgripperVideoDataset
 from .era5_dataset_v2 import ERA5Dataset_v2
+from .moving_mnist import MovingMnist
 #from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly
 
 def get_dataset_class(dataset):
@@ -19,6 +20,7 @@ def get_dataset_class(dataset):
         'ucf101': 'UCF101VideoDataset',
         'cartgripper': 'CartgripperVideoDataset',
         "era5":"ERA5Dataset_v2",
+        "moving_mnist":"MovingMnist"
 #        "era5_anomaly":"ERA5Dataset_v2_anomaly",
     }
     dataset_class = dataset_mappings.get(dataset, dataset)
diff --git a/video_prediction_savp/video_prediction/datasets/kth_dataset.py b/video_prediction_savp/video_prediction/datasets/kth_dataset.py
index 40fb6bf57b8219fc7e75c3759df9e6b38fffeb30..e1e11d51968e706868fd89f26faa25d1999d3a9b 100644
--- a/video_prediction_savp/video_prediction/datasets/kth_dataset.py
+++ b/video_prediction_savp/video_prediction/datasets/kth_dataset.py
@@ -136,11 +136,11 @@ def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequence
 
 def main():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--input_dir", type=str, help="directory containing the processed directories "
+    parser.add_argument("input_dir", type=str, help="directory containing the processed directories "
                                                     "boxing, handclapping, handwaving, "
                                                     "jogging, running, walking")
-    parser.add_argument("--output_dir", type=str)
-    parser.add_argument("--image_size", type=int)
+    parser.add_argument("output_dir", type=str)
+    parser.add_argument("image_size", type=int)
     args = parser.parse_args()
 
     partition_names = ['train', 'val', 'test']
diff --git a/video_prediction_savp/video_prediction/datasets/moving_mnist.py b/video_prediction_savp/video_prediction/datasets/moving_mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..12a517ac9540dd443866d47ea689803abe9e9a4d
--- /dev/null
+++ b/video_prediction_savp/video_prediction/datasets/moving_mnist.py
@@ -0,0 +1,241 @@
+import argparse
+import glob
+import itertools
+import os
+import pickle
+import random
+import re
+import numpy as np
+import json
+import tensorflow as tf
+from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
+# ML 2020/04/14: hack for getting functions of process_netCDF_v2:
+from os import path
+import sys
+sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/'))
+import DataPreprocess.process_netCDF_v2 
+from DataPreprocess.process_netCDF_v2 import get_unique_vars
+from DataPreprocess.process_netCDF_v2 import Calc_data_stat
+from metadata import MetaData
+#from base_dataset import VarLenFeatureVideoDataset
+from collections import OrderedDict
+from tensorflow.contrib.training import HParams
+from mpi4py import MPI
+import glob
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+
+class MovingMnist(VarLenFeatureVideoDataset):
+    def __init__(self, *args, **kwargs):
+        super(MovingMnist, self).__init__(*args, **kwargs)
+        from google.protobuf.json_format import MessageToDict
+        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
+        dict_message = MessageToDict(tf.train.Example.FromString(example))
+        feature = dict_message['features']['feature']
+        print("features in dataset:",feature.keys())
+        self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
+        self.image_shape = self.video_shape[1:]
+        self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(MovingMnist, self).get_default_hparams_dict()
+        hparams = dict(
+            context_frames=10,#Bing: Todo oriignal is 10
+            sequence_length=20,#bing: TODO original is 20,
+            shuffle_on_val=True, 
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+
+    @property
+    def jpeg_encoding(self):
+        return False
+
+
+
+    def num_examples_per_epoch(self):
+        with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
+            sequence_lengths = sequence_lengths_file.readlines()
+        sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
+        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
+
+    def filter(self, serialized_example):
+        return tf.convert_to_tensor(True)
+
+
+    def make_dataset_v2(self, batch_size):
+        def parser(serialized_example):
+            seqs = OrderedDict()
+            keys_to_features = {
+                'width': tf.FixedLenFeature([], tf.int64),
+                'height': tf.FixedLenFeature([], tf.int64),
+                'sequence_length': tf.FixedLenFeature([], tf.int64),
+                'channels': tf.FixedLenFeature([], tf.int64),
+                'images/encoded': tf.VarLenFeature(tf.float32)
+            }
+            
+            # for i in range(20):
+            #     keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
+            parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
+            print ("Parse features", parsed_features)
+            seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
+            #width = tf.sparse_tensor_to_dense(parsed_features["width"])
+           # height = tf.sparse_tensor_to_dense(parsed_features["height"])
+           # channels  = tf.sparse_tensor_to_dense(parsed_features["channels"])
+           # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
+            images = []
+            print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
+            images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
+            seqs["images"] = images
+            return seqs
+        filenames = self.filenames
+        print ("FILENAMES",filenames)
+	    #TODO:
+	    #temporal_filenames = self.temporal_filenames
+        shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
+        if shuffle:
+            random.shuffle(filenames)
+        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)  # todo: what is buffer_size
+        print("files", self.filenames)
+        print("mode", self.mode)
+        dataset = dataset.filter(self.filter)
+        if shuffle:
+            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
+        else:
+            dataset = dataset.repeat(self.num_epochs)
+
+        num_parallel_calls = None if shuffle else 1
+        dataset = dataset.apply(tf.contrib.data.map_and_batch(
+            parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
+        #dataset = dataset.map(parser)
+        # num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
+        # dataset = dataset.apply(tf.contrib.data.map_and_batch(
+        #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
+        dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU
+        return dataset
+
+
+
+    def make_batch(self, batch_size):
+        dataset = self.make_dataset_v2(batch_size)
+        iterator = dataset.make_one_shot_iterator()
+        return iterator.get_next()
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _bytes_list_feature(values):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
+
+def _floats_feature(value):
+    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+def save_tf_record(output_fname, sequences):
+    with tf.python_io.TFRecordWriter(output_fname) as writer:
+        for i in range(len(sequences)):
+            sequence = sequences[:,i,:,:,:] 
+            num_frames = len(sequence)
+            height, width = sequence[0,:,:,0].shape
+            encoded_sequence = np.array([list(image) for image in sequence])
+            features = tf.train.Features(feature={
+                'sequence_length': _int64_feature(num_frames),
+                'height': _int64_feature(height),
+                'width': _int64_feature(width),
+                'channels': _int64_feature(1),
+                'images/encoded': _floats_feature(encoded_sequence.flatten()),
+            })
+            example = tf.train.Example(features=features)
+            writer.write(example.SerializeToString())
+
+def read_frames_and_save_tf_records(output_dir,dat_npz, seq_length=20, sequences_per_file=128, height=64, width=64):#Bing: original 128
+    """
+    Read the moving_mnst data which is npz format, and save it to tfrecords files
+    The shape of dat_npz is [seq_length,number_samples,height,width]
+    moving_mnst only has one channel
+
+    """
+    os.makedirs(output_dir,exist_ok=True)
+    idx = 0
+    num_samples = dat_npz.shape[1]
+    dat_npz = np.expand_dims(dat_npz, axis=4) #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel]
+    print("data_npz_shape",dat_npz.shape)
+    dat_npz = dat_npz.astype(np.float32)
+    dat_npz /= 255.0 #normalize RGB codes by dividing it to the max RGB value 
+    while idx < num_samples - sequences_per_file:
+        sequences = dat_npz[:,idx:idx+sequences_per_file,:,:,:]
+        output_fname = 'sequence_{}_{}.tfrecords'.format(idx,idx+sequences_per_file)
+        output_fname = os.path.join(output_dir, output_fname)
+        save_tf_record(output_fname, sequences)
+        idx = idx + sequences_per_file
+    return None
+
+
+def write_sequence_file(output_dir,seq_length,sequences_per_file):    
+    partition_names = ["train","val","test"]
+    for partition_name in partition_names:
+        save_output_dir = os.path.join(output_dir,partition_name)
+        tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords"))
+        print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter))
+        sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w')
+        for i in range(tfCounter*sequences_per_file):
+            sequence_lengths_file.write("%d\n" % seq_length)
+        sequence_lengths_file.close()
+
+
+def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"):
+    """
+    Plot the seq images 
+    """
+
+    if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)")
+    img_len = imgs.shape[0]
+    fig = plt.figure(figsize=(18,6))
+    gs = gridspec.GridSpec(1, 10)
+    gs.update(wspace = 0., hspace = 0.)
+    for i in range(img_len):
+        ax1 = plt.subplot(gs[i])
+        plt.imshow(imgs[i] ,cmap = 'jet')
+        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+    plt.savefig(os.path.join(output_png_dir, label + "_" +   str(idx) +  ".jpg"))
+    print("images_saved")
+    plt.clf()
+
+
+    
+    
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
+    parser.add_argument("output_dir", type=str)
+    parser.add_argument("-sequences_per_file",type=int,default=2)
+    args = parser.parse_args()
+    current_path = os.getcwd()
+    data = np.load(os.path.join(args.input_dir,"mnist_test_seq.npy"))
+    print("data in minist_test_Seq shape",data.shape)
+    seq_length =  data.shape[0]
+    height = data.shape[2]
+    width = data.shape[3]
+    num_samples = data.shape[1] 
+    max_npz = np.max(data)
+    min_npz = np.min(data)
+    print("max_npz,",max_npz)
+    print("min_npz",min_npz)
+    #Todo need to discuss how to split the data, since we have totally 10000 samples, the origin paper convLSTM used 10000 as training, 2000 as validation and 3000 for testing
+    dat_train = data[:,:6000,:,:]
+    dat_val = data[:,6000:7000,:,:]
+    dat_test = data[:,7000:,:]
+    #plot_seq_imgs(dat_test[10:,0,:,:],output_png_dir="/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist/convLSTM",idx=1,label="Ground Truth from npz")
+    #save train
+    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"train"),dat_train, seq_length=20, sequences_per_file=40, height=height, width=width)
+    #save val
+    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"val"),dat_val, seq_length=20, sequences_per_file=40, height=height, width=width)
+    #save test     
+    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"test"),dat_test, seq_length=20, sequences_per_file=40, height=height, width=width)
+    #write_sequence_file(output_dir=args.output_dir,seq_length=20,sequences_per_file=40)
+if __name__ == '__main__':
+     main()
+
diff --git a/video_prediction_savp/video_prediction/layers/layer_def.py b/video_prediction_savp/video_prediction/layers/layer_def.py
index a59643c7a6d69141134ec01c9b147c4798bfed8e..738e139df0e155ca294fe43edd03a2d79fc1f532 100644
--- a/video_prediction_savp/video_prediction/layers/layer_def.py
+++ b/video_prediction_savp/video_prediction/layers/layer_def.py
@@ -3,8 +3,8 @@
 
 import tensorflow as tf
 import numpy as np
-
 weight_decay = 0.0005
+
 def _activation_summary(x):
     """Helper to create summaries for activations.
     Creates a summary that provides a histogram of activations.
@@ -157,4 +157,4 @@ def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None):
 
 def bn_layers_wrapper(inputs, is_training):
     pass
-   
\ No newline at end of file
+   
diff --git a/video_prediction_savp/video_prediction/models/__init__.py b/video_prediction_savp/video_prediction/models/__init__.py
index 6d7323f3750949b0ddb411d4a98934928537bc53..b71769a9d8cc523e4f108fc222fc2ed0284019f7 100644
--- a/video_prediction_savp/video_prediction/models/__init__.py
+++ b/video_prediction_savp/video_prediction/models/__init__.py
@@ -21,7 +21,7 @@ def get_model_class(model):
         'vae': 'VanillaVAEVideoPredictionModel',
         'convLSTM': 'VanillaConvLstmVideoPredictionModel',
         'mcnet': 'McNetVideoPredictionModel',
-        
+        'convLSTM_Loliver': "ConvLstmLoliverVideoPredictionModel"
         }
     model_class = model_mappings.get(model, model)
     model_class = globals().get(model_class)
diff --git a/video_prediction_savp/video_prediction/models/base_model.py b/video_prediction_savp/video_prediction/models/base_model.py
index 846621d8ca1e235c39618951be86fe184a2d974d..df479968325946a9d61896d498428d65692c1848 100644
--- a/video_prediction_savp/video_prediction/models/base_model.py
+++ b/video_prediction_savp/video_prediction/models/base_model.py
@@ -487,6 +487,7 @@ class VideoPredictionModel(BaseVideoPredictionModel):
         # skip_vars = {" discriminator_encoder/video_sn_fc4/dense/bias"}
 
         if self.num_gpus <= 1:  # cpu or 1 gpu
+            print("self.inputs:>20200822",self.inputs)
             outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs)
             self.outputs, self.eval_outputs = outputs_tuple
             self.d_losses, self.g_losses, g_losses_post = losses_tuple
diff --git a/video_prediction_savp/video_prediction/models/savp_model.py b/video_prediction_savp/video_prediction/models/savp_model.py
index c510d050c89908d0e06fe0f1a66e355e61c90530..039533864f34d1608c5a10d4a664d40ce73594a7 100644
--- a/video_prediction_savp/video_prediction/models/savp_model.py
+++ b/video_prediction_savp/video_prediction/models/savp_model.py
@@ -689,6 +689,10 @@ class SAVPCell(tf.nn.rnn_cell.RNNCell):
 def generator_given_z_fn(inputs, mode, hparams):
     # all the inputs needs to have the same length for unrolling the rnn
     print("inputs.items",inputs.items())
+    #20200822 bing
+    inputs ={"images":inputs["images"]}
+    print("inputs 20200822:",inputs)
+    #20200822
     inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
               for name, input in inputs.items()}
     cell = SAVPCell(inputs, mode, hparams)
diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
index 744284fc6c5b52bcde249f1a58e04a41e80339fa..01a4f7ce5d6430f19a1e4b99c4cba956b3f7682b 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
@@ -94,6 +94,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
     def convLSTM_cell(inputs, hidden):
 
         y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
+        channels = inputs.get_shape()[-1]
         # conv lstm cell
         cell_shape = y_0.get_shape().as_list()
         channels = cell_shape[-1]