diff --git a/video_prediction_savp/HPC_scripts/DataExtraction_template.sh b/video_prediction_savp/HPC_scripts/data_extraction_era5_template.sh
similarity index 100%
rename from video_prediction_savp/HPC_scripts/DataExtraction_template.sh
rename to video_prediction_savp/HPC_scripts/data_extraction_era5_template.sh
diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess_template.sh b/video_prediction_savp/HPC_scripts/preprocess_data_era5_step1_template.sh
similarity index 100%
rename from video_prediction_savp/HPC_scripts/DataPreprocess_template.sh
rename to video_prediction_savp/HPC_scripts/preprocess_data_era5_step1_template.sh
diff --git a/video_prediction_savp/HPC_scripts/preprocess_data_era5_step2_template.sh b/video_prediction_savp/HPC_scripts/preprocess_data_era5_step2_template.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e953b5bc3fd2a836a74b647c1066735d19e39640
--- /dev/null
+++ b/video_prediction_savp/HPC_scripts/preprocess_data_era5_step2_template.sh
@@ -0,0 +1,40 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=13
+##SBATCH --ntasks-per-node=13
+#SBATCH --cpus-per-task=1
+#SBATCH --output=DataPreprocess_to_tf-out.%j
+#SBATCH --error=DataPreprocess_to_tf-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --partition=devel
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+
+######### Template identifier (don't remove) #########
+echo "Do not run the template scripts"
+exit 99
+######### Template identifier (don't remove) #########
+
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
+
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/
+destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/
+
+# run Preprocessing (step 2 where Tf-records are generated)
+srun python ../video_prediction/datasets/era5_dataset_v2.py ${source_dir}/pickle ${destination_dir}/tfrecords -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20 
diff --git a/video_prediction_savp/HPC_scripts/preprocess_data_moving_mnist_template.sh b/video_prediction_savp/HPC_scripts/preprocess_data_moving_mnist_template.sh
new file mode 100644
index 0000000000000000000000000000000000000000..dc1fbb4a83788a4cc1f69fdf151d8419129dc06d
--- /dev/null
+++ b/video_prediction_savp/HPC_scripts/preprocess_data_moving_mnist_template.sh
@@ -0,0 +1,41 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+##SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=1
+#SBATCH --output=DataPreprocess_to_tf-out.%j
+#SBATCH --error=DataPreprocess_to_tf-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --partition=devel
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+
+######### Template identifier (don't remove) #########
+echo "Do not run the template scripts"
+exit 99
+######### Template identifier (don't remove) #########
+
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
+
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+
+source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist 
+destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist
+
+# run Preprocessing (step 2 where Tf-records are generated)
+srun python ../video_prediction/datasets/moving_mnist.py ${source_dir} ${destination_dir}/tfrecords
diff --git a/video_prediction_savp/HPC_scripts/train_era5_template.sh b/video_prediction_savp/HPC_scripts/train_model_era5_template.sh
similarity index 100%
rename from video_prediction_savp/HPC_scripts/train_era5_template.sh
rename to video_prediction_savp/HPC_scripts/train_model_era5_template.sh
diff --git a/video_prediction_savp/HPC_scripts/train_movingmnist_template.sh b/video_prediction_savp/HPC_scripts/train_model_moving_mnist_template.sh
similarity index 100%
rename from video_prediction_savp/HPC_scripts/train_movingmnist_template.sh
rename to video_prediction_savp/HPC_scripts/train_model_moving_mnist_template.sh
diff --git a/video_prediction_savp/HPC_scripts/generate_era5_template.sh b/video_prediction_savp/HPC_scripts/visualize_postprocess_era5_template.sh
similarity index 100%
rename from video_prediction_savp/HPC_scripts/generate_era5_template.sh
rename to video_prediction_savp/HPC_scripts/visualize_postprocess_era5_template.sh
diff --git a/video_prediction_savp/HPC_scripts/generate_movingmnist_template.sh b/video_prediction_savp/HPC_scripts/visualize_postprocess_moving_mnist_template.sh
similarity index 100%
rename from video_prediction_savp/HPC_scripts/generate_movingmnist_template.sh
rename to video_prediction_savp/HPC_scripts/visualize_postprocess_moving_mnist_template.sh
diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2_anomaly.py b/video_prediction_savp/deprecate/datasets/era5_dataset_v2_anomaly.py
similarity index 100%
rename from video_prediction_savp/video_prediction/datasets/era5_dataset_v2_anomaly.py
rename to video_prediction_savp/deprecate/datasets/era5_dataset_v2_anomaly.py
diff --git a/video_prediction_savp/scripts/generate_movingmnist.py b/video_prediction_savp/scripts/generate_movingmnist.py
deleted file mode 100644
index d4fbf5eb5d8d8f4cad87ae26d15bc2787d9e6c0a..0000000000000000000000000000000000000000
--- a/video_prediction_savp/scripts/generate_movingmnist.py
+++ /dev/null
@@ -1,822 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import errno
-import json
-import os
-import math
-import random
-import cv2
-import numpy as np
-import tensorflow as tf
-import pickle
-from random import seed
-import random
-import json
-import numpy as np
-import matplotlib
-matplotlib.use('Agg')
-import matplotlib.pyplot as plt
-import matplotlib.gridspec as gridspec
-import matplotlib.animation as animation
-import pandas as pd
-import re
-from video_prediction import datasets, models
-from matplotlib.colors import LinearSegmentedColormap
-#from matplotlib.ticker import MaxNLocator
-#from video_prediction.utils.ffmpeg_gif import save_gif
-from skimage.metrics import structural_similarity as ssim
-import datetime
-# Scarlet 2020/05/28: access to statistical values in json file 
-from os import path
-import sys
-sys.path.append(path.abspath('../video_prediction/datasets/'))
-from era5_dataset_v2 import Norm_data
-from os.path import dirname
-from netCDF4 import Dataset,date2num
-from metadata import MetaData as MetaData
-
-def set_seed(seed):
-    if seed is not None:
-        tf.set_random_seed(seed)
-        np.random.seed(seed)
-        random.seed(seed) 
-
-def get_coordinates(metadata_fname):
-    """
-    Retrieves the latitudes and longitudes read from the metadata json file.
-    """
-    md = MetaData(json_file=metadata_fname)
-    md.get_metadata_from_file(metadata_fname)
-    
-    try:
-        print("lat:",md.lat)
-        print("lon:",md.lon)
-        return md.lat, md.lon
-    except:
-        raise ValueError("Error when handling: '"+metadata_fname+"'")
-    
-
-def load_checkpoints_and_create_output_dirs(checkpoint,dataset,model):
-    if checkpoint:
-        checkpoint_dir = os.path.normpath(checkpoint)
-        if not os.path.isdir(checkpoint):
-            checkpoint_dir, _ = os.path.split(checkpoint_dir)
-        if not os.path.exists(checkpoint_dir):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
-        with open(os.path.join(checkpoint_dir, "options.json")) as f:
-            print("loading options from checkpoint %s" % checkpoint)
-            options = json.loads(f.read())
-            dataset = dataset or options['dataset']
-            model = model or options['model']
-        try:
-            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
-                dataset_hparams_dict = json.loads(f.read())
-        except FileNotFoundError:
-            print("dataset_hparams.json was not loaded because it does not exist")
-        try:
-            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
-                model_hparams_dict = json.loads(f.read())
-        except FileNotFoundError:
-            print("model_hparams.json was not loaded because it does not exist")
-    else:
-        if not dataset:
-            raise ValueError('dataset is required when checkpoint is not specified')
-        if not model:
-            raise ValueError('model is required when checkpoint is not specified')
-
-    return options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict
-
-
-    
-def setup_dataset(dataset,input_dir,mode,seed,num_epochs,dataset_hparams,dataset_hparams_dict):
-    VideoDataset = datasets.get_dataset_class(dataset)
-    dataset = VideoDataset(
-        input_dir,
-        mode = mode,
-        num_epochs = num_epochs,
-        seed = seed,
-        hparams_dict = dataset_hparams_dict,
-        hparams = dataset_hparams)
-    return dataset
-
-
-def setup_dirs(input_dir,results_png_dir):
-    input_dir = args.input_dir
-    temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/"
-    print ("temporal_dir:",temporal_dir)
-
-
-def update_hparams_dict(model_hparams_dict,dataset):
-    hparams_dict = dict(model_hparams_dict)
-    hparams_dict.update({
-        'context_frames': dataset.hparams.context_frames,
-        'sequence_length': dataset.hparams.sequence_length,
-        'repeat': dataset.hparams.time_shift,
-    })
-    return hparams_dict
-
-
-def psnr(img1, img2):
-    mse = np.mean((img1 - img2) ** 2)
-    if mse == 0: return 100
-    PIXEL_MAX = 1
-    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
-
-
-def setup_num_samples_per_epoch(num_samples, dataset):
-    if num_samples:
-        if num_samples > dataset.num_examples_per_epoch():
-            raise ValueError('num_samples cannot be larger than the dataset')
-        num_examples_per_epoch = num_samples
-    else:
-        num_examples_per_epoch = dataset.num_examples_per_epoch()
-    #if num_examples_per_epoch % args.batch_size != 0:
-    #    raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
-    return num_examples_per_epoch
-
-
-def initia_save_data():
-    sample_ind = 0
-    gen_images_all = []
-    #Bing:20200410
-    persistent_images_all = []
-    input_images_all = []
-    return sample_ind, gen_images_all,persistent_images_all, input_images_all
-
-
-def write_params_to_results_dir(args,output_dir,dataset,model):
-    if not os.path.exists(output_dir):
-        os.makedirs(output_dir)
-    with open(os.path.join(output_dir, "options.json"), "w") as f:
-        f.write(json.dumps(vars(args), sort_keys = True, indent = 4))
-    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
-        f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4))
-    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
-        f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4))
-    return None
-
-def get_one_seq_and_time(input_images,i):
-    assert (len(np.array(input_images).shape)==5)
-    input_images_ = input_images[i,:,:,:,:]
-    return input_images_
-
-
-def denorm_images_all_channels(input_images_):
-    input_images_all_channles_denorm = []
-    input_images_ = np.array(input_images_)
-    input_images_denorm = input_images_ * 255.0
-    #print("input_images_denorm shape",input_images_denorm.shape)
-    return input_images_denorm
-
-def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"):
-    """
-    Plot the seq images 
-    """
-
-    if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)")
-    img_len = imgs.shape[0]
-    fig = plt.figure(figsize=(18,6))
-    gs = gridspec.GridSpec(1, 10)
-    gs.update(wspace = 0., hspace = 0.)
-    for i in range(img_len):      
-        ax1 = plt.subplot(gs[i])
-        plt.imshow(imgs[i] ,cmap = 'jet')
-        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-    plt.savefig(os.path.join(output_png_dir, label + "_" +   str(idx) +  ".jpg"))
-    print("images_saved")
-    plt.clf()
- 
-
-    
-def get_persistence(ts):
-    pass
-
-
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--input_dir", type = str, required = True,
-                        help = "either a directory containing subdirectories "
-                               "train, val, test, etc, or a directory containing "
-                               "the tfrecords")
-    parser.add_argument("--results_dir", type = str, default = 'results',
-                        help = "ignored if output_gif_dir is specified")
-    parser.add_argument("--checkpoint",
-                        help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
-    parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val',
-                        help = 'mode for dataset, val or test.')
-    parser.add_argument("--dataset", type = str, help = "dataset class name")
-    parser.add_argument("--dataset_hparams", type = str,
-                        help = "a string of comma separated list of dataset hyperparameters")
-    parser.add_argument("--model", type = str, help = "model class name")
-    parser.add_argument("--model_hparams", type = str,
-                        help = "a string of comma separated list of model hyperparameters")
-    parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch")
-    parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)")
-    parser.add_argument("--num_epochs", type = int, default = 1)
-    parser.add_argument("--num_stochastic_samples", type = int, default = 1)
-    parser.add_argument("--gif_length", type = int, help = "default is sequence_length")
-    parser.add_argument("--fps", type = int, default = 4)
-    parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use")
-    parser.add_argument("--seed", type = int, default = 7)
-    args = parser.parse_args()
-    set_seed(args.seed)
-
-    dataset_hparams_dict = {}
-    model_hparams_dict = {}
-
-    options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict = load_checkpoints_and_create_output_dirs(args.checkpoint,args.dataset,args.model)
-    print("Step 1 finished")
-
-    print('----------------------------------- Options ------------------------------------')
-    for k, v in args._get_kwargs():
-        print(k, "=", v)
-    print('------------------------------------- End --------------------------------------')
-
-    #setup dataset and model object
-    input_dir_tf = os.path.join(args.input_dir, "tfrecords") # where tensorflow records are stored
-    dataset = setup_dataset(dataset,input_dir_tf,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict)
-    
-    print("Step 2 finished")
-    VideoPredictionModel = models.get_model_class(model)
-    
-    hparams_dict = dict(model_hparams_dict)
-    hparams_dict.update({
-        'context_frames': dataset.hparams.context_frames,
-        'sequence_length': dataset.hparams.sequence_length,
-        'repeat': dataset.hparams.time_shift,
-    })
-    
-    model = VideoPredictionModel(
-        mode = args.mode,
-        hparams_dict = hparams_dict,
-        hparams = args.model_hparams)
-
-    sequence_length = model.hparams.sequence_length
-    context_frames = model.hparams.context_frames
-    future_length = sequence_length - context_frames #context_Frames is the number of input frames
-
-    num_examples_per_epoch = setup_num_samples_per_epoch(args.num_samples,dataset)
-    
-    inputs = dataset.make_batch(args.batch_size)
-    print("inputs",inputs)
-    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
-    print("input_phs",input_phs)
-    
-    
-    # Build graph
-    with tf.variable_scope(''):
-        model.build_graph(input_phs)
-
-    #Write the update hparameters into results_dir    
-    write_params_to_results_dir(args=args,output_dir=args.results_dir,dataset=dataset,model=model)
-        
-    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac)
-    config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True)
-    sess = tf.Session(config = config)
-    sess.graph.as_default()
-    sess.run(tf.global_variables_initializer())
-    sess.run(tf.local_variables_initializer())
-    model.restore(sess, args.checkpoint)
-    
-    #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data
-    sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data()
-    
-    is_first=True
-    #loop for in samples
-    while sample_ind < 5:
-        gen_images_stochastic = []
-        if args.num_samples and sample_ind >= args.num_samples:
-            break
-        try:
-            input_results = sess.run(inputs)
-            input_images = input_results["images"]
-            #get the intial times
-            t_starts = input_results["T_start"]
-        except tf.errors.OutOfRangeError:
-            break
-            
-        #Get prediction values 
-        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
-        gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel]
-        print("gen_images 20200822:",np.array(gen_images).shape)       
-        #Loop in batch size
-        for i in range(args.batch_size):
-            
-            #get one seq and the corresponding start time point
-            input_images_ = get_one_seq_and_time(input_images,i)
-            
-            #Renormalized data for inputs
-            input_images_denorm = denorm_images_all_channels(input_images_)  
-            print("input_images_denorm",input_images_denorm[0][0])
-                                                             
-            #Renormalized data for inputs
-            gen_images_ = gen_images[i]
-            gen_images_denorm = denorm_images_all_channels(gen_images_)
-            print("gene_images_denorm:",gen_images_denorm[0][0])
-            
-            #Generate images inputs
-            plot_seq_imgs(imgs=input_images_denorm[context_frames+1:,:,:,0],idx = sample_ind + i, label="Ground Truth",output_png_dir=args.results_dir)  
-                                                             
-            #Generate forecast images
-            plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],idx = sample_ind + i,label="Forecast by Model " + args.model,output_png_dir=args.results_dir) 
-            
-            #TODO: Scaret plot persistence image
-            #implment get_persistence() function
-
-            #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation
-
-        sample_ind += args.batch_size
-
-
-        #for input_image in input_images_:
-
-#             for stochastic_sample_ind in range(args.num_stochastic_samples):
-#                 input_images_all.extend(input_images)
-#                 with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files:
-#                     pickle.dump(list(input_images_all), input_files)
-
-
-#                 gen_images_stochastic.append(gen_images)
-#                 #print("Stochastic_sample,", stochastic_sample_ind)
-#                 for i in range(args.batch_size):
-#                     #bing:20200417
-#                     t_stampe = test_temporal_pkl[sample_ind+i]
-#                     print("timestamp:",type(t_stampe))
-#                     persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1)
-#                     print ("persistent ts",persistent_ts)
-#                     persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts))
-#                     persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length]
-#                     print("persistent index in test set:", persistent_idx)
-#                     print("persistent_X.shape",persistent_X.shape)
-#                     persistent_images_all.append(persistent_X)
-
-#                     cmap_name = 'my_list'
-#                     if sample_ind < 100:
-#                         #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
-#                         #    sample_ind) + " + Sample_" + str(i)
-#                         name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S")
-#                         print ("name",name)
-#                         gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
-#                         #gen_images_ =  gen_images[i, :]
-#                         input_images_ = input_images[i, :]
-#                         #Bing:20200417
-#                         #persistent_images = ?
-#                         #+++Scarlet:20200528   
-#                         #print('Scarlet1')
-#                         input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm)
-#                         persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm)
-#                         #---Scarlet:20200528    
-#                         gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
-#                                         range(sequence_length)]  # return the list with 10 (sequence) mse
-#                         persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in
-#                                         range(sequence_length)]  # return the list with 10 (sequence) mse
-
-#                         fig = plt.figure(figsize=(18,6))
-#                         gs = gridspec.GridSpec(1, 10)
-#                         gs.update(wspace = 0., hspace = 0.)
-#                         ts = list(range(10,20)) #[10,11,12,..]
-#                         xlables = [round(i,2) for i  in list(np.linspace(np.min(lon),np.max(lon),5))]
-#                         ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
-
-#                         for t in ts:
-
-#                             #if t==0 : ax1=plt.subplot(gs[t])
-#                             ax1 = plt.subplot(gs[ts.index(t)])
-#                             #+++Scarlet:20200528
-#                             #print('Scarlet2')
-#                             input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm)
-#                             #---Scarlet:20200528
-#                             plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
-#                             ax1.title.set_text("t = " + str(t+1-10))
-#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-#                             if t == 0:
-#                                 plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
-#                                 plt.ylabel("Ground Truth", fontsize=10)
-#                         plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
-#                         plt.clf()
-
-#                         fig = plt.figure(figsize=(12,6))
-#                         gs = gridspec.GridSpec(1, 10)
-#                         gs.update(wspace = 0., hspace = 0.)
-
-#                         for t in ts:
-#                             #if t==0 : ax1=plt.subplot(gs[t])
-#                             ax1 = plt.subplot(gs[ts.index(t)])
-#                             #+++Scarlet:20200528
-#                             #print('Scarlet3')
-#                             gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm)
-#                             #---Scarlet:20200528
-#                             plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
-#                             ax1.title.set_text("t = " + str(t+1-10))
-#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-
-#                         plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
-#                         plt.clf()
-
-
-#                         fig = plt.figure(figsize=(12,6))
-#                         gs = gridspec.GridSpec(1, 10)
-#                         gs.update(wspace = 0., hspace = 0.)
-#                         for t in ts:
-#                             #if t==0 : ax1=plt.subplot(gs[t])
-#                             ax1 = plt.subplot(gs[ts.index(t)])
-#                             #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
-#                             plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300)
-#                             ax1.title.set_text("t = " + str(t+1-10))
-#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-
-#                         plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg"))
-#                         plt.clf()
-
-                        
-#                 with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files:
-#                     pickle.dump(list(persistent_images_all), input_files)
-#                     print ("Save persistent all")
-#                 if is_first:
-#                     gen_images_all = gen_images_stochastic
-#                     is_first = False
-#                 else:
-#                     gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
-
-#                 if args.num_stochastic_samples == 1:
-#                     with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files:
-#                         pickle.dump(list(gen_images_all[0]), gen_files)
-#                         print ("Save generate all")
-#                 else:
-#                     with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
-#                         pickle.dump(list(gen_images_stochastic), gen_files)
-#                     with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
-#                         pickle.dump(list(gen_images_all), gen_files)
-
-#         sample_ind += args.batch_size
-
-
-#     with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files:
-#         input_images_all = pickle.load(input_files)
-
-#     with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files:
-#         gen_images_all = pickle.load(gen_files)
-
-#     with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files:
-#         persistent_images_all = pickle.load(gen_files)
-
-#     #+++Scarlet:20200528
-#     #print('Scarlet4')
-#     input_images_all = np.array(input_images_all)
-#     input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm)
-#     #---Scarlet:20200528
-#     persistent_images_all = np.array(persistent_images_all)
-#     if len(np.array(gen_images_all).shape) == 6:
-#         for i in range(len(gen_images_all)):
-#             #+++Scarlet:20200528
-#             #print('Scarlet5')
-#             gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:]
-#             gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm)
-#             #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
-#             #---Scarlet:20200528
-#             mse_all = []
-#             psnr_all = []
-#             ssim_all = []
-#             f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w')
-#             for i in range(future_length):
-#                 mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :,
-#                                                                             0]) ** 2)  # look at all timesteps except the first
-#                 psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0])
-#                 ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0],
-#                                   data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min(
-#                                       input_images_all[:, i + 10, :, :, 0].flatten()))
-#                 mse_all.extend([mse_model])
-#                 psnr_all.extend([psnr_model])
-#                 ssim_all.extend([ssim_model])
-#                 results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
-#                 f.write("##########Predicted Frame {}\n".format(str(i + 1)))
-#                 f.write("Model MSE: %f\n" % mse_model)
-#                 # f.write("Previous Frame MSE: %f\n" % mse_prev)
-#                 f.write("Model PSNR: %f\n" % psnr_model)
-#                 f.write("Model SSIM: %f\n" % ssim_model)
-
-
-#             pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb"))
-#             # f.write("Previous frame PSNR: %f\n" % psnr_prev)
-#             f.write("Shape of X_test: " + str(input_images_all.shape))
-#             f.write("")
-#             f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape))
-
-#     else:
-#         #+++Scarlet:20200528
-#         #print('Scarlet6')
-#         gen_images_all = np.array(gen_images_all)
-#         gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm)
-#         #---Scarlet:20200528
-        
-#         # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
-#         # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
-#         # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
-#         mse_all = []
-#         psnr_all = []
-#         ssim_all = []
-#         persistent_mse_all = []
-#         persistent_psnr_all = []
-#         persistent_ssim_all = []
-#         f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w')
-#         for i in range(future_length):
-#             mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :,
-#                                                                         0]) ** 2)  # look at all timesteps except the first
-#             persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :,
-#                                                                         0]) ** 2)  # look at all timesteps except the first
-            
-#             psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0])
-#             ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0],
-#                               data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min(
-#                                   input_images_all[:, i + 10, :, :, 0].flatten()))
-#             persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0])
-#             persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0],
-#                               data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten()))
-#             mse_all.extend([mse_model])
-#             psnr_all.extend([psnr_model])
-#             ssim_all.extend([ssim_model])
-#             persistent_mse_all.extend([persistent_mse_model])
-#             persistent_psnr_all.extend([persistent_psnr_model])
-#             persistent_ssim_all.extend([persistent_ssim_model])
-#             results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
-
-#             persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all}
-#             f.write("##########Predicted Frame {}\n".format(str(i + 1)))
-#             f.write("Model MSE: %f\n" % mse_model)
-#             # f.write("Previous Frame MSE: %f\n" % mse_prev)
-#             f.write("Model PSNR: %f\n" % psnr_model)
-#             f.write("Model SSIM: %f\n" % ssim_model)
-
-#         pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb"))
-#         pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb"))
-#         # f.write("Previous frame PSNR: %f\n" % psnr_prev)
-#         f.write("Shape of X_test: " + str(input_images_all.shape))
-#         f.write("")
-#         f.write("Shape of X_hat: " + str(gen_images_all.shape)      
-
-if __name__ == '__main__':
-    main()        
-
-    #psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
-    #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
-    #psnr_prev = psnr(input_images_all[:, :, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
-
-    # ims = []
-    # fig = plt.figure()
-    # for frame in range(20):
-    #     input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel)
-    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
-    #     #pix_std = np.std(input_gen_diff, axis=0)
-    #     im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
-    #     if frame == 0:
-    #         fig.colorbar(im)
-    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
-    #     ims.append([im, ttl])
-    # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000)
-    # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
-    # plt.close("all")
-
-    # ims = []
-    # fig = plt.figure()
-    # for frame in range(19):
-    #     pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
-    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
-    #     #pix_std = np.std(input_gen_diff, axis=0)
-    #     im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
-    #     if frame == 0:
-    #         fig.colorbar(im)
-    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
-    #     ims.append([im, ttl])
-    # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
-    # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
-
-    # seed(1)
-    # s = random.sample(range(len(gen_images_all)), 100)
-    # print("******KDP******")
-    # #kernel density plot for checking the model collapse
-    # fig = plt.figure()
-    # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
-    # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
-    # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
-    # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
-    # plt.clf()
-
-    #line plot for evaluating the prediction and groud-truth
-    # for i in [0,3,6,9,12,15,18]:
-    #     fig = plt.figure()
-    #     plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
-    #     #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
-    #     plt.xlabel("Prediction")
-    #     plt.ylabel("Real values")
-    #     plt.title("Frame_{}".format(i+1))
-    #     plt.plot([250,300], [250,300],color="black")
-    #     plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
-    #     plt.clf()
-    #
-    # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
-    # x = [str(i+1) for i in list(range(19))]
-    # fig,axis = plt.subplots()
-    # mean_f = np.mean(mse_model_by_frames, axis = 0)
-    # median = np.median(mse_model_by_frames, axis=0)
-    # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
-    # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
-    # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
-    # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
-    # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
-    # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
-    # plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
-    # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
-    # plt.title(f'MSE percentile')
-    # plt.xlabel("Frames")
-    # plt.legend(loc=2, fontsize=8)
-    # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
-
-
-##                
-##
-##                    # fig = plt.figure()
-##                    # gs = gridspec.GridSpec(4,6)
-##                    # gs.update(wspace = 0.7,hspace=0.8)
-##                    # ax1 = plt.subplot(gs[0:2,0:3])
-##                    # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
-##                    # ax3 = plt.subplot(gs[2:4,0:3])
-##                    # ax4 = plt.subplot(gs[2:4,3:])
-##                    # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
-##                    # ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
-##                    # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
-##                    # ax1.title.set_text("(a) Ground Truth")
-##                    # ax2.title.set_text("(b) SAVP")
-##                    # ax3.title.set_text("(c) Diff.")
-##                    # ax4.title.set_text("(d) MSE")
-##                    #
-##                    # ax1.xaxis.set_tick_params(labelsize=7)
-##                    # ax1.yaxis.set_tick_params(labelsize = 7)
-##                    # ax2.xaxis.set_tick_params(labelsize=7)
-##                    # ax2.yaxis.set_tick_params(labelsize = 7)
-##                    # ax3.xaxis.set_tick_params(labelsize=7)
-##                    # ax3.yaxis.set_tick_params(labelsize = 7)
-##                    #
-##                    # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
-##                    # print("inti images shape", init_images.shape)
-##                    # xdata, ydata = [], []
-##                    # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
-##                    # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
-##                    # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
-##                    # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
-##                    # #x = np.linspace(0, 64, 64)
-##                    # #y = np.linspace(0, 64, 64)
-##                    # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
-##                    # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
-##                    # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
-##                    # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
-##                    #
-##                    # cm = LinearSegmentedColormap.from_list(
-##                    #     cmap_name, "bwr", N = 5)
-##                    #
-##                    # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r',
-##                    # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm)  # cmap = 'PuBu_r',
-##                    # plot4, = ax4.plot([], [], color = "r")
-##                    # ax4.set_xlim(0, future_length-1)
-##                    # ax4.set_ylim(0, 20)
-##                    # #ax4.set_ylim(0, 0.5)
-##                    # ax4.set_xlabel("Frames", fontsize=10)
-##                    # #ax4.set_ylabel("MSE", fontsize=10)
-##                    # ax4.xaxis.set_tick_params(labelsize=7)
-##                    # ax4.yaxis.set_tick_params(labelsize=7)
-##                    #
-##                    #
-##                    # plots = [plot1, plot2, plot3, plot4]
-##                    #
-##                    # #fig.colorbar(plots[1], ax = [ax1, ax2])
-##                    #
-##                    # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
-##                    # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
-##                    # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
-##                    #
-##                    # def animation_sample(t):
-##                    #     input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
-##                    #     gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
-##                    #     diff_image = input_gen_diff[t,:,:]
-##                    #     # p = sns.lineplot(x=x,y=data,color="b")
-##                    #     # p.tick_params(labelsize=17)
-##                    #     # plt.setp(p.lines, linewidth=6)
-##                    #     plots[0].set_data(input_image)
-##                    #     plots[1].set_data(gen_image)
-##                    #     #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
-##                    #     #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
-##                    #     plots[2].set_data(diff_image)
-##                    #
-##                    #     if t >= future_length:
-##                    #         #data = gen_mse_avg_[:t + 1]
-##                    #         # x = list(range(len(gen_mse_avg_)))[:t+1]
-##                    #         xdata.append(t-future_length)
-##                    #         print("xdata", xdata)
-##                    #         ydata.append(gen_mse_avg_[t])
-##                    #         print("ydata", ydata)
-##                    #         plots[3].set_data(xdata, ydata)
-##                    #         fig.suptitle("Predicted Frame " + str(t-future_length))
-##                    #     else:
-##                    #         #plots[3].set_data(xdata, ydata)
-##                    #         fig.suptitle("Context Frame " + str(t))
-##                    #     return plots
-##                    #
-##                    # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
-##                    #                               repeat_delay=2000)
-##                    # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
-##
-####                else:
-####                    pass
-##
-
-
-
-
-
-    #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
-    #         #     ims = []
-    #         #     fig = plt.figure()
-    #         #     plt.xlim(0,len(gen_mse_avg_))
-    #         #     plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg))
-    #         #     plt.xlabel("Frames")
-    #         #     plt.ylabel("MSE_AVG")
-    #         #     #X = list(range(len(gen_mse_avg_)))
-    #         #     #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
-    #         #     def animate_metric(j):
-    #         #         data = gen_mse_avg_[:(j+1)]
-    #         #         x = list(range(len(gen_mse_avg_)))[:(j+1)]
-    #         #         p = sns.lineplot(x=x,y=data,color="b")
-    #         #         p.tick_params(labelsize=17)
-    #         #         plt.setp(p.lines, linewidth=6)
-    #         #     ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000)
-    #         #     ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif"))
-    #         #
-    #         #
-    #         # for i, input_images_ in enumerate(input_images):
-    #         #     #context_images_ = (input_results['images'][i])
-    #         #     #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
-    #         #     ims = []
-    #         #     fig = plt.figure()
-    #         #     for t, input_image in enumerate(input_images_):
-    #         #         im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')
-    #         #         ttl = plt.text(1.5, 2,"Frame_" + str(t))
-    #         #         ims.append([im,ttl])
-    #         #     ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000)
-    #         #     ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif"))
-    #         #     #plt.show()
-    #         #
-    #         # for i,gen_images_ in enumerate(gen_images):
-    #         #     ims = []
-    #         #     fig = plt.figure()
-    #         #     for t, gen_image in enumerate(gen_images_):
-    #         #         im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
-    #         #         ttl = plt.text(1.5, 2, "Frame_" + str(t))
-    #         #         ims.append([im, ttl])
-    #         #     ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
-    #         #     ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif"))
-    #
-    #
-    #             # for i, gen_images_ in enumerate(gen_images):
-    #             #     #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8)
-    #             #     #gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
-    #             #     #bing
-    #             #     context_images_ = (input_results['images'][i])
-    #             #     gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
-    #             #     context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
-    #             #     plt.figure(figsize = (10,2))
-    #             #     gs = gridspec.GridSpec(2,10)
-    #             #     gs.update(wspace=0.,hspace=0.)
-    #             #     for t, gen_image in enumerate(gen_images_):
-    #             #         gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1)))
-    #             #         gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
-    #             #         plt.subplot(gs[t])
-    #             #         plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')  # the last index sets the channel. 0 = t2
-    #             #         # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
-    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
-    #             #                         right = False, labelbottom = False, labelleft = False)
-    #             #         if t == 0: plt.ylabel('Actual', fontsize = 10)
-    #             #
-    #             #         plt.subplot(gs[t + 10])
-    #             #         plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
-    #             #         # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
-    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
-    #             #                         right = False, labelbottom = False, labelleft = False)
-    #             #         if t == 0: plt.ylabel('Predicted', fontsize = 10)
-    #             #     plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png')
-    #             #     plt.clf()
-    #
-    #             # if args.gif_length:
-    #             #     context_and_gen_images = context_and_gen_images[:args.gif_length]
-    #             # save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
-    #             #          context_and_gen_images, fps=args.fps)
-    #             #
-    #             # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
-    #             # for t, gen_image in enumerate(gen_images_):
-    #             #     gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
-    #             #     if gen_image.shape[-1] == 1:
-    #             #       gen_image = np.tile(gen_image, (1, 1, 3))
-    #             #     else:
-    #             #       gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
-    #             #     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
diff --git a/video_prediction_savp/scripts/mpi_stager_v2.py b/video_prediction_savp/scripts/main_data_extraction.py
similarity index 100%
rename from video_prediction_savp/scripts/mpi_stager_v2.py
rename to video_prediction_savp/scripts/main_data_extraction.py
diff --git a/video_prediction_savp/scripts/mpi_stager_v2_process_netCDF.py b/video_prediction_savp/scripts/main_preprocess_data_step1.py
similarity index 100%
rename from video_prediction_savp/scripts/mpi_stager_v2_process_netCDF.py
rename to video_prediction_savp/scripts/main_preprocess_data_step1.py
diff --git a/video_prediction_savp/scripts/main_preprocess_data_step2.py b/video_prediction_savp/scripts/main_preprocess_data_step2.py
new file mode 100644
index 0000000000000000000000000000000000000000..1582ff2d0671e8d3806307fdd6288b97e6a1e441
--- /dev/null
+++ b/video_prediction_savp/scripts/main_preprocess_data_step2.py
@@ -0,0 +1,127 @@
+
+
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
+    parser.add_argument("output_dir", type=str)
+    # ML 2020/04/08 S
+    # Add vars for ensuring proper normalization and reshaping of sequences
+    parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.")
+    parser.add_argument("-height",type=int,default=64)
+    parser.add_argument("-width",type = int,default=64)
+    parser.add_argument("-seq_length",type=int,default=20)
+    parser.add_argument("-sequences_per_file",type=int,default=2)
+    args = parser.parse_args()
+    current_path = os.getcwd()
+    #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
+    #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5"
+    #partition_names = ['train','val',  'test'] #64,64,3 val has issue#
+
+    ############################################################
+    # CONTROLLING variable! Needs to be adapted manually!!!
+    ############################################################
+    partition = {
+            "train":{
+           #     "2222":[1,2,3,5,6,7,8,9,10,11,12], # Issue due to month 04, it is missing
+                "2010":[1,2,3,4,5,6,7,8,9,10,11,12],
+           #     "2012":[1,2,3,4,5,6,7,8,9,10,11,12],
+                "2013":[1,2,3,4,5,6,7,8,9,10,11,12],
+                "2015":[1,2,3,4,5,6,7,8,9,10,11,12],
+                "2019":[1,2,3,4,5,6,7,8,9,10,11,12]
+                 },
+            "val":
+                {"2017":[1,2,3,4,5,6,7,8,9,10,11,12]
+                 },
+            "test":
+                {"2016":[1,2,3,4,5,6,7,8,9,10,11,12]
+                 }
+            }
+    
+    # ini. MPI
+    comm = MPI.COMM_WORLD
+    my_rank = comm.Get_rank()  # rank of the node
+    p = comm.Get_size()  # number of assigned nods
+  
+    if my_rank == 0 :
+        # retrieve final statistics first (not parallelized!)
+        # some preparatory steps
+        stat_dir_prefix = args.input_dir
+        varnames        = args.variables
+    
+        vars_uni, varsind, nvars = get_unique_vars(varnames)
+        stat_obj = Calc_data_stat(nvars)                            # init statistic-instance
+    
+        # loop over whole data set (training, dev and test set) to collect the intermediate statistics
+        print("Start collecting statistics from the whole datset to be processed...")
+        for split in partition.keys():
+            values = partition[split]
+            for year in values.keys():
+                file_dir = os.path.join(stat_dir_prefix,year)
+                for month in values[year]:
+                    # process stat-file:
+                    stat_obj.acc_stat_master(file_dir,int(month))  # process monthly statistic-file  
+        
+        # finalize statistics and write to json-file
+        stat_obj.finalize_stat_master(vars_uni)
+        stat_obj.write_stat_json(args.input_dir)
+
+        # organize parallelized partioning 
+        partition_year_month = [] #contain lists of list, each list includes three element [train,year,month]
+        partition_names = list(partition.keys())
+        print ("partition_names:",partition_names)
+        broadcast_lists = []
+        for partition_name in partition_names:
+            partition_data = partition[partition_name]        
+            years = list(partition_data.keys())
+            broadcast_lists.append([partition_name,years])
+        for nodes in range(1,p):
+            #ibroadcast_list = [partition_name,years,nodes]
+            #broadcast_lists.append(broadcast_list)
+            comm.send(broadcast_lists,dest=nodes) 
+           
+        message_counter = 1
+        while message_counter <= 12:
+            message_in = comm.recv()
+            message_counter = message_counter + 1 
+            print("Message in from slaver",message_in) 
+            
+        write_sequence_file(args.output_dir,args.seq_length,args.sequences_per_file)
+        
+        #write_sequence_file   
+    else:
+        message_in = comm.recv()
+        print ("My rank,", my_rank)   
+        print("message_in",message_in)
+        # open statistics file and feed it to norm-instance
+        print("Opening json-file: "+os.path.join(args.input_dir,"statistics.json"))
+        with open(os.path.join(args.input_dir,"statistics.json")) as js_file:
+            stats = json.load(js_file)
+        #loop the partitions (train,val,test)
+        for partition in message_in:
+            print("partition on slave ",partition)
+            partition_name = partition[0]
+            save_output_dir =  os.path.join(args.output_dir,partition_name)
+            for year in partition[1]:
+               input_file = "X_" + '{0:02}'.format(my_rank) + ".pkl"
+               temp_file = "T_" + '{0:02}'.format(my_rank) + ".pkl"
+               input_dir = os.path.join(args.input_dir,year)
+               temp_file = os.path.join(input_dir,temp_file )
+               input_file = os.path.join(input_dir,input_file)
+               # create the tfrecords-files
+               read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir, \
+                                               input_file=input_file,temp_input_file=temp_file,vars_in=args.variables, \
+                                               partition_name=partition_name,seq_length=args.seq_length, \
+                                               height=args.height,width=args.width,sequences_per_file=args.sequences_per_file)   
+                                                  
+            print("Year {} finished",year)
+        message_out = ("Node:",str(my_rank),"finished","","\r\n")
+        print ("Message out for slaves:",message_out)
+        comm.send(message_out,dest=0)
+        
+    MPI.Finalize()        
+   
+if __name__ == '__main__':
+     main()
+
diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/main_train_models.py
similarity index 100%
rename from video_prediction_savp/scripts/train_dummy.py
rename to video_prediction_savp/scripts/main_train_models.py
diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/main_visualize_postprocess.py
similarity index 100%
rename from video_prediction_savp/scripts/generate_transfer_learning_finetune.py
rename to video_prediction_savp/scripts/main_visualize_postprocess.py
diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py b/video_prediction_savp/video_prediction/datasets/era5_dataset.py
similarity index 66%
rename from video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
rename to video_prediction_savp/video_prediction/datasets/era5_dataset.py
index 7a61aa090f9e115ffd140a54dd0784dbbd35c48d..8835363a002c06f7e5cfcf337e4db07d280a3bc6 100644
--- a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
+++ b/video_prediction_savp/video_prediction/datasets/era5_dataset.py
@@ -262,126 +262,3 @@ def write_sequence_file(output_dir,seq_length,sequences_per_file):
     
     
 
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
-    parser.add_argument("output_dir", type=str)
-    # ML 2020/04/08 S
-    # Add vars for ensuring proper normalization and reshaping of sequences
-    parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.")
-    parser.add_argument("-height",type=int,default=64)
-    parser.add_argument("-width",type = int,default=64)
-    parser.add_argument("-seq_length",type=int,default=20)
-    parser.add_argument("-sequences_per_file",type=int,default=2)
-    args = parser.parse_args()
-    current_path = os.getcwd()
-    #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
-    #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5"
-    #partition_names = ['train','val',  'test'] #64,64,3 val has issue#
-
-    ############################################################
-    # CONTROLLING variable! Needs to be adapted manually!!!
-    ############################################################
-    partition = {
-            "train":{
-           #     "2222":[1,2,3,5,6,7,8,9,10,11,12], # Issue due to month 04, it is missing
-                "2010":[1,2,3,4,5,6,7,8,9,10,11,12],
-           #     "2012":[1,2,3,4,5,6,7,8,9,10,11,12],
-                "2013":[1,2,3,4,5,6,7,8,9,10,11,12],
-                "2015":[1,2,3,4,5,6,7,8,9,10,11,12],
-                "2019":[1,2,3,4,5,6,7,8,9,10,11,12]
-                 },
-            "val":
-                {"2017":[1,2,3,4,5,6,7,8,9,10,11,12]
-                 },
-            "test":
-                {"2016":[1,2,3,4,5,6,7,8,9,10,11,12]
-                 }
-            }
-    
-    # ini. MPI
-    comm = MPI.COMM_WORLD
-    my_rank = comm.Get_rank()  # rank of the node
-    p = comm.Get_size()  # number of assigned nods
-  
-    if my_rank == 0 :
-        # retrieve final statistics first (not parallelized!)
-        # some preparatory steps
-        stat_dir_prefix = args.input_dir
-        varnames        = args.variables
-    
-        vars_uni, varsind, nvars = get_unique_vars(varnames)
-        stat_obj = Calc_data_stat(nvars)                            # init statistic-instance
-    
-        # loop over whole data set (training, dev and test set) to collect the intermediate statistics
-        print("Start collecting statistics from the whole datset to be processed...")
-        for split in partition.keys():
-            values = partition[split]
-            for year in values.keys():
-                file_dir = os.path.join(stat_dir_prefix,year)
-                for month in values[year]:
-                    # process stat-file:
-                    stat_obj.acc_stat_master(file_dir,int(month))  # process monthly statistic-file  
-        
-        # finalize statistics and write to json-file
-        stat_obj.finalize_stat_master(vars_uni)
-        stat_obj.write_stat_json(args.input_dir)
-
-        # organize parallelized partioning 
-        partition_year_month = [] #contain lists of list, each list includes three element [train,year,month]
-        partition_names = list(partition.keys())
-        print ("partition_names:",partition_names)
-        broadcast_lists = []
-        for partition_name in partition_names:
-            partition_data = partition[partition_name]        
-            years = list(partition_data.keys())
-            broadcast_lists.append([partition_name,years])
-        for nodes in range(1,p):
-            #ibroadcast_list = [partition_name,years,nodes]
-            #broadcast_lists.append(broadcast_list)
-            comm.send(broadcast_lists,dest=nodes) 
-           
-        message_counter = 1
-        while message_counter <= 12:
-            message_in = comm.recv()
-            message_counter = message_counter + 1 
-            print("Message in from slaver",message_in) 
-            
-        write_sequence_file(args.output_dir,args.seq_length,args.sequences_per_file)
-        
-        #write_sequence_file   
-    else:
-        message_in = comm.recv()
-        print ("My rank,", my_rank)   
-        print("message_in",message_in)
-        # open statistics file and feed it to norm-instance
-        print("Opening json-file: "+os.path.join(args.input_dir,"statistics.json"))
-        with open(os.path.join(args.input_dir,"statistics.json")) as js_file:
-            stats = json.load(js_file)
-        #loop the partitions (train,val,test)
-        for partition in message_in:
-            print("partition on slave ",partition)
-            partition_name = partition[0]
-            save_output_dir =  os.path.join(args.output_dir,partition_name)
-            for year in partition[1]:
-               input_file = "X_" + '{0:02}'.format(my_rank) + ".pkl"
-               temp_file = "T_" + '{0:02}'.format(my_rank) + ".pkl"
-               input_dir = os.path.join(args.input_dir,year)
-               temp_file = os.path.join(input_dir,temp_file )
-               input_file = os.path.join(input_dir,input_file)
-               # create the tfrecords-files
-               read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir, \
-                                               input_file=input_file,temp_input_file=temp_file,vars_in=args.variables, \
-                                               partition_name=partition_name,seq_length=args.seq_length, \
-                                               height=args.height,width=args.width,sequences_per_file=args.sequences_per_file)   
-                                                  
-            print("Year {} finished",year)
-        message_out = ("Node:",str(my_rank),"finished","","\r\n")
-        print ("Message out for slaves:",message_out)
-        comm.send(message_out,dest=0)
-        
-    MPI.Finalize()        
-   
-if __name__ == '__main__':
-     main()
-