diff --git a/.gitignore b/.gitignore
index 027f7926e575fe674bbb2f703a436999d63d7d26..a7dfc3d7d776b25730741f12db6f98cdfa127751 100644
--- a/.gitignore
+++ b/.gitignore
@@ -88,6 +88,7 @@ celerybeat-schedule
 venv/
 ENV/
 virtual_env*/
+virt_env*/
 
 # Spyder project settings
 .spyderproject
@@ -108,8 +109,8 @@ virtual_env*/
 *.DS_Store
 
 # Ignore log- and errorfiles 
-*-err.???????
-*-out.???????
+*-err.[0-9]*
+*-out.[0-9]*
 
 
 #Ignore the results files
diff --git a/video_prediction_savp/HPC_scripts/DataExtraction.sh b/video_prediction_savp/HPC_scripts/DataExtraction.sh
index 0c1491a3f044919f01771ca6ff94ca09a18d3a56..b44065e7babb0411cda6d2849ec429f3672c60d5 100755
--- a/video_prediction_savp/HPC_scripts/DataExtraction.sh
+++ b/video_prediction_savp/HPC_scripts/DataExtraction.sh
@@ -1,4 +1,5 @@
 #!/bin/bash -x
+## Controlling Batch-job
 #SBATCH --account=deepacf
 #SBATCH --nodes=1
 #SBATCH --ntasks=13
@@ -6,18 +7,40 @@
 #SBATCH --cpus-per-task=1
 #SBATCH --output=DataExtraction-out.%j
 #SBATCH --error=DataExtraction-err.%j
-#SBATCH --time=00:20:00
+#SBATCH --time=05:00:00
 #SBATCH --partition=devel
 #SBATCH --mail-type=ALL
-#SBATCH --mail-user=s.stadtler@fz-juelich.de
-##jutil env activate -p deepacf
-
-module purge
-module use $OTHERSTAGES
-module load Stages/2019a
-module addad Intel/2019.3.199-GCC-8.3.0  ParaStationMPI/5.2.2-1
-module load h5py/2.9.0-Python-3.6.8
-module load mpi4py/3.0.1-Python-3.6.8
-module load netcdf4-python/1.5.0.1-Python-3.6.8
-
-srun python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/2017/ --destination_dir /p/scratch/deepacf/scarlet/extractedData
+#SBATCH --mail-user=b.gong@fz-juelich.de
+
+
+jutil env activate -p deepacf
+
+# Name of virtual environment 
+VIRT_ENV_NAME="virt_env_hdfml"
+
+# Loading mouldes
+source ../env_setup/modules_preprocess.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+# Declare path-variables
+source_dir="/p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/"
+dest_dir="/p/scratch/deepacf/video_prediction_shared_folder/extractedData/"
+
+year="2010"
+
+# Run data extraction
+srun python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir ${source_dir}/${year}/ --destination_dir ${dest_dir}/${year}/
+
+
+
+# 2tier pystager 
+#srun python ../../workflow_parallel_frame_prediction/DataExtraction/main_single_master.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/${year}/ --destination_dir ${SAVE_DIR}/extractedData/${year}
diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess.sh b/video_prediction_savp/HPC_scripts/DataPreprocess.sh
index 48ed0581802fe5f629019e729425a7ca1445af4f..aa84de9de7dce7015b26f040aaec48d0b096a816 100755
--- a/video_prediction_savp/HPC_scripts/DataPreprocess.sh
+++ b/video_prediction_savp/HPC_scripts/DataPreprocess.sh
@@ -1,4 +1,5 @@
 #!/bin/bash -x
+## Controlling Batch-job
 #SBATCH --account=deepacf
 #SBATCH --nodes=1
 #SBATCH --ntasks=12
@@ -6,38 +7,58 @@
 #SBATCH --cpus-per-task=1
 #SBATCH --output=DataPreprocess-out.%j
 #SBATCH --error=DataPreprocess-err.%j
-#SBATCH --time=02:20:00
-#SBATCH --partition=batch
+#SBATCH --time=00:20:00
+#SBATCH --partition=devel
 #SBATCH --mail-type=ALL
 #SBATCH --mail-user=b.gong@fz-juelich.de
 
-module --force purge
-module use  $OTHERSTAGES
-module load Stages/2019a
-module load Intel/2019.3.199-GCC-8.3.0  ParaStationMPI/5.2.2-1
-module load h5py/2.9.0-Python-3.6.8
-module load mpi4py/3.0.1-Python-3.6.8
 
+# Name of virtual environment 
+VIRT_ENV_NAME="virt_env_hdfml"
 
-srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
- --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2015/ \
- --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2015/ \
- --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+# Loading mouldes
+source ../env_setup/modules_preprocess.sh
 
-srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
- --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2016/ \
- --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2016/ \
- --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+source_dir=${SAVE_DIR}/extractedData
+destination_dir=${SAVE_DIR}/preprocessedData/era5-Y2015to2017M01to12
+script_dir=`pwd`
 
-srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
- --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017/ \
- --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2017/ \
- --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+declare -a years=("2222"
+                 "2010_1"
+                 "2012"
+                 "2013_complete"
+                 "2015"
+                 "2016"
+                 "2017" 
+                 "2019"
+                  )
 
 
+declare -a years=(
+                 "2015"
+                 "2016"
+                 "2017"
+                  )
 
 
-#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
-# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017 \
-# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016toY2017M01to12-128x160-74d00N71d0E-T_MSL_gph500/2017 \
-# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+# ececute Python-scripts
+for year in "${years[@]}";     do 
+        echo "Year $year"
+	echo "source_dir ${source_dir}/${year}"
+	srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+        --source_dir ${source_dir} -scr_dir ${script_dir} \
+        --destination_dir ${destination_dir} --years ${year} --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710     
+    done
+
+
+#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir} --varnames T2 MSL gph500    
diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess_dev.sh b/video_prediction_savp/HPC_scripts/DataPreprocess_dev.sh
deleted file mode 100755
index b5aa2010cbe2b5b9f87b5b65bade29db974bcc8d..0000000000000000000000000000000000000000
--- a/video_prediction_savp/HPC_scripts/DataPreprocess_dev.sh
+++ /dev/null
@@ -1,68 +0,0 @@
-#!/bin/bash -x
-#SBATCH --account=deepacf
-#SBATCH --nodes=1
-#SBATCH --ntasks=12
-##SBATCH --ntasks-per-node=12
-#SBATCH --cpus-per-task=1
-#SBATCH --output=DataPreprocess-out.%j
-#SBATCH --error=DataPreprocess-err.%j
-#SBATCH --time=00:20:00
-#SBATCH --partition=devel
-#SBATCH --mail-type=ALL
-#SBATCH --mail-user=m.langguth@fz-juelich.de
-
-module --force purge
-module use  $OTHERSTAGES
-module load Stages/2019a
-module load Intel/2019.3.199-GCC-8.3.0  ParaStationMPI/5.2.2-1
-module load h5py/2.9.0-Python-3.6.8
-module load mpi4py/3.0.1-Python-3.6.8
-
-source_dir=/p/scratch/deepacf/video_prediction_shared_folder/extractedData
-destination_dir=/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle
-declare -a years=("2015" 
-                 "2016" 
-                 "2017"
-                  )
-
-
-
-for year in "${years[@]}"; 
-    do 
-        echo "Year $year"
-	echo "source_dir ${source_dir}/${year}"
-	srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
-         --source_dir ${source_dir}/${year}/ \
-         --destination_dir ${destination_dir}/${year}/ --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
-    done
-
-
-srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir}
-
-
-
-
-#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
-# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2015/ \
-# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2015/ \
-# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
-
-#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
-# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2016/ \
-# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2016/ \
-# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
-
-#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
-# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017/ \
-# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2017/ \
-# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
-
-#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py \
-#--destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-#T_MSL_gph500/
-
-
-
-#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
-# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017 \
-# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016toY2017M01to12-128x160-74d00N71d0E-T_MSL_gph500/2017 \
-# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf.sh b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf.sh
index 6f541b9d31f582dfd8b9318f7980930716c6c09b..bcf950e93145bcc8b0d15892a606d4cc5d7dd66e 100755
--- a/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf.sh
+++ b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf.sh
@@ -1,8 +1,8 @@
 #!/bin/bash -x
 #SBATCH --account=deepacf
 #SBATCH --nodes=1
-#SBATCH --ntasks=12
-##SBATCH --ntasks-per-node=12
+#SBATCH --ntasks=13
+##SBATCH --ntasks-per-node=13
 #SBATCH --cpus-per-task=1
 #SBATCH --output=DataPreprocess_to_tf-out.%j
 #SBATCH --error=DataPreprocess_to_tf-err.%j
@@ -11,12 +11,26 @@
 #SBATCH --mail-type=ALL
 #SBATCH --mail-user=b.gong@fz-juelich.de
 
-module purge
-module use $OTHERSTAGES
-module load Stages/2019a
-module load Intel/2019.3.199-GCC-8.3.0  ParaStationMPI/5.2.2-1
-module load h5py/2.9.0-Python-3.6.8
-module load mpi4py/3.0.1-Python-3.6.8
-module load TensorFlow/1.13.1-GPU-Python-3.6.8
 
-srun python ../video_prediction/datasets/era5_dataset_v2.py /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016M01to12-128_160-74.00N710E-T_T_T/splits/ /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016M01to12-128_160-74.00N710E-T_T_T/tfrecords/ -vars T2 T2 T2 
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
+
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/
+destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/
+
+# run Preprocessing (step 2 where Tf-records are generated)
+srun python ../video_prediction/datasets/era5_dataset_v2.py ${source_dir}/pickle ${destination_dir}/tfrecords -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20 
diff --git a/video_prediction_savp/HPC_scripts/generate_era5.sh b/video_prediction_savp/HPC_scripts/generate_era5.sh
index c6121ebfc4e441d587cd97fabb366c106324a8be..bb36609129d2c45c5c0cbfaaf0675a7e338eb09c 100755
--- a/video_prediction_savp/HPC_scripts/generate_era5.sh
+++ b/video_prediction_savp/HPC_scripts/generate_era5.sh
@@ -13,19 +13,33 @@
 #SBATCH --mail-user=b.gong@fz-juelich.de
 ##jutil env activate -p cjjsc42
 
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
 
-module purge
-module load GCC/8.3.0
-module load ParaStationMPI/5.2.2-1
-module load TensorFlow/1.13.1-GPU-Python-3.6.8
-module load netcdf4-python/1.5.0.1-Python-3.6.8
-module load h5py/2.9.0-Python-3.6.8
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
 
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+source_dir=/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/
+checkpoint_dir=/p/scratch/deepacf/video_prediction_shared_folder/models/
+results_dir=/p/scratch/deepacf/video_prediction_shared_folder/results/
 
-python -u ../scripts/generate_transfer_learning_finetune.py \
---input_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/tfrecords/  \
---dataset_hparams sequence_length=20 --checkpoint  /p/scratch/deepacf/video_prediction_shared_folder/models/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/ours_gan \
---mode test --results_dir /p/scratch/deepacf/video_prediction_shared_folder/results/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T \
---batch_size 4 --dataset era5   > generate_era5-out.out
+# name of model
+model=convLSTM
+
+# run postprocessing/generation of model results including evaluation metrics
+srun python -u ../scripts/generate_transfer_learning_finetune.py \
+--input_dir ${source_dir}/tfrecords --dataset_hparams sequence_length=20 --checkpoint  ${checkpoint_dir}/${model} \
+--mode test --model ${model} --results_dir ${results_dir}/${model}/ --batch_size 2 --dataset era5   > generate_era5-out.out
 
 #srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
diff --git a/video_prediction_savp/HPC_scripts/hyperparam_setup.sh b/video_prediction_savp/HPC_scripts/hyperparam_setup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a6c24a062ca30b06879641806d771beacc4b34f8
--- /dev/null
+++ b/video_prediction_savp/HPC_scripts/hyperparam_setup.sh
@@ -0,0 +1,12 @@
+#!/usr/bin/env bash
+
+# for choosing the model
+export model=convLSTM
+export model_hparams=../hparams/era5/${model}/model_hparams.json
+
+#create a subfolder with create time and user names, which can be consider as hyperparameter tunning folder. This can avoid overwrite the prevoius trained model using differ#ent hypermeters 
+export hyperdir="$(date +"%Y%m%dT%H%M")_"$USER""
+
+echo "model: ${model}"
+echo "hparams: ${model_hparams}"
+echo "experiment dir: ${hyperdir}"
diff --git a/video_prediction_savp/HPC_scripts/reset_dirs.sh b/video_prediction_savp/HPC_scripts/reset_dirs.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8de5247e044150d1c01eccfa512b9ae1c0e4cdfa
--- /dev/null
+++ b/video_prediction_savp/HPC_scripts/reset_dirs.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+
+sed -i "s|source_dir=.*|source_dir=${SAVE_DIR}preprocessedData/|g" DataPreprocess_to_tf.sh
+sed -i "s|destination_dir=.*|destination_dir=${SAVE_DIR}preprocessedData/|g" DataPreprocess_to_tf.sh
+
+sed -i "s|source_dir=.*|source_dir=${SAVE_DIR}preprocessedData/|g" train_era5.sh
+sed -i "s|destination_dir=.*|destination_dir=${SAVE_DIR}models/|g" train_era5.sh
+
+sed -i "s|source_dir=.*|source_dir=${SAVE_DIR}preprocessedData/|g" generate_era5.sh
+sed -i "s|checkpoint_dir=.*|checkpoint_dir=${SAVE_DIR}models/|g" generate_era5.sh
+sed -i "s|results_dir=.*|results_dir=${SAVE_DIR}results/|g" generate_era5.sh
diff --git a/video_prediction_savp/HPC_scripts/train_era5.sh b/video_prediction_savp/HPC_scripts/train_era5.sh
index ef060b0d985aa0141a1a1cdb974bf04f37b7204b..f605866056f6b2d9fa179a00850468fee0c72d87 100755
--- a/video_prediction_savp/HPC_scripts/train_era5.sh
+++ b/video_prediction_savp/HPC_scripts/train_era5.sh
@@ -7,21 +7,41 @@
 #SBATCH --output=train_era5-out.%j
 #SBATCH --error=train_era5-err.%j
 #SBATCH --time=00:20:00
-#SBATCH --gres=gpu:1
+#SBATCH --gres=gpu:2
 #SBATCH --partition=develgpus
 #SBATCH --mail-type=ALL
 #SBATCH --mail-user=b.gong@fz-juelich.de
 ##jutil env activate -p cjjsc42
 
-module --force purge 
-module use $OTHERSTAGES
-module load Stages/2019a
-module load GCCcore/.8.3.0
-module load mpi4py/3.0.1-Python-3.6.8
-module load h5py/2.9.0-serial-Python-3.6.8
-module load TensorFlow/1.13.1-GPU-Python-3.6.8
-module load cuDNN/7.5.1.10-CUDA-10.1.105
 
+# Name of virtual environment 
+VIRT_ENV_NAME="vp"
 
-srun python ../scripts/train_v2.py --input_dir  /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/2017M01to12-64_64-50.00N11.50E-T_T_T/tfrecords  --dataset era5  --model savp --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir /p/scratch/deepacf/video_prediction_shared_folder/models/2017M01to12-64_64-50.00N11.50E-T_T_T/ours_savp
-#srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
+# Loading mouldes
+source ../env_setup/modules_train.sh
+# Activate virtual environment if needed (and possible)
+if [ -z ${VIRTUAL_ENV} ]; then
+   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+      echo "Activating virtual environment..."
+      source ../${VIRT_ENV_NAME}/bin/activate
+   else 
+      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
+      exit 1
+   fi
+fi
+
+
+
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/
+destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/
+
+# for choosing the model
+model=convLSTM
+model_hparams=../hparams/era5/${model}/model_hparams.json
+
+# rund training
+srun python ../scripts/train_dummy.py --input_dir  ${source_dir}/tfrecords/ --dataset era5  --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/
+
+ 
diff --git a/video_prediction_savp/env_setup/create_env.sh b/video_prediction_savp/env_setup/create_env.sh
old mode 100755
new mode 100644
index 7d0f0a10bd8586e59fe129198a5a6f7c21121502..ad388826caf1d077c1c6434acae29d6cbaa9c6fc
--- a/video_prediction_savp/env_setup/create_env.sh
+++ b/video_prediction_savp/env_setup/create_env.sh
@@ -1,39 +1,111 @@
 #!/usr/bin/env bash
+#
+# __authors__ = Bing Gong, Michael Langguth
+# __date__  = '2020_07_24'
 
-if [[ ! -n "$1" ]]; then
-  echo "Provide the user name, which will be taken as folder name"
+# This script can be used for setting up the virtual environment needed for ambs-project
+# or to simply activate it.
+#
+# some first sanity checks
+if [[ ${BASH_SOURCE[0]} == ${0} ]]; then
+  echo "ERROR: 'create_env.sh' must be sourced, i.e. execute by prompting 'source create_env.sh [virt_env_name]'"
   exit 1
 fi
 
-if [[ ! -n "$2" ]]; then
-  echo "Provide the env name, which will be taken as folder name"
-  exit 1
+# from now on, just return if something unexpected occurs instead of exiting
+# as the latter would close the terminal including logging out
+if [[ ! -n "$1" ]]; then
+  echo "ERROR: Provide a name to set up the virtual environment, i.e. execute by prompting 'source create_env.sh [virt_env_name]"
+  return
 fi
 
-ENV_NAME=$2
-FOLDER_NAME=$1
-WORKING_DIR=/p/project/deepacf/deeprain/${FOLDER_NAME}/video_prediction_savp
-ENV_SETUP_DIR=${WORKING_DIR}/env_setup
+HOST_NAME=`hostname`
+ENV_NAME=$1
+ENV_SETUP_DIR=`pwd`
+WORKING_DIR="$(dirname "$ENV_SETUP_DIR")"
+EXE_DIR="$(basename "$ENV_SETUP_DIR")"
 ENV_DIR=${WORKING_DIR}/${ENV_NAME}
 
-source ${ENV_SETUP_DIR}/modules.sh
-# Install additional Python packages.
-python3 -m venv $ENV_DIR
-source ${ENV_DIR}/bin/activate
-pip3 install -r ${ENV_SETUP_DIR}/requirements.txt
-#pip3 install --user netCDF4
-#pip3 install --user numpy
+# further sanity checks:
+# * ensure execution from env_setup-directory
+# * check if virtual env has already been set up
+
+if [[ "${EXE_DIR}" != "env_setup"  ]]; then
+  echo "ERROR: The setup-script for the virtual environment from the env_setup-directory!"
+  return
+fi
+
+if [[ -d ${ENV_DIR} ]]; then
+  echo "Virtual environment has already been set up under ${ENV_DIR}. The present virtual environment is activated now."
+  echo "NOTE: If you wish to set up a new virtual environment, delete the existing one or provide a different name."
+  
+  ENV_EXIST=1
+else
+  ENV_EXIST=0
+fi
 
-#Copy the hickle package from bing's account
-cp  -r /p/project/deepacf/deeprain/bing/hickle ${WORKING_DIR}
+# add personal email-address to Batch-scripts
+if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then
+    USER_EMAIL=$(jutil user show -o json | grep email | cut -f2 -d':' | cut -f1 -d',' | cut -f2 -d'"')
+    #replace the email in sbatch script with the USER_EMAIL
+    sed -i "s/--mail-user=.*/--mail-user=$USER_EMAIL/g" ../HPC_scripts/*.sh
+    # load modules and check for their availability
+    echo "***** Checking modules required during the workflow... *****"
+    source ${ENV_SETUP_DIR}/modules_preprocess.sh
+    source ${ENV_SETUP_DIR}/modules_train.sh
 
-source ${ENV_SETUP_DIR}/modules.sh
-source ${ENV_DIR}/bin/activate
+elif [[ "${HOST_NAME}" == "zam347" ]]; then
+    unset PYTHONPATH
+fi
 
-export PYTHONPATH=${WORKING_DIR}/hickle/lib/python3.6/site-packages:$PYTHONPATH
-export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH
-export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH
-#export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH
-export PYTHONPATH=${WORKING_DIR}/lpips-tensorflow:$PYTHONPATH
+if [[ "$ENV_EXIST" == 0 ]]; then
+  # Activate virtual environmen and install additional Python packages.
+  echo "Configuring and activating virtual environment on ${HOST_NAME}"
+    
+  python3 -m venv $ENV_DIR
+  
+  activate_virt_env=${ENV_DIR}/bin/activate
+  echo ${activate_virt_env}
+  
+  source ${activate_virt_env}
+  
+  # install some requirements and/or check for modules
+  if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then
+    # check module availability for the first time on known HPC-systems
+    echo "***** Start installing additional Python modules with pip... *****"
+    pip3 install --no-cache-dir --ignore-installed -r ${ENV_SETUP_DIR}/requirements.txt
+    #pip3 install --user netCDF4
+    #pip3 install --user numpy
+  elif [[ "${HOST_NAME}" == "zam347" ]]; then
+    echo "***** Start installing additional Python modules with pip... *****"
+    pip3 install --upgrade pip
+    pip3 install -r ${ENV_SETUP_DIR}/requirements.txt
+    pip3 install  mpi4py 
+    pip3 install netCDF4
+    pip3 install  numpy
+    pip3 install h5py
+    pip3 install tensorflow-gpu==1.13.1
+  fi
+  # expand PYTHONPATH...
+  export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH >> ${activate_virt_env}
+  #export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH
+  export PYTHONPATH=${WORKING_DIR}/external_package/lpips-tensorflow:$PYTHONPATH >> ${activate_virt_env}
 
+  if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then
+     export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH >> ${activate_virt_env}
+  fi
+  # ...and ensure that this also done when the 
+  echo "" >> ${activate_virt_env}
+  echo "# Expand PYTHONPATH..." >> ${activate_virt_env}
+  echo "export PYTHONPATH=${WORKING_DIR}:\$PYTHONPATH" >> ${activate_virt_env}
+  #export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:\$PYTHONPATH
+  echo "export PYTHONPATH=${WORKING_DIR}/external_package/lpips-tensorflow:\$PYTHONPATH" >> ${activate_virt_env}
+
+  if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then
+    echo "export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:\$PYTHONPATH" >> ${activate_virt_env}
+  fi  
+elif [[ "$ENV_EXIST" == 1 ]]; then
+  # activating virtual env is suifficient
+  source ${ENV_DIR}/bin/activate  
+fi
 
diff --git a/video_prediction_savp/env_setup/modules.sh b/video_prediction_savp/env_setup/modules.sh
deleted file mode 100755
index e6793787ad59988cfc6646dc8dd789d1573c6b23..0000000000000000000000000000000000000000
--- a/video_prediction_savp/env_setup/modules.sh
+++ /dev/null
@@ -1,12 +0,0 @@
-#!/usr/bin/env bash
-module purge
-module use $OTHERSTAGES
-module load Stages/2019a
-module load GCC/8.3.0
-module load MVAPICH2/.2.3.1-GDR
-module load GCCcore/.8.3.0
-module load mpi4py/3.0.1-Python-3.6.8
-module load h5py/2.9.0-serial-Python-3.6.8
-module load TensorFlow/1.13.1-GPU-Python-3.6.8
-module load cuDNN/7.5.1.10-CUDA-10.1.105
-
diff --git a/video_prediction_savp/env_setup/modules_preprocess.sh b/video_prediction_savp/env_setup/modules_preprocess.sh
new file mode 100755
index 0000000000000000000000000000000000000000..c4a242b0a739eb65c7a340dd8e3a3fbb57b04408
--- /dev/null
+++ b/video_prediction_savp/env_setup/modules_preprocess.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+
+# __author__ = Bing Gong, Michael Langguth
+# __date__  = '2020_06_26'
+
+# This script loads the required modules for ambs on Juwels and HDF-ML.
+# Note that some other packages have to be installed into a venv (see create_env.sh and requirements.txt).
+
+HOST_NAME=`hostname`
+
+echo "Start loading modules on ${HOST_NAME} required for preprocessing..."
+echo "This script is used by: "
+echo "* DataExtraction.sh"
+echo "* DataPreprocess.sh"
+
+module purge
+module use $OTHERSTAGES
+module load Stages/2019a
+module load GCC/8.3.0
+module load ParaStationMPI/5.2.2-1
+module load mpi4py/3.0.1-Python-3.6.8
+# serialized version is not available on HFML
+# see https://gitlab.version.fz-juelich.de/haf/Wiki/-/wikis/HDF-ML%20System
+if [[ "${HOST_NAME}" == hdfml* ]]; then
+  module load h5py/2.9.0-serial-Python-3.6.8
+elif [[ "${HOST_NAME}" == juwels* ]]; then
+  module load h5py/2.9.0-Python-3.6.8
+fi
+module load netcdf4-python/1.5.0.1-Python-3.6.8
+
diff --git a/video_prediction_savp/env_setup/modules_train.sh b/video_prediction_savp/env_setup/modules_train.sh
new file mode 100755
index 0000000000000000000000000000000000000000..cdbc436e1976773132aa636a3f1cedfa506d3a14
--- /dev/null
+++ b/video_prediction_savp/env_setup/modules_train.sh
@@ -0,0 +1,25 @@
+#!/usr/bin/env bash
+
+# __author__ = Bing Gong, Michael Langguth
+# __date__  = '2020_06_26'
+
+# This script loads the required modules for ambs on Juwels and HDF-ML.
+# Note that some other packages have to be installed into a venv (see create_env.sh and requirements.txt).
+
+HOST_NAME=`hostname`
+
+echo "Start loading modules on ${HOST_NAME}..."
+
+module purge
+module use $OTHERSTAGES
+module load Stages/2019a
+module load GCC/8.3.0
+module load GCCcore/.8.3.0
+module load ParaStationMPI/5.2.2-1
+module load mpi4py/3.0.1-Python-3.6.8
+# serialized version of HDF5 is used since only this version is compatible with TensorFlow/1.13.1-GPU-Python-3.6.8
+module load h5py/2.9.0-serial-Python-3.6.8
+module load TensorFlow/1.13.1-GPU-Python-3.6.8
+module load cuDNN/7.5.1.10-CUDA-10.1.105
+module load netcdf4-python/1.5.0.1-Python-3.6.8
+
diff --git a/video_prediction_savp/env_setup/requirements.txt b/video_prediction_savp/env_setup/requirements.txt
index 76dd1f57d64577cc565968bb7106656e53687261..4bf2f0b25d082c4c503bbd56f46d28360a48df43 100644
--- a/video_prediction_savp/env_setup/requirements.txt
+++ b/video_prediction_savp/env_setup/requirements.txt
@@ -2,3 +2,4 @@ opencv-python
 scipy
 scikit-image
 pandas
+hickle
diff --git a/video_prediction_savp/hparams/era5/model_hparams.json b/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json
similarity index 56%
rename from video_prediction_savp/hparams/era5/model_hparams.json
rename to video_prediction_savp/hparams/era5/convLSTM/model_hparams.json
index b121ee2f005b6db753b2536deb804204dd41b78d..c2edaad9f9ac158f6e7b8d94bb81db16d55d05e8 100644
--- a/video_prediction_savp/hparams/era5/model_hparams.json
+++ b/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json
@@ -1,11 +1,12 @@
+
 {
-    "batch_size": 8,
+    "batch_size": 10,
     "lr": 0.001,
-    "nz": 16,
-    "max_steps":500,
+    "max_epochs":2,
     "context_frames":10,
     "sequence_length":20
 
 }
 
 
+
diff --git a/video_prediction_savp/metadata.py b/video_prediction_savp/metadata.py
index b65db76d71a549b3405f9f10b26dcccb09b30230..8f61d5766169c08d793b665253a8cd1e86a80548 100644
--- a/video_prediction_savp/metadata.py
+++ b/video_prediction_savp/metadata.py
@@ -4,6 +4,7 @@ Classes and routines to retrieve and handle meta-data
 
 import os
 import sys
+import time
 import numpy as np
 import json
 from netCDF4 import Dataset
@@ -23,7 +24,9 @@ class MetaData:
         method_name = MetaData.__init__.__name__+" of Class "+MetaData.__name__
         
         if not json_file is None: 
-            MetaData.get_metadata_from_file(json_file)
+            print(json_file)
+            print(type(json_file))
+            MetaData.get_metadata_from_file(self,json_file)
             
         else:
             # No dictionary from json-file available, all other arguments have to set
@@ -90,9 +93,11 @@ class MetaData:
         self.nx, self.ny = np.abs(slices['lon_e'] - slices['lon_s']), np.abs(slices['lat_e'] - slices['lat_s'])    
         sw_c             = [float(datafile.variables['lat'][slices['lat_e']-1]),float(datafile.variables['lon'][slices['lon_s']])]                # meridional axis lat is oriented from north to south (i.e. monotonically decreasing)
         self.sw_c        = sw_c
+        self.lat = datafile.variables['lat'][slices['lat_s']:slices['lat_e']]
+        self.lon = datafile.variables['lon'][slices['lon_s']:slices['lon_e']]
         
-        # Now start constructing exp_dir-string
-        # switch sign and coordinate-flags to avoid negative values appearing in exp_dir-name
+        # Now start constructing expdir-string
+        # switch sign and coordinate-flags to avoid negative values appearing in expdir-name
         if sw_c[0] < 0.:
             sw_c[0] = np.abs(sw_c[0])
             flag_coords[0] = "S"
@@ -112,7 +117,7 @@ class MetaData:
         
         expdir, expname = path_parts[0], path_parts[1] 
 
-        # extend exp_dir_in successively (splitted up for better readability)
+        # extend expdir_in successively (splitted up for better readability)
         expname += "-"+str(self.nx) + "x" + str(self.ny)
         expname += "-"+(("{0: 05.2f}"+flag_coords[0]+"{1:05.2f}"+flag_coords[1]).format(*sw_c)).strip().replace(".","")+"-"  
         
@@ -139,10 +144,15 @@ class MetaData:
                      "expdir" : self.expdir}
         
         meta_dict["sw_corner_frame"] = {
-            "lat" : self.sw_c[0],
-            "lon" : self.sw_c[1]
+            "lat" : np.around(self.sw_c[0],decimals=2),
+            "lon" : np.around(self.sw_c[1],decimals=2)
             }
         
+        meta_dict["coordinates"] = {
+            "lat" : np.around(self.lat,decimals=2).tolist(),
+            "lon" : np.around(self.lon,decimals=2).tolist()
+            }
+            
         meta_dict["frame_size"] = {
             "nx" : int(self.nx),
             "ny" : int(self.ny)
@@ -150,7 +160,7 @@ class MetaData:
         
         meta_dict["variables"] = []
         for i in range(len(self.varnames)):
-            print(self.varnames[i])
+            #print(self.varnames[i])
             meta_dict["variables"].append( 
                     {"var"+str(i+1) : self.varnames[i]})
         
@@ -163,14 +173,19 @@ class MetaData:
             
         meta_fname = os.path.join(dest_dir,"metadata.json")
 
-        if os.path.exists(meta_fname):                      # check if a metadata-file already exists and check its content 
+        if os.path.exists(meta_fname):                      # check if a metadata-file already exists and check its content
+            print(method_name+": json-file ('"+meta_fname+"' already exists. Its content will be checked...")
             self.status = "old"                             # set status to old in order to prevent repeated modification of shell-/Batch-scripts
             with open(meta_fname,'r') as js_file:
                 dict_dupl = json.load(js_file)
                 
                 if dict_dupl != meta_dict:
-                    print(method_name+": Already existing metadata (see '"+meta_fname+") do not fit data being processed right now. Ensure a common data base.")
-                    sys.exit(1)
+                    meta_fname_dbg = os.path.join(dest_dir,"metadata_debug.json")
+                    print(method_name+": Already existing metadata (see '"+meta_fname+"') do not fit data being processed right now (see '" \
+                          +meta_fname_dbg+"'. Ensure a common data base.")
+                    with open(meta_fname_dbg,'w') as js_file:
+                        json.dump(meta_dict,js_file)                         
+                    raise ValueError
                 else: #do not need to do anything
                     pass
         else:
@@ -189,20 +204,24 @@ class MetaData:
         with open(js_file) as js_file:                
             dict_in = json.load(js_file)
             
-            self.exp_dir = dict_in["exp_dir"]
+            self.expdir = dict_in["expdir"]
             
             self.sw_c       = [dict_in["sw_corner_frame"]["lat"],dict_in["sw_corner_frame"]["lon"] ]
+            self.lat        = dict_in["coordinates"]["lat"]
+            self.lon        = dict_in["coordinates"]["lon"]
             
             self.nx         = dict_in["frame_size"]["nx"]
             self.ny         = dict_in["frame_size"]["ny"]
-            
-            self.variables  = [dict_in["variables"][ivar] for ivar in dict_in["variables"].keys()]   
-            
+            # dict_in["variables"] is a list like [{var1: varname1},{var2: varname2},...]
+            list_of_dict_aux = dict_in["variables"] 
+            # iterate through the list with an integer ivar
+            # note: the naming of the variables starts with var1, thus add 1 to the iterator
+            self.variables = [list_of_dict_aux[ivar]["var"+str(ivar+1)] for ivar in range(len(list_of_dict_aux))]
             
     def write_dirs_to_batch_scripts(self,batch_script):
         
         """
-         Expands ('known') directory-variables in batch_script by exp_dir-attribute of class instance
+         Expands ('known') directory-variables in batch_script by expdir-attribute of class instance
         """
         
         paths_to_mod = ["source_dir=","destination_dir=","checkpoint_dir=","results_dir="]      # known directory-variables in batch-scripts
@@ -224,6 +243,7 @@ class MetaData:
     def write_destdir_jsontmp(dest_dir, tmp_dir = None):        
         """
           Writes dest_dir to temporary json-file (temp.json) stored in the current working directory.
+          To be executed by Master node in parallel mode.
         """
         
         if not tmp_dir: tmp_dir = os.getcwd()
@@ -259,6 +279,34 @@ class MetaData:
         else:
             return(dict_tmp.get("destination_dir"))
     
+    @staticmethod
+    def wait_for_jsontmp(tmp_dir = None, waittime = 10, delay=0.5):
+        """
+          Waits at max. waittime (in sec) until temp.json-file becomes available
+        """
+        
+        method_name = MetaData.wait_for_jsontmp.__name__+" of Class "+MetaData.__name__
+        
+        if not tmp_dir: tmp_dir = os.getcwd()
+        
+        file_tmp = os.path.join(tmp_dir,"temp.json")
+                                
+        counter_max = waittime/delay
+        counter = 0
+        status  = "not_ok" 
+        
+        while (counter <= counter_max):
+            if os.path.isfile(file_tmp):
+                status = "ok"
+                break
+            else:
+                time.sleep(delay)
+            
+            counter += 1
+                                
+        if status != "ok": raise IOError(method_name+": '"+file_tmp+ \
+                           "' does not exist after waiting for "+str(waittime)+" sec.") 
+    
     
     @staticmethod
     def issubset(a,b):
diff --git a/video_prediction_savp/scripts/generate_anomaly.py b/video_prediction_savp/scripts/generate_anomaly.py
deleted file mode 100644
index 9c8e555c2e223c51bfb555c7dd0fa0eb089610d8..0000000000000000000000000000000000000000
--- a/video_prediction_savp/scripts/generate_anomaly.py
+++ /dev/null
@@ -1,518 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import errno
-import json
-import os
-import math
-import random
-import cv2
-import numpy as np
-import tensorflow as tf
-import seaborn as sns
-import pickle
-from random import seed
-import random
-import json
-import numpy as np
-#from six.moves import cPickle
-import matplotlib
-matplotlib.use('Agg')
-import matplotlib.pyplot as plt
-import matplotlib.gridspec as gridspec
-import matplotlib.animation as animation
-import seaborn as sns
-import pandas as pd
-from video_prediction import datasets, models
-from matplotlib.colors import LinearSegmentedColormap
-from matplotlib.ticker import MaxNLocator
-from video_prediction.utils.ffmpeg_gif import save_gif
-
-with open("./splits_size_64_64_1/geo_info.json","r") as json_file:
-    geo = json.load(json_file)
-    lat = [round(i,2) for i in geo["lat"]]
-    lon = [round(i,2) for i in geo["lon"]]
-
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
-                                                                     "train, val, test, etc, or a directory containing "
-                                                                     "the tfrecords")
-    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified")
-    parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified")
-    parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified")
-    parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is "
-                                                 "results_gif_dir/model_fname")
-    parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is "
-                                                 "results_png_dir/model_fname")
-    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
-
-    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')
-
-    parser.add_argument("--dataset", type=str, help="dataset class name")
-    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
-    parser.add_argument("--model", type=str, help="model class name")
-    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
-
-    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
-    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
-    parser.add_argument("--num_epochs", type=int, default=1)
-
-    parser.add_argument("--num_stochastic_samples", type=int, default=1) #Bing original is 5, change to 1
-    parser.add_argument("--gif_length", type=int, help="default is sequence_length")
-    parser.add_argument("--fps", type=int, default=4)
-
-    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
-    parser.add_argument("--seed", type=int, default=7)
-
-    args = parser.parse_args()
-
-    if args.seed is not None:
-        tf.set_random_seed(args.seed)
-        np.random.seed(args.seed)
-        random.seed(args.seed)
-
-    args.results_gif_dir = args.results_gif_dir or args.results_dir
-    args.results_png_dir = args.results_png_dir or args.results_dir
-    dataset_hparams_dict = {}
-    model_hparams_dict = {}
-    if args.checkpoint:
-        checkpoint_dir = os.path.normpath(args.checkpoint)
-        if not os.path.isdir(args.checkpoint):
-            checkpoint_dir, _ = os.path.split(checkpoint_dir)
-        if not os.path.exists(checkpoint_dir):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
-        with open(os.path.join(checkpoint_dir, "options.json")) as f:
-            print("loading options from checkpoint %s" % args.checkpoint)
-            options = json.loads(f.read())
-            args.dataset = args.dataset or options['dataset']
-            args.model = args.model or options['model']
-        try:
-            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
-                dataset_hparams_dict = json.loads(f.read())
-        except FileNotFoundError:
-            print("dataset_hparams.json was not loaded because it does not exist")
-        try:
-            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
-                model_hparams_dict = json.loads(f.read())
-        except FileNotFoundError:
-            print("model_hparams.json was not loaded because it does not exist")
-        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1])
-        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1])
-    else:
-        if not args.dataset:
-            raise ValueError('dataset is required when checkpoint is not specified')
-        if not args.model:
-            raise ValueError('model is required when checkpoint is not specified')
-        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
-        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)
-
-    print('----------------------------------- Options ------------------------------------')
-    for k, v in args._get_kwargs():
-        print(k, "=", v)
-    print('------------------------------------- End --------------------------------------')
-
-    VideoDataset = datasets.get_dataset_class(args.dataset)
-    dataset = VideoDataset(
-        args.input_dir,
-        mode=args.mode,
-        num_epochs=args.num_epochs,
-        seed=args.seed,
-        hparams_dict=dataset_hparams_dict,
-        hparams=args.dataset_hparams)
-    VideoPredictionModel = models.get_model_class(args.model)
-    hparams_dict = dict(model_hparams_dict)
-    hparams_dict.update({
-        'context_frames': dataset.hparams.context_frames,
-        'sequence_length': dataset.hparams.sequence_length,
-        'repeat': dataset.hparams.time_shift,
-    })
-    model = VideoPredictionModel(
-        mode=args.mode,
-        hparams_dict=hparams_dict,
-        hparams=args.model_hparams)
-
-    sequence_length = model.hparams.sequence_length
-    context_frames = model.hparams.context_frames
-    future_length = sequence_length - context_frames
-
-    if args.num_samples:
-        if args.num_samples > dataset.num_examples_per_epoch():
-            raise ValueError('num_samples cannot be larger than the dataset')
-        num_examples_per_epoch = args.num_samples
-    else:
-        #Bing: error occurs here, cheats a little bit here
-        #num_examples_per_epoch = dataset.num_examples_per_epoch()
-        num_examples_per_epoch = args.batch_size * 8
-    if num_examples_per_epoch % args.batch_size != 0:
-        #bing
-        #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
-        pass
-    #Bing if it is era 5 data we used dataset.make_batch_v2
-    #inputs = dataset.make_batch(args.batch_size)
-    inputs, inputs_mean = dataset.make_batch_v2(args.batch_size)
-    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
-    with tf.variable_scope(''):
-        model.build_graph(input_phs)
-
-    for output_dir in (args.output_gif_dir, args.output_png_dir):
-        if not os.path.exists(output_dir):
-            os.makedirs(output_dir)
-        with open(os.path.join(output_dir, "options.json"), "w") as f:
-            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
-        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
-            f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
-        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
-            f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
-
-    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
-    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
-    sess = tf.Session(config=config)
-    sess.graph.as_default()
-    model.restore(sess, args.checkpoint)
-    sample_ind = 0
-    gen_images_all = []
-    input_images_all = []
-
-    while True:
-        if args.num_samples and sample_ind >= args.num_samples:
-            break
-        try:
-            input_results = sess.run(inputs)
-            input_mean_results = sess.run(inputs_mean)
-            input_final = input_results["images"] + input_mean_results["images"]
-
-        except tf.errors.OutOfRangeError:
-            break
-        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))
-        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
-        for stochastic_sample_ind in range(args.num_stochastic_samples):
-            print("Stochastic sample id", stochastic_sample_ind)
-            gen_anomaly = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
-            gen_images = gen_anomaly + input_mean_results["images"][:,1:,:,:]
-            #input_images = sess.run(inputs["images"])
-            #Bing: Add evaluation metrics
-            # fetches = {'images': model.inputs['images']}
-            # fetches.update(model.eval_outputs.items())
-            # fetches.update(model.eval_metrics.items())
-            # results = sess.run(fetches, feed_dict = feed_dict)
-            # input_images = results["images"] #shape (batch_size,future_frames,height,width,channel)
-            # only keep the future frames
-            #gen_images = gen_images[:, -future_length:] #(8,10,64,64,1) (batch_size, sequences, height, width, channel)
-            #input_images = input_results["images"][:,-future_length:,:,:]
-            #input_images = input_results["images"][:,1:,:,:,:]
-            input_images = input_final [:,1:,:,:,:]
-            #gen_mse_avg = results["eval_mse/avg"] #shape (batch_size,future_frames)
-            print("Finish sample ind",stochastic_sample_ind)
-            input_gen_diff_ = input_images - gen_images
-            #diff_image_range = pd.cut(input_gen_diff_.flatten(), bins = 4, labels = [-10, -5, 0, 5], right = False)
-            #diff_image_range = np.reshape(np.array(diff_image_range),input_gen_diff_.shape)
-            gen_images_all.extend(gen_images)
-            input_images_all.extend(input_images)
-
-            colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
-            cmap_name = 'my_list'
-            if sample_ind < 100:
-                for i in range(len(gen_images)):
-                    name = 'Batch_id_' + str(sample_ind) + " + Sample_" + str(i)
-                    gen_images_ = gen_images[i, :]
-                    gen_mse_avg_ = [np.mean(input_gen_diff_[i, frame, :, :, :]**2) for frame in
-                                    range(19)]  # return the list with 10 (sequence) mse
-
-                    input_gen_diff = input_gen_diff_[i,:,:,:,:]
-                    input_images_ = input_images[i, :]
-                    #gen_mse_avg_ = gen_mse_avg[i, :]
-                    fig = plt.figure()
-                    gs = gridspec.GridSpec(4,6)
-                    gs.update(wspace = 0.7,hspace=0.8)
-                    ax1 = plt.subplot(gs[0:2,0:3])
-                    ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
-                    ax3 = plt.subplot(gs[2:4,0:3])
-                    ax4 = plt.subplot(gs[2:4,3:])
-                    xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
-                    ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
-                    plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
-                    ax1.title.set_text("(a) Ground Truth")
-                    ax2.title.set_text("(b) SAVP")
-                    ax3.title.set_text("(c) Diff.")
-                    ax4.title.set_text("(d) MSE")
-
-                    ax1.xaxis.set_tick_params(labelsize=7)
-                    ax1.yaxis.set_tick_params(labelsize = 7)
-                    ax2.xaxis.set_tick_params(labelsize=7)
-                    ax2.yaxis.set_tick_params(labelsize = 7)
-                    ax3.xaxis.set_tick_params(labelsize=7)
-                    ax3.yaxis.set_tick_params(labelsize = 7)
-
-                    init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
-                    print("inti images shape", init_images.shape)
-                    xdata, ydata = [], []
-                    plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
-                    plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
-                    #x = np.linspace(0, 64, 64)
-                    #y = np.linspace(0, 64, 64)
-                    #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
-                    #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
-                    fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
-                    fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
-
-                    cm = LinearSegmentedColormap.from_list(
-                        cmap_name, "bwr", N = 5)
-
-                    plot3 = ax3.imshow(init_images, vmin=-10, vmax=10, cmap=cm)#cmap = 'PuBu_r',
-
-                    plot4, = ax4.plot([], [], color = "r")
-                    ax4.set_xlim(0, len(gen_mse_avg_)-1)
-                    ax4.set_ylim(0, 10)
-                    ax4.set_xlabel("Frames", fontsize=10)
-                    #ax4.set_ylabel("MSE", fontsize=10)
-                    ax4.xaxis.set_tick_params(labelsize=7)
-                    ax4.yaxis.set_tick_params(labelsize=7)
-
-
-                    plots = [plot1, plot2, plot3, plot4]
-
-                    #fig.colorbar(plots[1], ax = [ax1, ax2])
-
-                    fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
-                    #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
-                    #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
-
-                    def animation_sample(t):
-                        input_image = input_images_[t, :, :, 0]
-                        gen_image = gen_images_[t, :, :, 0]
-                        diff_image = input_gen_diff[t,:,:,0]
-
-                        data = gen_mse_avg_[:t + 1]
-                        # x = list(range(len(gen_mse_avg_)))[:t+1]
-                        xdata.append(t)
-                        print("xdata", xdata)
-                        ydata.append(gen_mse_avg_[t])
-
-                        print("ydata", ydata)
-                        # p = sns.lineplot(x=x,y=data,color="b")
-                        # p.tick_params(labelsize=17)
-                        # plt.setp(p.lines, linewidth=6)
-                        plots[0].set_data(input_image)
-                        plots[1].set_data(gen_image)
-                        #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
-                        #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
-                        plots[2].set_data(diff_image)
-                        plots[3].set_data(xdata, ydata)
-                        fig.suptitle("Frame " + str(t+1))
-
-                        return plots
-
-                    ani = animation.FuncAnimation(fig, animation_sample, frames = len(gen_mse_avg_), interval = 1000,
-                                                  repeat_delay = 2000)
-                    ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
-
-                else:
-                    pass
-
-    #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
-    #         #     ims = []
-    #         #     fig = plt.figure()
-    #         #     plt.xlim(0,len(gen_mse_avg_))
-    #         #     plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg))
-    #         #     plt.xlabel("Frames")
-    #         #     plt.ylabel("MSE_AVG")
-    #         #     #X = list(range(len(gen_mse_avg_)))
-    #         #     #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
-    #         #     def animate_metric(j):
-    #         #         data = gen_mse_avg_[:(j+1)]
-    #         #         x = list(range(len(gen_mse_avg_)))[:(j+1)]
-    #         #         p = sns.lineplot(x=x,y=data,color="b")
-    #         #         p.tick_params(labelsize=17)
-    #         #         plt.setp(p.lines, linewidth=6)
-    #         #     ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000)
-    #         #     ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif"))
-    #         #
-    #         #
-    #         # for i, input_images_ in enumerate(input_images):
-    #         #     #context_images_ = (input_results['images'][i])
-    #         #     #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
-    #         #     ims = []
-    #         #     fig = plt.figure()
-    #         #     for t, input_image in enumerate(input_images_):
-    #         #         im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')
-    #         #         ttl = plt.text(1.5, 2,"Frame_" + str(t))
-    #         #         ims.append([im,ttl])
-    #         #     ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000)
-    #         #     ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif"))
-    #         #     #plt.show()
-    #         #
-    #         # for i,gen_images_ in enumerate(gen_images):
-    #         #     ims = []
-    #         #     fig = plt.figure()
-    #         #     for t, gen_image in enumerate(gen_images_):
-    #         #         im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
-    #         #         ttl = plt.text(1.5, 2, "Frame_" + str(t))
-    #         #         ims.append([im, ttl])
-    #         #     ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
-    #         #     ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif"))
-    #
-    #
-    #             # for i, gen_images_ in enumerate(gen_images):
-    #             #     #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8)
-    #             #     #gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
-    #             #     #bing
-    #             #     context_images_ = (input_results['images'][i])
-    #             #     gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
-    #             #     context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
-    #             #     plt.figure(figsize = (10,2))
-    #             #     gs = gridspec.GridSpec(2,10)
-    #             #     gs.update(wspace=0.,hspace=0.)
-    #             #     for t, gen_image in enumerate(gen_images_):
-    #             #         gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1)))
-    #             #         gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
-    #             #         plt.subplot(gs[t])
-    #             #         plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')  # the last index sets the channel. 0 = t2
-    #             #         # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
-    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
-    #             #                         right = False, labelbottom = False, labelleft = False)
-    #             #         if t == 0: plt.ylabel('Actual', fontsize = 10)
-    #             #
-    #             #         plt.subplot(gs[t + 10])
-    #             #         plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
-    #             #         # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
-    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
-    #             #                         right = False, labelbottom = False, labelleft = False)
-    #             #         if t == 0: plt.ylabel('Predicted', fontsize = 10)
-    #             #     plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png')
-    #             #     plt.clf()
-    #
-    #             # if args.gif_length:
-    #             #     context_and_gen_images = context_and_gen_images[:args.gif_length]
-    #             # save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
-    #             #          context_and_gen_images, fps=args.fps)
-    #             #
-    #             # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
-    #             # for t, gen_image in enumerate(gen_images_):
-    #             #     gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
-    #             #     if gen_image.shape[-1] == 1:
-    #             #       gen_image = np.tile(gen_image, (1, 1, 3))
-    #             #     else:
-    #             #       gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
-    #             #     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
-
-        sample_ind += args.batch_size
-
-
-    with open(os.path.join(args.output_png_dir, "input_images_all"),"wb") as input_files:
-        pickle.dump(input_images_all,input_files)
-
-    with open(os.path.join(args.output_png_dir, "gen_images_all"),"wb") as gen_files:
-        pickle.dump(gen_images_all,gen_files)
-
-    with open(os.path.join(args.output_png_dir, "input_images_all"),"rb") as input_files:
-        input_images_all = pickle.load(input_files)
-
-    with open(os.path.join(args.output_png_dir, "gen_images_all"),"rb") as gen_files:
-        gen_images_all=pickle.load(gen_files)
-    ims = []
-    fig = plt.figure()
-    for frame in range(19):
-        input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel)
-        #pix_mean = np.mean(input_gen_diff, axis = 0)
-        #pix_std = np.std(input_gen_diff, axis=0)
-        im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
-        if frame == 0:
-            fig.colorbar(im)
-        ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
-        ims.append([im, ttl])
-    ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
-    ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
-    plt.close("all")
-
-    ims = []
-    fig = plt.figure()
-    for frame in range(19):
-        pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
-        #pix_mean = np.mean(input_gen_diff, axis = 0)
-        #pix_std = np.std(input_gen_diff, axis=0)
-        im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
-        if frame == 0:
-            fig.colorbar(im)
-        ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
-        ims.append([im, ttl])
-    ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
-    ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
-
-    gen_images_all = np.array(gen_images_all)
-    input_images_all = np.array(input_images_all)
-    # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
-    # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
-    # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
-
-    mse_model = np.mean((input_images_all[:, :10,:,:,0] - gen_images_all[:, :10,:,:,0])**2)  # look at all timesteps except the first
-    mse_model_last = np.mean((input_images_all[:,10,:,:,0] - gen_images_all[:, 10,:,:,0])**2)
-    mse_prev = np.mean((input_images_all[:, :9, :, :, 0] - input_images_all[:, 1:10, :, :, 0])**2 )
-
-    def psnr(img1, img2):
-        mse = np.mean((img1 - img2) ** 2)
-        if mse == 0: return 100
-        PIXEL_MAX = 1
-        return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
-
-    psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
-    psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
-    psnr_prev = psnr(input_images_all[:, :9, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
-    f = open(os.path.join(args.output_png_dir,'prediction_scores_4prediction.txt'), 'w')
-    f.write("Model MSE: %f\n" % mse_model)
-    f.write("Model MSE from only last prediction in sequence: %f\n" % mse_model_last)
-    f.write("Previous Frame MSE: %f\n" % mse_prev)
-    f.write("Model PSNR: %f\n" % psnr_model)
-    f.write("Model PSNR from only last prediction in sequence: %f\n" % psnr_model_last)
-    f.write("Previous frame PSNR: %f\n" % psnr_prev)
-    f.write("Shape of X_test: " + str(input_images_all.shape))
-    f.write("")
-    f.write("Shape of X_hat: " + str(gen_images_all.shape))
-    f.close()
-
-    seed(1)
-    s = random.sample(range(len(gen_images_all)), 100)
-    print("******KDP******")
-    #kernel density plot for checking the model collapse
-    fig = plt.figure()
-    kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
-    kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
-    kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
-    plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
-    plt.clf()
-
-    #line plot for evaluating the prediction and groud-truth
-    for i in [0,3,6,9,12,15,18]:
-        fig = plt.figure()
-        plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
-        #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
-        plt.xlabel("Prediction")
-        plt.ylabel("Real values")
-        plt.title("Frame_{}".format(i+1))
-        plt.plot([250,300], [250,300],color="black")
-        plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
-        plt.clf()
-
-
-    mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
-    x = [str(i+1) for i in list(range(19))]
-    fig,axis = plt.subplots()
-    mean_f = np.mean(mse_model_by_frames, axis = 0)
-    median = np.median(mse_model_by_frames, axis=0)
-    q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
-    q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
-    d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
-    d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
-    plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
-    plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
-    plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
-    plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
-    plt.title(f'MSE percentile')
-    plt.xlabel("Frames")
-    plt.legend(loc=2, fontsize=8)
-    plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
-
-if __name__ == '__main__':
-    main()
diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
index 331559f6287a4f24c1c19ee9f7f4b03309a22abf..0e250b47df28d115c8cdfc77fc708eab5e094ce6 100644
--- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
+++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
@@ -12,12 +12,10 @@ import cv2
 import numpy as np
 import tensorflow as tf
 import pickle
-import hickle
 from random import seed
 import random
 import json
 import numpy as np
-#from six.moves import cPickle
 import matplotlib
 matplotlib.use('Agg')
 import matplotlib.pyplot as plt
@@ -37,11 +35,88 @@ import sys
 sys.path.append(path.abspath('../video_prediction/datasets/'))
 from era5_dataset_v2 import Norm_data
 from os.path import dirname
+from netCDF4 import Dataset,date2num
+from metadata import MetaData as MetaData
+
+def set_seed(seed):
+    if seed is not None:
+        tf.set_random_seed(seed)
+        np.random.seed(seed)
+        random.seed(seed) 
+
+def get_coordinates(metadata_fname):
+    """
+    Retrieves the latitudes and longitudes read from the metadata json file.
+    """
+    md = MetaData(json_file=metadata_fname)
+    md.get_metadata_from_file(metadata_fname)
+    
+    try:
+        print("lat:",md.lat)
+        print("lon:",md.lon)
+        return md.lat, md.lon
+    except:
+        raise ValueError("Error when handling: '"+metadata_fname+"'")
+    
+
+def load_checkpoints_and_create_output_dirs(checkpoint,dataset,model):
+    if checkpoint:
+        checkpoint_dir = os.path.normpath(checkpoint)
+        if not os.path.isdir(checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % checkpoint)
+            options = json.loads(f.read())
+            dataset = dataset or options['dataset']
+            model = model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+    else:
+        if not dataset:
+            raise ValueError('dataset is required when checkpoint is not specified')
+        if not model:
+            raise ValueError('model is required when checkpoint is not specified')
 
-with open("../geo_info.json","r") as json_file:
-    geo = json.load(json_file)
-    lat = [round(i,2) for i in geo["lat"]]
-    lon = [round(i,2) for i in geo["lon"]]
+    return options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict
+
+
+    
+def setup_dataset(dataset,input_dir,mode,seed,num_epochs,dataset_hparams,dataset_hparams_dict):
+    VideoDataset = datasets.get_dataset_class(dataset)
+    dataset = VideoDataset(
+        input_dir,
+        mode = mode,
+        num_epochs = num_epochs,
+        seed = seed,
+        hparams_dict = dataset_hparams_dict,
+        hparams = dataset_hparams)
+    return dataset
+
+
+def setup_dirs(input_dir,results_png_dir):
+    input_dir = args.input_dir
+    temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/"
+    print ("temporal_dir:",temporal_dir)
+
+
+def update_hparams_dict(model_hparams_dict,dataset):
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    return hparams_dict
 
 
 def psnr(img1, img2):
@@ -51,6 +126,218 @@ def psnr(img1, img2):
     return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
 
 
+def setup_num_samples_per_epoch(num_samples, dataset):
+    if num_samples:
+        if num_samples > dataset.num_examples_per_epoch():
+            raise ValueError('num_samples cannot be larger than the dataset')
+        num_examples_per_epoch = num_samples
+    else:
+        num_examples_per_epoch = dataset.num_examples_per_epoch()
+    #if num_examples_per_epoch % args.batch_size != 0:
+    #    raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
+    return num_examples_per_epoch
+
+
+def initia_save_data():
+    sample_ind = 0
+    gen_images_all = []
+    #Bing:20200410
+    persistent_images_all = []
+    input_images_all = []
+    return sample_ind, gen_images_all,persistent_images_all, input_images_all
+
+
+def write_params_to_results_dir(args,output_dir,dataset,model):
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+    with open(os.path.join(output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys = True, indent = 4))
+    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4))
+    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4))
+    return None
+
+
+def denorm_images(stat_fl, input_images_,channel,var):
+    norm_cls  = Norm_data(var)
+    norm = 'minmax'
+    with open(stat_fl) as js_file:
+         norm_cls.check_and_set_norm(json.load(js_file),norm)
+    input_images_denorm = norm_cls.denorm_var(input_images_[:, :, :,channel], var, norm)
+    return input_images_denorm
+
+def denorm_images_all_channels(stat_fl,input_images_,*args):
+    input_images_all_channles_denorm = []
+    input_images_ = np.array(input_images_)
+    print("THIS IS INPUT_IAMGES SHPAE,",input_images_.shape)
+    args = [item for item in args][0]
+    for c in range(len(args)):
+        print("args c:", args[c])
+        input_images_all_channles_denorm.append(denorm_images(stat_fl,input_images_,channel=c,var=args[c]))           
+    input_images_denorm = np.stack(input_images_all_channles_denorm, axis=-1)
+    #print("input_images_denorm shape",input_images_denorm.shape)
+    return input_images_denorm
+
+def get_one_seq_and_time(input_images,t_starts,i):
+    assert (len(np.array(input_images).shape)==5)
+    input_images_ = input_images[i,:,:,:,:]
+    t_start = t_starts[i]
+    return input_images_,t_start
+
+
+def generate_seq_timestamps(t_start,len_seq=20):
+    if isinstance(t_start,int): t_start = str(t_start)
+    if isinstance(t_start,np.ndarray):t_start = str(t_start[0])
+    s_datetime = datetime.datetime.strptime(t_start, '%Y%m%d%H')
+    seq_ts = [s_datetime + datetime. timedelta(hours = i+1) for i in range(len_seq)]
+    return seq_ts
+    
+    
+def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,ts,context_frames,future_length,model_name,fl_name="test.nc"):
+    assert (len(np.array(input_images_).shape)==len(np.array(gen_images_).shape))
+    
+    y_len = len(lats)
+    x_len = len(lons)
+    ts_len = len(ts)
+    ts_input = ts[:context_frames]
+    ts_forecast = ts[context_frames:]
+    #print("context_frame:",context_frames)
+    #print("future_frame",future_length)
+    #print("length of ts input:",len(ts_input))
+
+    print("input_images_ shape in netcdf,",input_images_.shape)
+    gen_images_ = np.array(gen_images_)
+
+    output_file = os.path.join(output_dir,fl_name)
+    with Dataset(output_file, "w", format="NETCDF4") as nc_file:
+        nc_file.title = 'ERA5 hourly reanalysis data and the forecasting data by deep learning for 2-m above sea level temperatures'
+        nc_file.author = "Bing Gong, Michael Langguth"
+        nc_file.create_date = "2020-08-04"
+
+        #create groups forecasts and analysis 
+        fcst = nc_file.createGroup("forecasts")
+        analgrp = nc_file.createGroup("analysis")
+
+        #create dims for all the data(forecast and analysis)
+        latD = nc_file.createDimension('lat', y_len)
+        lonD = nc_file.createDimension('lon', x_len)
+        timeD = nc_file.createDimension('time_input', context_frames) 
+        timeF = nc_file.createDimension('time_forecast', future_length)
+
+        #Latitude
+        lat  = nc_file.createVariable('lat', float, ('lat',), zlib = True)
+        lat.units = 'degrees_north'
+        lat[:] = lats
+
+
+        #Longitude
+        lon = nc_file.createVariable('lon', float, ('lon',), zlib = True)
+        lon.units = 'degrees_east'
+        lon[:] = lons
+
+        #Time for input
+        time = nc_file.createVariable('time_input', 'f8', ('time_input',), zlib = True)
+        time.units = "hours since 1970-01-01 00:00:00" 
+        time.calendar = "gregorian"
+        time[:] = date2num(ts_input, units = time.units, calendar = time.calendar)
+        
+        #time for forecast
+        time_f = nc_file.createVariable('time_forecast', 'f8', ('time_forecast',), zlib = True)
+        time_f.units = "hours since 1970-01-01 00:00:00" 
+        time_f.calendar = "gregorian"
+        time_f[:] = date2num(ts_forecast, units = time.units, calendar = time.calendar)
+        
+        ################ analysis group  #####################
+        
+        #####sub group for inputs
+        # create variables for non-meta data
+        #Temperature
+        t2 = nc_file.createVariable("/analysis/inputs/T2","f4",("time_input","lat","lon"), zlib = True)
+        t2.units = 'K'
+        t2[:,:,:] = input_images_[:context_frames,:,:,0]
+
+        #mean sea level pressure
+        msl = nc_file.createVariable("/analysis/inputs/MSL","f4",("time_input","lat","lon"), zlib = True)
+        msl.units = 'Pa'
+        msl[:,:,:] = input_images_[:context_frames,:,:,1]
+
+        #Geopotential at 500 
+        gph500 = nc_file.createVariable("/analysis/inputs/GPH500","f4",("time_input","lat","lon"), zlib = True)
+        gph500.units = 'm'
+        gph500[:,:,:] = input_images_[:context_frames,:,:,2]
+        
+        #####sub group for reference(ground truth)
+        #Temperature
+        t2_r = nc_file.createVariable("/analysis/reference/T2","f4",("time_forecast","lat","lon"), zlib = True)
+        t2_r.units = 'K'
+        t2_r[:,:,:] = input_images_[context_frames:,:,:,0]
+
+        #mean sea level pressure
+        msl_r = nc_file.createVariable("/analysis/reference/MSL","f4",("time_forecast","lat","lon"), zlib = True)
+        msl_r.units = 'Pa'
+        msl_r[:,:,:] = input_images_[context_frames:,:,:,1]
+
+        #Geopotential at 500 
+        gph500_r = nc_file.createVariable("/analysis/reference/GPH500","f4",("time_forecast","lat","lon"), zlib = True)
+        gph500_r.units = 'm'
+        gph500_r[:,:,:] = input_images_[context_frames:,:,:,2]
+        
+
+        ################ forecast group  #####################
+
+        #Temperature:
+        t2 = nc_file.createVariable("/forecast/{}/T2".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
+        t2.units = 'K'
+        t2[:,:,:] = gen_images_[context_frames:,:,:,0]
+        print("NetCDF created")
+
+        #mean sea level pressure
+        msl = nc_file.createVariable("/forecast/{}/MSL".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
+        msl.units = 'Pa'
+        msl[:,:,:] = gen_images_[context_frames:,:,:,1]
+
+        #Geopotential at 500 
+        gph500 = nc_file.createVariable("/forecast/{}/GPH500".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
+        gph500.units = 'm'
+        gph500[:,:,:] = gen_images_[context_frames:,:,:,2]        
+
+        print("{} created".format(output_file)) 
+
+    return None
+
+def plot_seq_imgs(imgs,lats,lons,ts,output_png_dir,label="Ground Truth"):
+    """
+    Plot the seq images 
+    """
+
+    if len(np.array(imgs).shape)!=3:raise("img dims should be four: (seq_len,lat,lon)")
+    if np.array(imgs).shape[0]!= len(ts): raise("The len of timestamps should be equal the image seq_len") 
+    fig = plt.figure(figsize=(18,6))
+    gs = gridspec.GridSpec(1, 10)
+    gs.update(wspace = 0., hspace = 0.)
+    xlables = [round(i,2) for i  in list(np.linspace(np.min(lons),np.max(lons),5))]
+    ylabels = [round(i,2) for i  in list(np.linspace(np.max(lats),np.min(lats),5))]
+    for i in range(len(ts)):
+        t = ts[i]
+        #if i==0 : ax1=plt.subplot(gs[i])
+        ax1 = plt.subplot(gs[i])
+        plt.imshow(imgs[i] ,cmap = 'jet', vmin=270, vmax=300)
+        ax1.title.set_text("t = " + t.strftime("%Y%m%d%H"))
+        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+        if i == 0:
+            plt.setp([ax1], xticks = list(np.linspace(0, len(lons), 5)), xticklabels = xlables, yticks = list(np.linspace(0, len(lats), 5)), yticklabels = ylabels)
+            plt.ylabel(label, fontsize=10)
+    plt.savefig(os.path.join(output_png_dir, label + "_TS_" + str(ts[0]) + ".jpg"))
+    plt.clf()
+    output_fname = label + "_TS_" + ts[0].strftime("%Y%m%d%H") + ".jpg"
+    print("image {} saved".format(output_fname))
+
+    
+def get_persistence(ts):
+    pass
+
+
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--input_dir", type = str, required = True,
@@ -59,106 +346,52 @@ def main():
                                "the tfrecords")
     parser.add_argument("--results_dir", type = str, default = 'results',
                         help = "ignored if output_gif_dir is specified")
-    parser.add_argument("--results_gif_dir", type = str,
-                        help = "default is results_dir. ignored if output_gif_dir is specified")
-    parser.add_argument("--results_png_dir", type = str,
-                        help = "default is results_dir. ignored if output_png_dir is specified")
-    parser.add_argument("--output_gif_dir", help = "output directory where samples are saved as gifs. default is "
-                                                   "results_gif_dir/model_fname")
-    parser.add_argument("--output_png_dir", help = "output directory where samples are saved as pngs. default is "
-                                                   "results_png_dir/model_fname")
     parser.add_argument("--checkpoint",
                         help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
-
     parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val',
                         help = 'mode for dataset, val or test.')
-
     parser.add_argument("--dataset", type = str, help = "dataset class name")
     parser.add_argument("--dataset_hparams", type = str,
                         help = "a string of comma separated list of dataset hyperparameters")
     parser.add_argument("--model", type = str, help = "model class name")
     parser.add_argument("--model_hparams", type = str,
                         help = "a string of comma separated list of model hyperparameters")
-
     parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch")
     parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)")
     parser.add_argument("--num_epochs", type = int, default = 1)
-
     parser.add_argument("--num_stochastic_samples", type = int, default = 1)
     parser.add_argument("--gif_length", type = int, help = "default is sequence_length")
     parser.add_argument("--fps", type = int, default = 4)
-
     parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use")
     parser.add_argument("--seed", type = int, default = 7)
-
     args = parser.parse_args()
+    set_seed(args.seed)
 
-    if args.seed is not None:
-        tf.set_random_seed(args.seed)
-        np.random.seed(args.seed)
-        random.seed(args.seed)
-    #Bing:20200518 
-    input_dir = args.input_dir
-    temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/"
-    print ("temporal_dir:",temporal_dir)
-    args.results_gif_dir = args.results_gif_dir or args.results_dir
-    args.results_png_dir = args.results_png_dir or args.results_dir
     dataset_hparams_dict = {}
     model_hparams_dict = {}
-    if args.checkpoint:
-        checkpoint_dir = os.path.normpath(args.checkpoint)
-        if not os.path.isdir(args.checkpoint):
-            checkpoint_dir, _ = os.path.split(checkpoint_dir)
-        if not os.path.exists(checkpoint_dir):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
-        with open(os.path.join(checkpoint_dir, "options.json")) as f:
-            print("loading options from checkpoint %s" % args.checkpoint)
-            options = json.loads(f.read())
-            args.dataset = args.dataset or options['dataset']
-            args.model = args.model or options['model']
-        try:
-            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
-                dataset_hparams_dict = json.loads(f.read())
-        except FileNotFoundError:
-            print("dataset_hparams.json was not loaded because it does not exist")
-        try:
-            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
-                model_hparams_dict = json.loads(f.read())
-        except FileNotFoundError:
-            print("model_hparams.json was not loaded because it does not exist")
-        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir,
-                                                                  os.path.split(checkpoint_dir)[1])
-        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir,
-                                                                  os.path.split(checkpoint_dir)[1])
-    else:
-        if not args.dataset:
-            raise ValueError('dataset is required when checkpoint is not specified')
-        if not args.model:
-            raise ValueError('model is required when checkpoint is not specified')
-        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
-        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)
+
+    options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict = load_checkpoints_and_create_output_dirs(args.checkpoint,args.dataset,args.model)
+    print("Step 1 finished")
 
     print('----------------------------------- Options ------------------------------------')
     for k, v in args._get_kwargs():
         print(k, "=", v)
     print('------------------------------------- End --------------------------------------')
 
-    VideoDataset = datasets.get_dataset_class(args.dataset)
-    dataset = VideoDataset(
-        args.input_dir,
-        mode = args.mode,
-        num_epochs = args.num_epochs,
-        seed = args.seed,
-        hparams_dict = dataset_hparams_dict,
-        hparams = args.dataset_hparams)
-
-    VideoPredictionModel = models.get_model_class(args.model)
+    #setup dataset and model object
+    input_dir_tf = os.path.join(args.input_dir, "tfrecords") # where tensorflow records are stored
+    dataset = setup_dataset(dataset,input_dir_tf,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict)
+    
+    print("Step 2 finished")
+    VideoPredictionModel = models.get_model_class(model)
+    
     hparams_dict = dict(model_hparams_dict)
     hparams_dict.update({
         'context_frames': dataset.hparams.context_frames,
         'sequence_length': dataset.hparams.sequence_length,
         'repeat': dataset.hparams.time_shift,
     })
+    
     model = VideoPredictionModel(
         mode = args.mode,
         hparams_dict = hparams_dict,
@@ -166,32 +399,23 @@ def main():
 
     sequence_length = model.hparams.sequence_length
     context_frames = model.hparams.context_frames
-    future_length = sequence_length - context_frames
-
-    if args.num_samples:
-        if args.num_samples > dataset.num_examples_per_epoch():
-            raise ValueError('num_samples cannot be larger than the dataset')
-        num_examples_per_epoch = args.num_samples
-    else:
-        num_examples_per_epoch = dataset.num_examples_per_epoch()
-    if num_examples_per_epoch % args.batch_size != 0:
-        raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
+    future_length = sequence_length - context_frames #context_Frames is the number of input frames
 
+    num_examples_per_epoch = setup_num_samples_per_epoch(args.num_samples,dataset)
+    
     inputs = dataset.make_batch(args.batch_size)
+    print("inputs",inputs)
     input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    print("input_phs",input_phs)
+    
+    
+    # Build graph
     with tf.variable_scope(''):
         model.build_graph(input_phs)
 
-    for output_dir in (args.output_gif_dir, args.output_png_dir):
-        if not os.path.exists(output_dir):
-            os.makedirs(output_dir)
-        with open(os.path.join(output_dir, "options.json"), "w") as f:
-            f.write(json.dumps(vars(args), sort_keys = True, indent = 4))
-        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
-            f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4))
-        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
-            f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4))
-
+    #Write the update hparameters into results_dir    
+    write_params_to_results_dir(args=args,output_dir=args.results_dir,dataset=dataset,model=model)
+        
     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac)
     config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True)
     sess = tf.Session(config = config)
@@ -199,169 +423,375 @@ def main():
     sess.run(tf.global_variables_initializer())
     sess.run(tf.local_variables_initializer())
     model.restore(sess, args.checkpoint)
+    
+    #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data
+    sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data()
 
-    sample_ind = 0
-    gen_images_all = []
-    #Bing:20200410
-    persistent_images_all = []
-    input_images_all = []
-    #Bing:20201417
-    print ("temporal_dir:",temporal_dir)
-    test_temporal_pkl = pickle.load(open(os.path.join(temporal_dir,"T_test.pkl"),"rb"))
-    #val_temporal_pkl = pickle.load(open("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/T_val.pkl","rb"))
-    print("test temporal_pkl file looks like folowing", test_temporal_pkl)
-
-    #X_val = hickle.load("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/X_val.hkl")
-    X_test = hickle.load(os.path.join(temporal_dir,"X_test.hkl"))
     is_first=True
-
-    #+++Scarlet:20200528    
-    norm_cls  = Norm_data('T2')
-    norm = 'minmax'
-    with open(os.path.join(dirname(input_dir),"hickle/splits/statistics.json")) as js_file:
-         norm_cls.check_and_set_norm(json.load(js_file),norm)
-    #---Scarlet:20200528    
-    while True:
-        print("Sample id", sample_ind)
-        if sample_ind <= 24:
-            pass
-        elif sample_ind >= len(X_test):
+    #+++Scarlet:20200803    
+    lats, lons = get_coordinates(os.path.join(args.input_dir,"metadata.json"))
+            
+    #---Scarlet:20200803    
+    #while True:
+    #Change True to sample_id<=24 for debugging
+    
+    #loop for in samples
+    while sample_ind < 5:
+        gen_images_stochastic = []
+        if args.num_samples and sample_ind >= args.num_samples:
+            break
+        try:
+            input_results = sess.run(inputs)
+            input_images = input_results["images"]
+            #get the intial times
+            t_starts = input_results["T_start"]
+        except tf.errors.OutOfRangeError:
             break
-        else:
-            gen_images_stochastic = []
-            if args.num_samples and sample_ind >= args.num_samples:
-                break
-            try:
-                input_results = sess.run(inputs)
-                input_images = input_results["images"]
-
-
-            except tf.errors.OutOfRangeError:
-                break
-
-            feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
-            for stochastic_sample_ind in range(args.num_stochastic_samples):
-                input_images_all.extend(input_images)
-                with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files:
-                    pickle.dump(list(input_images_all), input_files)
-                
-                gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)
-                gen_images_stochastic.append(gen_images)
-                #print("Stochastic_sample,", stochastic_sample_ind)
-                for i in range(args.batch_size):
-                    #bing:20200417
-                    t_stampe = test_temporal_pkl[sample_ind+i]
-                    print("timestamp:",type(t_stampe))
-                    persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1)
-                    print ("persistent ts",persistent_ts)
-                    persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts))
-                    persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length]
-                    print("persistent index in test set:", persistent_idx)
-                    print("persistent_X.shape",persistent_X.shape)
-                    persistent_images_all.append(persistent_X)
+            
+        #Get prediction values 
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel]
+        
+        #Loop in batch size
+        for i in range(args.batch_size):
+            
+            #get one seq and the corresponding start time point
+            input_images_,t_start = get_one_seq_and_time(input_images,t_starts,i)
+            #generate time stamps for sequences
+            ts = generate_seq_timestamps(t_start,len_seq=sequence_length)
+            
+            #Renormalized data for inputs
+            stat_fl = os.path.join(args.input_dir,"pickle/statistics.json")
+            input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"])  
+            print("input_images_denorm",input_images_denorm[0][0])
+                                                             
+            #Renormalized data for inputs
+            gen_images_ = gen_images[i]
+            gen_images_denorm = denorm_images_all_channels(stat_fl,gen_images_,["T2","MSL","gph500"])
+            print("gene_images_denorm:",gen_images_denorm[0][0])
+            
+            #Save input to netCDF file
+            init_date_str = ts[0].strftime("%Y%m%d%H")
+            save_to_netcdf_per_sequence(args.results_dir,input_images_denorm,gen_images_denorm,lons,lats,ts,context_frames,future_length,args.model,fl_name="vfp_{}.nc".format(init_date_str))
+                                                             
+            #Generate images inputs
+            plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],lats=lats,lons=lons,ts=ts[:context_frames-1],label="Ground Truth",output_png_dir=args.results_dir)  
+                                                             
+            #Generate forecast images
+            plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) 
+            
+            #TODO: Scaret plot persistence image
+            #implment get_persistence() function
 
-                        
-                #print("batch", i)
-                #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
-
-
-                    cmap_name = 'my_list'
-                    if sample_ind < 100:
-                        #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
-                        #    sample_ind) + " + Sample_" + str(i)
-                        name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S")
-                        print ("name",name)
-                        gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
-                        #gen_images_ =  gen_images[i, :]
-                        input_images_ = input_images[i, :]
-                        #Bing:20200417
-                        #persistent_images = ?
-                        #+++Scarlet:20200528   
-                        #print('Scarlet1')
-                        input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm)
-                        persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm)
-                        #---Scarlet:20200528    
-                        gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
-                                        range(sequence_length)]  # return the list with 10 (sequence) mse
-                        persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in
-                                        range(sequence_length)]  # return the list with 10 (sequence) mse
-
-                        fig = plt.figure(figsize=(18,6))
-                        gs = gridspec.GridSpec(1, 10)
-                        gs.update(wspace = 0., hspace = 0.)
-                        ts = list(range(10,20)) #[10,11,12,..]
-                        xlables = [round(i,2) for i  in list(np.linspace(np.min(lon),np.max(lon),5))]
-                        ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
-
-                        for t in ts:
-
-                            #if t==0 : ax1=plt.subplot(gs[t])
-                            ax1 = plt.subplot(gs[ts.index(t)])
-                            #+++Scarlet:20200528
-                            #print('Scarlet2')
-                            input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm)
-                            #---Scarlet:20200528
-                            plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
-                            ax1.title.set_text("t = " + str(t+1-10))
-                            plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-                            if t == 0:
-                                plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
-                                plt.ylabel("Ground Truth", fontsize=10)
-                        plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
-                        plt.clf()
-
-                        fig = plt.figure(figsize=(12,6))
-                        gs = gridspec.GridSpec(1, 10)
-                        gs.update(wspace = 0., hspace = 0.)
-
-                        for t in ts:
-                            #if t==0 : ax1=plt.subplot(gs[t])
-                            ax1 = plt.subplot(gs[ts.index(t)])
-                            #+++Scarlet:20200528
-                            #print('Scarlet3')
-                            gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm)
-                            #---Scarlet:20200528
-                            plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
-                            ax1.title.set_text("t = " + str(t+1-10))
-                            plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-
-                        plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
-                        plt.clf()
-
-
-                        fig = plt.figure(figsize=(12,6))
-                        gs = gridspec.GridSpec(1, 10)
-                        gs.update(wspace = 0., hspace = 0.)
-                        for t in ts:
-                            #if t==0 : ax1=plt.subplot(gs[t])
-                            ax1 = plt.subplot(gs[ts.index(t)])
-                            #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
-                            plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300)
-                            ax1.title.set_text("t = " + str(t+1-10))
-                            plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-
-                        plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg"))
-                        plt.clf()
+            #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation
+
+        sample_ind += args.batch_size
+
+
+        #for input_image in input_images_:
+
+#             for stochastic_sample_ind in range(args.num_stochastic_samples):
+#                 input_images_all.extend(input_images)
+#                 with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files:
+#                     pickle.dump(list(input_images_all), input_files)
+
+
+#                 gen_images_stochastic.append(gen_images)
+#                 #print("Stochastic_sample,", stochastic_sample_ind)
+#                 for i in range(args.batch_size):
+#                     #bing:20200417
+#                     t_stampe = test_temporal_pkl[sample_ind+i]
+#                     print("timestamp:",type(t_stampe))
+#                     persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1)
+#                     print ("persistent ts",persistent_ts)
+#                     persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts))
+#                     persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length]
+#                     print("persistent index in test set:", persistent_idx)
+#                     print("persistent_X.shape",persistent_X.shape)
+#                     persistent_images_all.append(persistent_X)
+
+#                     cmap_name = 'my_list'
+#                     if sample_ind < 100:
+#                         #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
+#                         #    sample_ind) + " + Sample_" + str(i)
+#                         name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S")
+#                         print ("name",name)
+#                         gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
+#                         #gen_images_ =  gen_images[i, :]
+#                         input_images_ = input_images[i, :]
+#                         #Bing:20200417
+#                         #persistent_images = ?
+#                         #+++Scarlet:20200528   
+#                         #print('Scarlet1')
+#                         input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm)
+#                         persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm)
+#                         #---Scarlet:20200528    
+#                         gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
+#                                         range(sequence_length)]  # return the list with 10 (sequence) mse
+#                         persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in
+#                                         range(sequence_length)]  # return the list with 10 (sequence) mse
+
+#                         fig = plt.figure(figsize=(18,6))
+#                         gs = gridspec.GridSpec(1, 10)
+#                         gs.update(wspace = 0., hspace = 0.)
+#                         ts = list(range(10,20)) #[10,11,12,..]
+#                         xlables = [round(i,2) for i  in list(np.linspace(np.min(lon),np.max(lon),5))]
+#                         ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+
+#                         for t in ts:
+
+#                             #if t==0 : ax1=plt.subplot(gs[t])
+#                             ax1 = plt.subplot(gs[ts.index(t)])
+#                             #+++Scarlet:20200528
+#                             #print('Scarlet2')
+#                             input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm)
+#                             #---Scarlet:20200528
+#                             plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
+#                             ax1.title.set_text("t = " + str(t+1-10))
+#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+#                             if t == 0:
+#                                 plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
+#                                 plt.ylabel("Ground Truth", fontsize=10)
+#                         plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
+#                         plt.clf()
+
+#                         fig = plt.figure(figsize=(12,6))
+#                         gs = gridspec.GridSpec(1, 10)
+#                         gs.update(wspace = 0., hspace = 0.)
+
+#                         for t in ts:
+#                             #if t==0 : ax1=plt.subplot(gs[t])
+#                             ax1 = plt.subplot(gs[ts.index(t)])
+#                             #+++Scarlet:20200528
+#                             #print('Scarlet3')
+#                             gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm)
+#                             #---Scarlet:20200528
+#                             plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
+#                             ax1.title.set_text("t = " + str(t+1-10))
+#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+#                         plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
+#                         plt.clf()
+
+
+#                         fig = plt.figure(figsize=(12,6))
+#                         gs = gridspec.GridSpec(1, 10)
+#                         gs.update(wspace = 0., hspace = 0.)
+#                         for t in ts:
+#                             #if t==0 : ax1=plt.subplot(gs[t])
+#                             ax1 = plt.subplot(gs[ts.index(t)])
+#                             #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+#                             plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300)
+#                             ax1.title.set_text("t = " + str(t+1-10))
+#                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+#                         plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg"))
+#                         plt.clf()
 
                         
-                with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files:
-                    pickle.dump(list(persistent_images_all), input_files)
-                    print ("Save persistent all")
-                if is_first:
-                    gen_images_all = gen_images_stochastic
-                    is_first = False
-                else:
-                    gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
-
-                if args.num_stochastic_samples == 1:
-                    with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files:
-                        pickle.dump(list(gen_images_all[0]), gen_files)
-                        print ("Save generate all")
-                else:
-                    with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
-                        pickle.dump(list(gen_images_stochastic), gen_files)
-                    with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
-                        pickle.dump(list(gen_images_all), gen_files)
+#                 with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files:
+#                     pickle.dump(list(persistent_images_all), input_files)
+#                     print ("Save persistent all")
+#                 if is_first:
+#                     gen_images_all = gen_images_stochastic
+#                     is_first = False
+#                 else:
+#                     gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
+
+#                 if args.num_stochastic_samples == 1:
+#                     with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files:
+#                         pickle.dump(list(gen_images_all[0]), gen_files)
+#                         print ("Save generate all")
+#                 else:
+#                     with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
+#                         pickle.dump(list(gen_images_stochastic), gen_files)
+#                     with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
+#                         pickle.dump(list(gen_images_all), gen_files)
+
+#         sample_ind += args.batch_size
+
+
+#     with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files:
+#         input_images_all = pickle.load(input_files)
+
+#     with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files:
+#         gen_images_all = pickle.load(gen_files)
+
+#     with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files:
+#         persistent_images_all = pickle.load(gen_files)
+
+#     #+++Scarlet:20200528
+#     #print('Scarlet4')
+#     input_images_all = np.array(input_images_all)
+#     input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm)
+#     #---Scarlet:20200528
+#     persistent_images_all = np.array(persistent_images_all)
+#     if len(np.array(gen_images_all).shape) == 6:
+#         for i in range(len(gen_images_all)):
+#             #+++Scarlet:20200528
+#             #print('Scarlet5')
+#             gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:]
+#             gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm)
+#             #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+#             #---Scarlet:20200528
+#             mse_all = []
+#             psnr_all = []
+#             ssim_all = []
+#             f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w')
+#             for i in range(future_length):
+#                 mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :,
+#                                                                             0]) ** 2)  # look at all timesteps except the first
+#                 psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0])
+#                 ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0],
+#                                   data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min(
+#                                       input_images_all[:, i + 10, :, :, 0].flatten()))
+#                 mse_all.extend([mse_model])
+#                 psnr_all.extend([psnr_model])
+#                 ssim_all.extend([ssim_model])
+#                 results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
+#                 f.write("##########Predicted Frame {}\n".format(str(i + 1)))
+#                 f.write("Model MSE: %f\n" % mse_model)
+#                 # f.write("Previous Frame MSE: %f\n" % mse_prev)
+#                 f.write("Model PSNR: %f\n" % psnr_model)
+#                 f.write("Model SSIM: %f\n" % ssim_model)
+
+
+#             pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb"))
+#             # f.write("Previous frame PSNR: %f\n" % psnr_prev)
+#             f.write("Shape of X_test: " + str(input_images_all.shape))
+#             f.write("")
+#             f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape))
+
+#     else:
+#         #+++Scarlet:20200528
+#         #print('Scarlet6')
+#         gen_images_all = np.array(gen_images_all)
+#         gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm)
+#         #---Scarlet:20200528
+        
+#         # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
+#         # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
+#         # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
+#         mse_all = []
+#         psnr_all = []
+#         ssim_all = []
+#         persistent_mse_all = []
+#         persistent_psnr_all = []
+#         persistent_ssim_all = []
+#         f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w')
+#         for i in range(future_length):
+#             mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :,
+#                                                                         0]) ** 2)  # look at all timesteps except the first
+#             persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :,
+#                                                                         0]) ** 2)  # look at all timesteps except the first
+            
+#             psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0])
+#             ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0],
+#                               data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min(
+#                                   input_images_all[:, i + 10, :, :, 0].flatten()))
+#             persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0])
+#             persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0],
+#                               data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten()))
+#             mse_all.extend([mse_model])
+#             psnr_all.extend([psnr_model])
+#             ssim_all.extend([ssim_model])
+#             persistent_mse_all.extend([persistent_mse_model])
+#             persistent_psnr_all.extend([persistent_psnr_model])
+#             persistent_ssim_all.extend([persistent_ssim_model])
+#             results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
+
+#             persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all}
+#             f.write("##########Predicted Frame {}\n".format(str(i + 1)))
+#             f.write("Model MSE: %f\n" % mse_model)
+#             # f.write("Previous Frame MSE: %f\n" % mse_prev)
+#             f.write("Model PSNR: %f\n" % psnr_model)
+#             f.write("Model SSIM: %f\n" % ssim_model)
+
+#         pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb"))
+#         pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb"))
+#         # f.write("Previous frame PSNR: %f\n" % psnr_prev)
+#         f.write("Shape of X_test: " + str(input_images_all.shape))
+#         f.write("")
+#         f.write("Shape of X_hat: " + str(gen_images_all.shape)      
+
+if __name__ == '__main__':
+    main()        
+
+    #psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
+    #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
+    #psnr_prev = psnr(input_images_all[:, :, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
+
+    # ims = []
+    # fig = plt.figure()
+    # for frame in range(20):
+    #     input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel)
+    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
+    #     #pix_std = np.std(input_gen_diff, axis=0)
+    #     im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
+    #     if frame == 0:
+    #         fig.colorbar(im)
+    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
+    #     ims.append([im, ttl])
+    # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000)
+    # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
+    # plt.close("all")
+
+    # ims = []
+    # fig = plt.figure()
+    # for frame in range(19):
+    #     pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
+    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
+    #     #pix_std = np.std(input_gen_diff, axis=0)
+    #     im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
+    #     if frame == 0:
+    #         fig.colorbar(im)
+    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
+    #     ims.append([im, ttl])
+    # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
+
+    # seed(1)
+    # s = random.sample(range(len(gen_images_all)), 100)
+    # print("******KDP******")
+    # #kernel density plot for checking the model collapse
+    # fig = plt.figure()
+    # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
+    # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
+    # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
+    # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
+    # plt.clf()
+
+    #line plot for evaluating the prediction and groud-truth
+    # for i in [0,3,6,9,12,15,18]:
+    #     fig = plt.figure()
+    #     plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
+    #     #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
+    #     plt.xlabel("Prediction")
+    #     plt.ylabel("Real values")
+    #     plt.title("Frame_{}".format(i+1))
+    #     plt.plot([250,300], [250,300],color="black")
+    #     plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
+    #     plt.clf()
+    #
+    # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
+    # x = [str(i+1) for i in list(range(19))]
+    # fig,axis = plt.subplots()
+    # mean_f = np.mean(mse_model_by_frames, axis = 0)
+    # median = np.median(mse_model_by_frames, axis=0)
+    # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
+    # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
+    # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
+    # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
+    # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
+    # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
+    # plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
+    # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
+    # plt.title(f'MSE percentile')
+    # plt.xlabel("Frames")
+    # plt.legend(loc=2, fontsize=8)
+    # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
+
+
 ##                
 ##
 ##                    # fig = plt.figure()
@@ -457,7 +887,9 @@ def main():
 ####                else:
 ####                    pass
 ##
-        sample_ind += args.batch_size
+
+
+
 
 
     #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
@@ -545,187 +977,3 @@ def main():
     #             #     else:
     #             #       gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
     #             #     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
-
-
-
-    with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files:
-        input_images_all = pickle.load(input_files)
-
-    with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files:
-        gen_images_all = pickle.load(gen_files)
-
-    with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files:
-        persistent_images_all = pickle.load(gen_files)
-
-    #+++Scarlet:20200528
-    #print('Scarlet4')
-    input_images_all = np.array(input_images_all)
-    input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm)
-    #---Scarlet:20200528
-    persistent_images_all = np.array(persistent_images_all)
-    if len(np.array(gen_images_all).shape) == 6:
-        for i in range(len(gen_images_all)):
-            #+++Scarlet:20200528
-            #print('Scarlet5')
-            gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:]
-            gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm)
-            #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
-            #---Scarlet:20200528
-            mse_all = []
-            psnr_all = []
-            ssim_all = []
-            f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w')
-            for i in range(future_length):
-                mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :,
-                                                                            0]) ** 2)  # look at all timesteps except the first
-                psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0])
-                ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0],
-                                  data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min(
-                                      input_images_all[:, i + 10, :, :, 0].flatten()))
-                mse_all.extend([mse_model])
-                psnr_all.extend([psnr_model])
-                ssim_all.extend([ssim_model])
-                results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
-                f.write("##########Predicted Frame {}\n".format(str(i + 1)))
-                f.write("Model MSE: %f\n" % mse_model)
-                # f.write("Previous Frame MSE: %f\n" % mse_prev)
-                f.write("Model PSNR: %f\n" % psnr_model)
-                f.write("Model SSIM: %f\n" % ssim_model)
-
-
-            pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb"))
-            # f.write("Previous frame PSNR: %f\n" % psnr_prev)
-            f.write("Shape of X_test: " + str(input_images_all.shape))
-            f.write("")
-            f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape))
-
-    else:
-        #+++Scarlet:20200528
-        #print('Scarlet6')
-        gen_images_all = np.array(gen_images_all)
-        gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm)
-        #---Scarlet:20200528
-        
-        # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
-        # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
-        # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
-        mse_all = []
-        psnr_all = []
-        ssim_all = []
-        persistent_mse_all = []
-        persistent_psnr_all = []
-        persistent_ssim_all = []
-        f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w')
-        for i in range(future_length):
-            mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :,
-                                                                        0]) ** 2)  # look at all timesteps except the first
-            persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :,
-                                                                        0]) ** 2)  # look at all timesteps except the first
-            
-            psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0])
-            ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0],
-                              data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min(
-                                  input_images_all[:, i + 10, :, :, 0].flatten()))
-            persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0])
-            persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0],
-                              data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten()))
-            mse_all.extend([mse_model])
-            psnr_all.extend([psnr_model])
-            ssim_all.extend([ssim_model])
-            persistent_mse_all.extend([persistent_mse_model])
-            persistent_psnr_all.extend([persistent_psnr_model])
-            persistent_ssim_all.extend([persistent_ssim_model])
-            results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
-
-            persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all}
-            f.write("##########Predicted Frame {}\n".format(str(i + 1)))
-            f.write("Model MSE: %f\n" % mse_model)
-            # f.write("Previous Frame MSE: %f\n" % mse_prev)
-            f.write("Model PSNR: %f\n" % psnr_model)
-            f.write("Model SSIM: %f\n" % ssim_model)
-
-        pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb"))
-        pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb"))
-        # f.write("Previous frame PSNR: %f\n" % psnr_prev)
-        f.write("Shape of X_test: " + str(input_images_all.shape))
-        f.write("")
-        f.write("Shape of X_hat: " + str(gen_images_all.shape))
-
-
-
-    #psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
-    #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
-    #psnr_prev = psnr(input_images_all[:, :, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
-
-    # ims = []
-    # fig = plt.figure()
-    # for frame in range(20):
-    #     input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel)
-    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
-    #     #pix_std = np.std(input_gen_diff, axis=0)
-    #     im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
-    #     if frame == 0:
-    #         fig.colorbar(im)
-    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
-    #     ims.append([im, ttl])
-    # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000)
-    # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
-    # plt.close("all")
-
-    # ims = []
-    # fig = plt.figure()
-    # for frame in range(19):
-    #     pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
-    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
-    #     #pix_std = np.std(input_gen_diff, axis=0)
-    #     im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
-    #     if frame == 0:
-    #         fig.colorbar(im)
-    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
-    #     ims.append([im, ttl])
-    # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
-    # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
-
-    # seed(1)
-    # s = random.sample(range(len(gen_images_all)), 100)
-    # print("******KDP******")
-    # #kernel density plot for checking the model collapse
-    # fig = plt.figure()
-    # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
-    # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
-    # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
-    # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
-    # plt.clf()
-
-    #line plot for evaluating the prediction and groud-truth
-    # for i in [0,3,6,9,12,15,18]:
-    #     fig = plt.figure()
-    #     plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
-    #     #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
-    #     plt.xlabel("Prediction")
-    #     plt.ylabel("Real values")
-    #     plt.title("Frame_{}".format(i+1))
-    #     plt.plot([250,300], [250,300],color="black")
-    #     plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
-    #     plt.clf()
-    #
-    # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
-    # x = [str(i+1) for i in list(range(19))]
-    # fig,axis = plt.subplots()
-    # mean_f = np.mean(mse_model_by_frames, axis = 0)
-    # median = np.median(mse_model_by_frames, axis=0)
-    # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
-    # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
-    # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
-    # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
-    # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
-    # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
-    # plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
-    # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
-    # plt.title(f'MSE percentile')
-    # plt.xlabel("Frames")
-    # plt.legend(loc=2, fontsize=8)
-    # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
-
-if __name__ == '__main__':
-    main()
diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py
index 2f892f69c901f1eaa0a7ce2e57a3d0f6f131a7f9..805e81d79707338665f0fe2b1fad3b212359c233 100644
--- a/video_prediction_savp/scripts/train_dummy.py
+++ b/video_prediction_savp/scripts/train_dummy.py
@@ -11,6 +11,16 @@ import time
 import numpy as np
 import tensorflow as tf
 from video_prediction import datasets, models
+import matplotlib.pyplot as plt
+from json import JSONEncoder
+import pickle as pkl
+
+
+class NumpyArrayEncoder(JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, np.ndarray):
+            return obj.tolist()
+        return JSONEncoder.default(self, obj)
 
 
 def add_tag_suffix(summary, tag_suffix):
@@ -23,41 +33,11 @@ def add_tag_suffix(summary, tag_suffix):
         value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:])
     return summary.SerializeToString()
 
-
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
-                                                                     "train, val, test, etc, or a directory containing "
-                                                                     "the tfrecords")
-    parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
-    parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
-    parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
-                                             "default is logs_dir/model_fname, where model_fname consists of "
-                                             "information from model and model_hparams")
-    parser.add_argument("--output_dir_postfix", default="")
-    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
-    parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
-
-    parser.add_argument("--dataset", type=str, help="dataset class name")
-    parser.add_argument("--model", type=str, help="model class name")
-    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
-    parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
-
-   # parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training")
-    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
-    parser.add_argument("--seed", type=int)
-
-    args = parser.parse_args()
-
-    if args.seed is not None:
-        tf.set_random_seed(args.seed)
-        np.random.seed(args.seed)
-        random.seed(args.seed)
-
-    if args.output_dir is None:
+def generate_output_dir(output_dir, model,model_hparams,logs_dir,output_dir_postfix):
+    if output_dir is None:
         list_depth = 0
         model_fname = ''
-        for t in ('model=%s,%s' % (args.model, args.model_hparams)):
+        for t in ('model=%s,%s' % (model, model_hparams)):
             if t == '[':
                 list_depth += 1
             if t == ']':
@@ -69,52 +49,78 @@ def main():
             if t in '[]':
                 t = ''
             model_fname += t
-        args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix
-
-    if args.resume:
-        if args.checkpoint:
+        output_dir = os.path.join(logs_dir, model_fname) + output_dir_postfix
+    return output_dir
+
+
+def get_model_hparams_dict(model_hparams_dict):
+    """
+    Get model_hparams_dict from json file
+    """
+    model_hparams_dict_load = {}
+    if model_hparams_dict:
+        with open(model_hparams_dict) as f:
+            model_hparams_dict_load.update(json.loads(f.read()))
+    return model_hparams_dict
+
+def resume_checkpoint(resume,checkpoint,output_dir):
+    """
+    Resume the existing model checkpoints and set checkpoint directory
+    """
+    if resume:
+        if checkpoint:
             raise ValueError('resume and checkpoint cannot both be specified')
-        args.checkpoint = args.output_dir
+        checkpoint = output_dir
+    return checkpoint
 
+def set_seed(seed):
+    if seed is not None:
+        tf.set_random_seed(seed)
+        np.random.seed(seed)
+        random.seed(seed)
 
-    model_hparams_dict = {}
-    if args.model_hparams_dict:
-        with open(args.model_hparams_dict) as f:
-            model_hparams_dict.update(json.loads(f.read()))
-    if args.checkpoint:
-        checkpoint_dir = os.path.normpath(args.checkpoint)
-        if not os.path.isdir(args.checkpoint):
+def load_params_from_checkpoints_dir(model_hparams_dict,checkpoint,dataset,model):
+   
+    model_hparams_dict_load = {}
+    if model_hparams_dict:
+        with open(model_hparams_dict) as f:
+            model_hparams_dict_load.update(json.loads(f.read()))
+ 
+    if checkpoint:
+        checkpoint_dir = os.path.normpath(checkpoint)
+        if not os.path.isdir(checkpoint):
             checkpoint_dir, _ = os.path.split(checkpoint_dir)
         if not os.path.exists(checkpoint_dir):
             raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
         with open(os.path.join(checkpoint_dir, "options.json")) as f:
             print("loading options from checkpoint %s" % args.checkpoint)
             options = json.loads(f.read())
-            args.dataset = args.dataset or options['dataset']
-            args.model = args.model or options['model']
+            dataset = dataset or options['dataset']
+            model = model or options['model']
         try:
             with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
-                model_hparams_dict.update(json.loads(f.read()))
+                model_hparams_dict_load.update(json.loads(f.read()))
         except FileNotFoundError:
             print("model_hparams.json was not loaded because it does not exist")
+    return dataset, model, model_hparams_dict_load
 
-    print('----------------------------------- Options ------------------------------------')
-    for k, v in args._get_kwargs():
-        print(k, "=", v)
-    print('------------------------------------- End --------------------------------------')
-
-    VideoDataset = datasets.get_dataset_class(args.dataset)
+def setup_dataset(dataset,input_dir,val_input_dir):
+    VideoDataset = datasets.get_dataset_class(dataset)
     train_dataset = VideoDataset(
-        args.input_dir,
+        input_dir,
         mode='train')
     val_dataset = VideoDataset(
-        args.val_input_dir or args.input_dir,
+        val_input_dir or input_dir,
         mode='val')
-
     variable_scope = tf.get_variable_scope()
     variable_scope.set_use_resource(True)
+    return train_dataset,val_dataset,variable_scope
 
-    VideoPredictionModel = models.get_model_class(args.model)
+def setup_model(model,model_hparams_dict,train_dataset,model_hparams):
+    """
+    Set up model instance
+    """
+    VideoPredictionModel = models.get_model_class(model)
     hparams_dict = dict(model_hparams_dict)
     hparams_dict.update({
         'context_frames': train_dataset.hparams.context_frames,
@@ -123,9 +129,21 @@ def main():
     })
     model = VideoPredictionModel(
         hparams_dict=hparams_dict,
-        hparams=args.model_hparams)
+        hparams=model_hparams)
+    return model
 
-    batch_size = model.hparams.batch_size
+def save_dataset_model_params_to_checkpoint_dir(args,output_dir,train_dataset,model):
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+    with open(os.path.join(output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+    return None
+
+def make_dataset_iterator(train_dataset, val_dataset, batch_size ):
     train_tf_dataset = train_dataset.make_dataset_v2(batch_size)
     train_iterator = train_tf_dataset.make_one_shot_iterator()
     # The `Iterator.string_handle()` method returns a tensor that can be evaluated
@@ -138,18 +156,87 @@ def main():
     #    train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
     inputs = train_iterator.get_next()
     val = val_iterator.get_next()
-   
-    model.build_graph(inputs)
+    return inputs,train_handle, val_handle
+
+
+def plot_train(train_losses,val_losses,output_dir):
+    iterations = list(range(len(train_losses))) 
+    plt.plot(iterations, train_losses, 'g', label='Training loss')
+    plt.plot(iterations, val_losses, 'b', label='validation loss')
+    plt.title('Training and Validation loss')
+    plt.xlabel('Iterations')
+    plt.ylabel('Loss')
+    plt.legend()
+    plt.savefig(os.path.join(output_dir,'plot_train.png'))
+
+def save_results_to_dict(results_dict,output_dir):
+    with open(os.path.join(output_dir,"results.json"),"w") as fp:
+        json.dump(results_dict,fp)    
+
+def save_results_to_pkl(train_losses,val_losses, output_dir):
+     with open(os.path.join(output_dir,"train_losses.pkl"),"wb") as f:
+        pkl.dump(train_losses,f)
+     with open(os.path.join(output_dir,"val_losses.pkl"),"wb") as f:
+        pkl.dump(val_losses,f) 
+ 
 
-    if not os.path.exists(args.output_dir):
-        os.makedirs(args.output_dir)
-    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
-        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
-    with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
-        f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
-    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
-        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
+    parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
+    parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
+                                             "default is logs_dir/model_fname, where model_fname consists of "
+                                             "information from model and model_hparams")
+    parser.add_argument("--output_dir_postfix", default="")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+    parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
 
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") 
+    parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
+
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed",default=1234, type=int)
+
+    args = parser.parse_args()
+     
+    #Set seed  
+    set_seed(args.seed)
+    
+    #setup output directory
+    args.output_dir = generate_output_dir(args.output_dir, args.model, args.model_hparams, args.logs_dir, args.output_dir_postfix)
+    
+    #resume the existing checkpoint and set up the checkpoint directory to output directory
+    args.checkpoint = resume_checkpoint(args.resume,args.checkpoint,args.output_dir)
+ 
+    #get model hparams dict from json file
+    #load the existing checkpoint related datasets, model configure (This information was stored in the checkpoint dir when last time training model)
+    args.dataset,args.model,model_hparams_dict = load_params_from_checkpoints_dir(args.model_hparams_dict,args.checkpoint,args.dataset,args.model)
+     
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+    #setup training val datset instance
+    train_dataset,val_dataset,variable_scope = setup_dataset(args.dataset,args.input_dir,args.val_input_dir)
+    
+    #setup model instance 
+    model=setup_model(args.model,model_hparams_dict,train_dataset,args.model_hparams)
+
+    batch_size = model.hparams.batch_size
+    #Create input and val iterator
+    inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size)
+    
+    #build model graph
+    model.build_graph(inputs)
+    
+    #save all the model, data params to output dirctory
+    save_dataset_model_params_to_checkpoint_dir(args,args.output_dir,train_dataset,model)
+    
     with tf.name_scope("parameter_count"):
         # exclude trainable variables that are replicas (used in multi-gpu setting)
         trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables)
@@ -162,113 +249,88 @@ def main():
 
     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True)
     config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
-  
  
-    max_steps = model.hparams.max_steps
-    print ("max_steps",max_steps)
+    max_epochs = model.hparams.max_epochs #the number of epochs
+    num_examples_per_epoch = train_dataset.num_examples_per_epoch()
+    print ("number of exmaples per epoch:",num_examples_per_epoch)
+    steps_per_epoch = int(num_examples_per_epoch/batch_size)
+    total_steps = steps_per_epoch * max_epochs
+    #mock total_steps only for fast debugging
+    #total_steps = 10
+    print ("Total steps for training:",total_steps)
+    results_dict = {}
     with tf.Session(config=config) as sess:
         print("parameter_count =", sess.run(parameter_count))
         sess.run(tf.global_variables_initializer())
         sess.run(tf.local_variables_initializer())
-        #coord = tf.train.Coordinator()
-        #threads = tf.train.start_queue_runners(sess = sess, coord = coord)
-        print("Init done: {sess.run(tf.local_variables_initializer())}%")
-        model.restore(sess, args.checkpoint)
-
-        #sess.run(model.post_init_ops)
-
-        #val_handle_eval = sess.run(val_handle)
-        #print ("val_handle_val",val_handle_eval)
-        #print("val handle done")
+        #model.restore(sess, args.checkpoint)
         sess.graph.finalize()
         start_step = sess.run(model.global_step)
-
-
+        print("start_step", start_step)
         # start at one step earlier to log everything without doing any training
         # step is relative to the start_step
-        for step in range(-1, max_steps - start_step):
+        train_losses=[]
+        val_losses=[]
+        run_start_time = time.time()        
+        for step in range(total_steps):
             global_step = sess.run(model.global_step)
             print ("global_step:", global_step)
             val_handle_eval = sess.run(val_handle)
-
-            if step == 1:
-                # skip step -1 and 0 for timing purposes (for warmstarting)
-                start_time = time.time()
             
+            #Fetch variables in the graph
             fetches = {"global_step":model.global_step}
             fetches["train_op"] = model.train_op
-
-           # fetches["latent_loss"] = model.latent_loss
+            #fetches["latent_loss"] = model.latent_loss
             fetches["total_loss"] = model.total_loss
+
+            #fetch the specific loss function only for mcnet
             if model.__class__.__name__ == "McNetVideoPredictionModel":
                 fetches["L_p"] = model.L_p
                 fetches["L_gdl"] = model.L_gdl
                 fetches["L_GAN"]  =model.L_GAN
-          
-         
-
-            fetches["summary"] = model.summary_op
-
-            run_start_time = time.time()
-            #Run training results
-            #X = inputs["images"].eval(session=sess)           
-
+            
+            if model.__class__.__name__ == "SAVP":
+                #todo
+                pass
+            
+            fetches["summary"] = model.summary_op       
             results = sess.run(fetches)
-
-            run_elapsed_time = time.time() - run_start_time
-            if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}:
-                print('running train_op took too long (%0.1fs)' % run_elapsed_time)
-
-            #Run testing results
-            #val_fetches = {"global_step":global_step}
+            train_losses.append(results["total_loss"])          
+            #Fetch losses for validation data
             val_fetches = {}
             #val_fetches["latent_loss"] = model.latent_loss
-            #val_fetches["total_loss"] = model.total_loss
+            val_fetches["total_loss"] = model.total_loss
             val_fetches["summary"] = model.summary_op
             val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval})
-          
+            val_losses.append(val_results["total_loss"])
+            
             summary_writer.add_summary(results["summary"])
             summary_writer.add_summary(val_results["summary"])
-             
-
-
-           
-            val_datasets = [val_dataset]
-            val_models = [model]
-
-            # for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)):
-            #     sess.run(val_model.accum_eval_metrics_reset_op)
-            #     # traverse (roughly up to rounding based on the batch size) all the validation dataset
-            #     accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size
-            #     val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op}
-            #     for update_step in range(accum_eval_summary_num_updates):
-            #         print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates))
-            #         val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
-            #     accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1))
-            #     print("recording accum eval summary")
-            #     summary_writer.add_summary(accum_eval_summary, val_results["global_step"])
             summary_writer.flush()
-
+             
             # global_step will have the correct step count if we resume from a checkpoint
-            # global step is read before it's incremented
-            steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size
-            #train_epoch = results["global_step"] / steps_per_epoch
+            # global step is read before it's incemented
             train_epoch = global_step/steps_per_epoch
             print("progress  global step %d  epoch %0.1f" % (global_step + 1, train_epoch))
-            if step > 0:
-                elapsed_time = time.time() - start_time
-                average_time = elapsed_time / step
-                images_per_sec = batch_size / average_time
-                remaining_time = (max_steps - (start_step + step + 1)) * average_time
-                print("image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)" %
-                      (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))
 
-
-            print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"]))
+            if model.__class__.__name__ == "McNetVideoPredictionModel":
+              print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"]))
+            elif model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
+                print ("Total_loss:{}".format(results["total_loss"]))
+            else:
+                print ("The model name does not exist")
             
-            print("saving model to", args.output_dir)
-            saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)##Bing: cheat here a little bit because of the global step issue
-            print("done")
-
+            #print("saving model to", args.output_dir)
+            saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)#
+        train_time = time.time() - run_start_time
+        results_dict = {"train_time":train_time,
+                        "total_steps":total_steps}
+        save_results_to_dict(results_dict,args.output_dir)
+        save_results_to_pkl(train_losses, val_losses, args.output_dir)
+        print("train_losses:",train_losses)
+        print("val_losses:",val_losses) 
+        plot_train(train_losses,val_losses,args.output_dir)
+        print("Done")
+        
 if __name__ == '__main__':
     main()
diff --git a/video_prediction_savp/video_prediction/datasets/__init__.py b/video_prediction_savp/video_prediction/datasets/__init__.py
index 736b8202172051f36586db5579c545863c72e14d..e449a65bd48b14ef5a11e6846ce4d8f39f7ed193 100644
--- a/video_prediction_savp/video_prediction/datasets/__init__.py
+++ b/video_prediction_savp/video_prediction/datasets/__init__.py
@@ -7,7 +7,7 @@ from .kth_dataset import KTHVideoDataset
 from .ucf101_dataset import UCF101VideoDataset
 from .cartgripper_dataset import CartgripperVideoDataset
 from .era5_dataset_v2 import ERA5Dataset_v2
-from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly
+#from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly
 
 def get_dataset_class(dataset):
     dataset_mappings = {
@@ -19,7 +19,7 @@ def get_dataset_class(dataset):
         'ucf101': 'UCF101VideoDataset',
         'cartgripper': 'CartgripperVideoDataset',
         "era5":"ERA5Dataset_v2",
-        "era5_anomaly":"ERA5Dataset_v2_anomaly",
+#        "era5_anomaly":"ERA5Dataset_v2_anomaly",
     }
     dataset_class = dataset_mappings.get(dataset, dataset)
     print("datset_class",dataset_class)
diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
index 9e32e0c638b4b39c588389e906ba29be5144ee35..a3c9fc3666eb21ef02f9e5f64c8c95a29034d619 100644
--- a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
+++ b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
@@ -5,7 +5,6 @@ import os
 import pickle
 import random
 import re
-import hickle as hkl
 import numpy as np
 import json
 import tensorflow as tf
@@ -17,9 +16,14 @@ sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/'))
 import DataPreprocess.process_netCDF_v2 
 from DataPreprocess.process_netCDF_v2 import get_unique_vars
 from DataPreprocess.process_netCDF_v2 import Calc_data_stat
+from metadata import MetaData
 #from base_dataset import VarLenFeatureVideoDataset
 from collections import OrderedDict
 from tensorflow.contrib.training import HParams
+from mpi4py import MPI
+import glob
+
+
 
 class ERA5Dataset_v2(VarLenFeatureVideoDataset):
     def __init__(self, *args, **kwargs):
@@ -28,6 +32,7 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
         example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
         dict_message = MessageToDict(tf.train.Example.FromString(example))
         feature = dict_message['features']['feature']
+        print("features in dataset:",feature.keys())
         self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
         self.image_shape = self.video_shape[1:]
         self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
@@ -57,7 +62,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
         sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
         return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
 
-
     def filter(self, serialized_example):
         return tf.convert_to_tensor(True)
 
@@ -70,7 +74,8 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
                 'height': tf.FixedLenFeature([], tf.int64),
                 'sequence_length': tf.FixedLenFeature([], tf.int64),
                 'channels': tf.FixedLenFeature([],tf.int64),
-                # 'images/encoded':  tf.FixedLenFeature([], tf.string)
+                #'t_start':  tf.FixedLenFeature([], tf.string),
+                't_start':  tf.VarLenFeature(tf.int64),
                 'images/encoded': tf.VarLenFeature(tf.float32)
             }
             
@@ -79,28 +84,22 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
             parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
             print ("Parse features", parsed_features)
             seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
-           # width = tf.sparse_tensor_to_dense(parsed_features["width"])
+            T_start = tf.sparse_tensor_to_dense(parsed_features["t_start"])
+            print("T_start in make dataset_v2", T_start)
+            #width = tf.sparse_tensor_to_dense(parsed_features["width"])
            # height = tf.sparse_tensor_to_dense(parsed_features["height"])
            # channels  = tf.sparse_tensor_to_dense(parsed_features["channels"])
            # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
             images = []
-            # for i in range(20):
-            #    images.append(parsed_features["images/encoded"].values[i])
-            # images = parsed_features["images/encoded"]
-            # images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets)
-            # seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '')
-            # Parse the string into an array of pixels corresponding to the image
-            # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32)
-
-            # images = seq
             print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
             images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
             seqs["images"] = images
+            seqs["T_start"] = T_start
             return seqs
         filenames = self.filenames
         print ("FILENAMES",filenames)
-	#TODO:
-	#temporal_filenames = self.temporal_filenames
+	    #TODO:
+	    #temporal_filenames = self.temporal_filenames
         shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
         if shuffle:
             random.shuffle(filenames)
@@ -121,7 +120,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
         # dataset = dataset.apply(tf.contrib.data.map_and_batch(
         #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
         dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU
-
         return dataset
 
 
@@ -139,24 +137,26 @@ def _bytes_list_feature(values):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
 
 def _floats_feature(value):
-  return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
 
 def _int64_feature(value):
     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
-def save_tf_record(output_fname, sequences):
-    print('saving sequences to %s' % output_fname)
+def save_tf_record(output_fname, sequences,T_start_points):
     with tf.python_io.TFRecordWriter(output_fname) as writer:
-        for sequence in sequences:
+        for i in range(len(sequences)):
+            sequence = sequences[i]
+            T_start = T_start_points[i][0].strftime("%Y%m%d%H")
+            print("T_start:",T_start)
             num_frames = len(sequence)
             height, width, channels = sequence[0].shape
             encoded_sequence = np.array([list(image) for image in sequence])
-
             features = tf.train.Features(feature={
                 'sequence_length': _int64_feature(num_frames),
                 'height': _int64_feature(height),
                 'width': _int64_feature(width),
                 'channels': _int64_feature(channels),
+                't_start': _int64_feature(int(T_start)),
                 'images/encoded': _floats_feature(encoded_sequence.flatten()),
             })
             example = tf.train.Example(features=features)
@@ -207,91 +207,107 @@ class Norm_data:
             for stat_name in self.known_norms[norm]:
                 #setattr(self,varname+stat_name,stat_dict[varname][0][stat_name])
                 setattr(self,varname+stat_name,Calc_data_stat.get_stat_vars(stat_dict,stat_name,varname))
-                
+
         self.status_ok = True           # set status for normalization -> ready
-                
+
     def norm_var(self,data,varname,norm):
         """ 
          Performs given normalization on input data (given that the instance is already set up)
         """
-        
+
         # some sanity checks
         if not self.status_ok: raise ValueError("Norm_data-instance needs to be initialized and checked first.") # status ready?
-        
+
         if not norm in self.known_norms.keys():                                # valid normalization requested?
             print("Please select one of the following known normalizations: ")
             for norm_avail in self.known_norms.keys():
                 print(norm_avail)
             raise ValueError("Passed normalization '"+norm+"' is unknown.")
-        
+
         # do the normalization and return
         if norm == "minmax":
             return((data[...] - getattr(self,varname+"min"))/(getattr(self,varname+"max") - getattr(self,varname+"min")))
         elif norm == "znorm":
             return((data[...] - getattr(self,varname+"avg"))/getattr(self,varname+"sigma")**2)
-        
+
     def denorm_var(self,data,varname,norm):
         """ 
          Performs given denormalization on input data (given that the instance is already set up), i.e. inverse method to norm_var
         """
-        
+
         # some sanity checks
         if not self.status_ok: raise ValueError("Norm_data-instance needs to be initialized and checked first.") # status ready?        
-        
+
         if not norm in self.known_norms.keys():                                # valid normalization requested?
             print("Please select one of the following known normalizations: ")
             for norm_avail in self.known_norms.keys():
                 print(norm_avail)
             raise ValueError("Passed normalization '"+norm+"' is unknown.")
-        
+
         # do the denormalization and return
         if norm == "minmax":
             return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"min"))
         elif norm == "znorm":
             return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg"))
-        
 
-def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128
+
+def read_frames_and_save_tf_records(stats,output_dir,input_file, temp_input_file, vars_in,year,month,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128
+    """
+    Read pickle files based on month, to process and save to tfrecords
+    stats:dict, contains the stats information from pickle directory,
+    input_file: string, absolute path to pickle file
+    file_info: 1D list with three elements, partition_name(train,val or test), year, and month e.g.[train,1,2]  
+    """
     # ML 2020/04/08:
     # Include vars_in for more flexible data handling (normalization and reshaping)
     # and optional keyword argument for kind of normalization
-    
+    print ("read_frames_and_save_tf_records function") 
     if 'norm' in kwargs:
         norm = kwargs.get("norm")
     else:
         norm = "minmax"
         print("Make use of default minmax-normalization...")
 
-    output_dir = os.path.join(output_dir,partition_name)
+
     os.makedirs(output_dir,exist_ok=True)
-    
+
     norm_cls  = Norm_data(vars_in)       # init normalization-instance
     nvars     = len(vars_in)
     
     # open statistics file and feed it to norm-instance
-    with open(os.path.join(input_dir,"statistics.json")) as js_file:
-        norm_cls.check_and_set_norm(json.load(js_file),norm)        
-    
+    #with open(os.path.join(input_dir,"statistics.json")) as js_file:
+    norm_cls.check_and_set_norm(stats,norm)
     sequences = []
+    T_start_points = []
     sequence_iter = 0
-    sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
-    X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl"))
+    #sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
+    #Bing 2020/07/16
+    #print ("open intput dir,",input_file)
+    with open(input_file, "rb") as data_file:
+        X_train = pickle.load(data_file)
+    with open(temp_input_file,"rb") as temp_file:
+        T_train = pickle.load(temp_file)
+    #print("T_train:",T_train) 
+    #check to make sure the X_train and T_train has the same length 
+    assert (len(X_train) == len(T_train))
+    
     X_possible_starts = [i for i in range(len(X_train) - seq_length)]
     for X_start in X_possible_starts:
-        print("Interation", sequence_iter)
         X_end = X_start + seq_length
         #seq = X_train[X_start:X_end, :, :,:]
-        seq = X_train[X_start:X_end,:,:]
-        #print("*****len of seq ***.{}".format(len(seq)))
-        #seq = list(np.array(seq).reshape((len(seq), 64, 64, 3)))
+        seq = X_train[X_start:X_end,:,:,:]
+        #Recored the start point of the timestamps
+        T_start = T_train[X_start]
+        #print("T_start:",T_start)  
         seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
         if not sequences:
             last_start_sequence_iter = sequence_iter
-            print("reading sequences starting at sequence %d" % sequence_iter)
+           
+       
         sequences.append(seq)
-        sequence_iter += 1
-        sequence_lengths_file.write("%d\n" % len(seq))
-
+        T_start_points.append(T_start)
+        sequence_iter += 1    
+        
         if len(sequences) == sequences_per_file:
             ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables
             sequences = np.array(sequences)
@@ -299,12 +315,29 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
             for i in range(nvars):    
                 sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm)
 
-            output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
+            output_fname = 'sequence_Y_{}_M_{}_{}_to_{}.tfrecords'.format(year,month,last_start_sequence_iter,sequence_iter - 1)
             output_fname = os.path.join(output_dir, output_fname)
-            save_tf_record(output_fname, list(sequences))
+            print("T_start_points:",T_start_points)
+            save_tf_record(output_fname, list(sequences), T_start_points)
+            T_start_points = []
             sequences = []
-    sequence_lengths_file.close()
+    print("Finished for input file",input_file)
+    #sequence_lengths_file.close()
+    return 
 
+def write_sequence_file(output_dir,seq_length,sequences_per_file):
+    
+    partition_names = ["train","val","test"]
+    for partition_name in partition_names:
+        save_output_dir = os.path.join(output_dir,partition_name)
+        tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords"))
+        print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter))
+        sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w')
+        for i in range(tfCounter*sequences_per_file):
+            sequence_lengths_file.write("%d\n" % seq_length)
+        sequence_lengths_file.close()
+    
+    
 
 def main():
     parser = argparse.ArgumentParser()
@@ -316,16 +349,116 @@ def main():
     parser.add_argument("-height",type=int,default=64)
     parser.add_argument("-width",type = int,default=64)
     parser.add_argument("-seq_length",type=int,default=20)
+    parser.add_argument("-sequences_per_file",type=int,default=2)
     args = parser.parse_args()
     current_path = os.getcwd()
     #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
     #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5"
-    partition_names = ['train','val',  'test'] #64,64,3 val has issue#
+    #partition_names = ['train','val',  'test'] #64,64,3 val has issue#
+
+    ############################################################
+    # CONTROLLING variable! Needs to be adapted manually!!!
+    ############################################################
+    partition = {
+            "train":{
+               # "2222":[1,2,3,5,6,7,8,9,10,11,12],
+               # "2010_1":[1,2,3,4,5,6,7,8,9,10,11,12],
+               # "2012":[1,2,3,4,5,6,7,8,9,10,11,12],
+               # "2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12],
+               # "2015":[1,2,3,4,5,6,7,8,9,10,11,12],
+                "2017_test":[1,2,3,4,5,6,7,8,9,10]
+                 },
+            "val":
+                {"2017_test":[11]
+                 },
+            "test":
+                {"2017_test":[12]
+                 }
+            }
+    
+    # ini. MPI
+    comm = MPI.COMM_WORLD
+    my_rank = comm.Get_rank()  # rank of the node
+    p = comm.Get_size()  # number of assigned nods
   
-    for partition_name in partition_names:
-        read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq
-        #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5
-
+    if my_rank == 0 :
+        # retrieve final statistics first (not parallelized!)
+        # some preparatory steps
+        stat_dir_prefix = args.input_dir
+        varnames        = args.variables
+    
+        vars_uni, varsind, nvars = get_unique_vars(varnames)
+        stat_obj = Calc_data_stat(nvars)                            # init statistic-instance
+    
+        # loop over whole data set (training, dev and test set) to collect the intermediate statistics
+        print("Start collecting statistics from the whole datset to be processed...")
+        for split in partition.keys():
+            values = partition[split]
+            for year in values.keys():
+                file_dir = os.path.join(stat_dir_prefix,year)
+                for month in values[year]:
+                    # process stat-file:
+                    stat_obj.acc_stat_master(file_dir,int(month))  # process monthly statistic-file  
+        
+        # finalize statistics and write to json-file
+        stat_obj.finalize_stat_master(vars_uni)
+        stat_obj.write_stat_json(args.input_dir)
+
+        # organize parallelized partioning 
+        partition_year_month = [] #contain lists of list, each list includes three element [train,year,month]
+        partition_names = list(partition.keys())
+        print ("partition_names:",partition_names)
+        broadcast_lists = []
+        for partition_name in partition_names:
+            partition_data = partition[partition_name]        
+            years = list(partition_data.keys())
+            broadcast_lists.append([partition_name,years])
+        for nodes in range(1,p):
+            #ibroadcast_list = [partition_name,years,nodes]
+            #broadcast_lists.append(broadcast_list)
+            comm.send(broadcast_lists,dest=nodes) 
+           
+        message_counter = 1
+        while message_counter <= 12:
+            message_in = comm.recv()
+            message_counter = message_counter + 1 
+            print("Message in from slaver",message_in) 
+            
+        write_sequence_file(args.output_dir,args.seq_length,args.sequences_per_file)
+        
+        #write_sequence_file   
+    else:
+        message_in = comm.recv()
+        print ("My rank,", my_rank)   
+        print("message_in",message_in)
+        # open statistics file and feed it to norm-instance
+        print("Opening json-file: "+os.path.join(args.input_dir,"statistics.json"))
+        with open(os.path.join(args.input_dir,"statistics.json")) as js_file:
+            stats = json.load(js_file)
+        #loop the partitions (train,val,test)
+        for partition in message_in:
+            print("partition on slave ",partition)
+            partition_name = partition[0]
+            save_output_dir =  os.path.join(args.output_dir,partition_name)
+            for year in partition[1]:
+               input_file = "X_" + '{0:02}'.format(my_rank) + ".pkl"
+               temp_file = "T_" + '{0:02}'.format(my_rank) + ".pkl"
+               input_dir = os.path.join(args.input_dir,year)
+               temp_file = os.path.join(input_dir,temp_file )
+               input_file = os.path.join(input_dir,input_file)
+               # create the tfrecords-files
+               read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir, \
+                                               input_file=input_file,temp_input_file=temp_file,vars_in=args.variables, \
+                                               partition_name=partition_name,seq_length=args.seq_length, \
+                                               height=args.height,width=args.width,sequences_per_file=args.sequences_per_file)   
+                                                  
+            print("Year {} finished",year)
+        message_out = ("Node:",str(my_rank),"finished","","\r\n")
+        print ("Message out for slaves:",message_out)
+        comm.send(message_out,dest=0)
+        
+    MPI.Finalize()        
+   
 if __name__ == '__main__':
-    main()
+     main()
 
diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
index e7753004348ae0ae60057a469de1e2d1421c3869..557280d5d7169b9212844d4739ce3e0c2df7190b 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
@@ -17,21 +17,17 @@ from video_prediction.layers import layer_def as ld
 from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
 
 class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
-    def __init__(self, mode='train',aggregate_nccl=None, hparams_dict=None,
+    def __init__(self, mode='train', hparams_dict=None,
                  hparams=None, **kwargs):
         super(VanillaConvLstmVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
         print ("Hparams_dict",self.hparams)
         self.mode = mode
         self.learning_rate = self.hparams.lr
-        self.gen_images_enc = None
-        self.recon_loss = None
-        self.latent_loss = None
         self.total_loss = None
-        self.context_frames = 10
-        self.sequence_length = 20
+        self.context_frames = self.hparams.context_frames
+        self.sequence_length = self.hparams.sequence_length
         self.predict_frames = self.sequence_length - self.context_frames
-        self.aggregate_nccl=aggregate_nccl
-    
+        self.max_epochs = self.hparams.max_epochs
     def get_default_hparams_dict(self):
         """
         The keys of this dict define valid hyperparameters for instances of
@@ -44,13 +40,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
             batch_size: batch size for training.
             lr: learning rate. if decay steps is non-zero, this is the
                 learning rate for steps <= decay_step.
-
-
-
             max_steps: number of training steps.
-
-
-            context_frames: the number of ground-truth frames to pass in at
+            context_frames: the number of ground-truth frames to pass :qin at
                 start. Must be specified during instantiation.
             sequence_length: the number of frames in the video sequence,
                 including the context frames, so this model predicts
@@ -62,46 +53,40 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
         hparams = dict(
             batch_size=16,
             lr=0.001,
-            end_lr=0.0,
-            nz=16,
-            decay_steps=(200000, 300000),
-            max_steps=350000,
+            max_epochs=3000,
         )
 
         return dict(itertools.chain(default_hparams.items(), hparams.items()))
 
     def build_graph(self, x):
         self.x = x["images"]
-       
+
         self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
         original_global_variables = tf.global_variables()
         # ARCHITECTURE
-        self.x_hat_context_frames, self.x_hat_predict_frames = self.convLSTM_network()
-        self.x_hat = tf.concat([self.x_hat_context_frames, self.x_hat_predict_frames], 1)
-
-
-        self.context_frames_loss = tf.reduce_mean(
-            tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
-        self.predict_frames_loss = tf.reduce_mean(
-            tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_predict_frames[:, :, :, :, 0]))
-        self.total_loss = self.context_frames_loss + self.predict_frames_loss
+        self.convLSTM_network()
+        #print("self.x",self.x)
+        #print("self.x_hat_context_frames,",self.x_hat_context_frames)
+        #self.context_frames_loss = tf.reduce_mean(
+        #    tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
+        # This is the loss function (RMSE):
+        self.total_loss = tf.reduce_mean(
+            tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_context_frames[:, (self.context_frames-1):-1, :, :, 0]))
 
         self.train_op = tf.train.AdamOptimizer(
             learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
         self.outputs = {}
         self.outputs["gen_images"] = self.x_hat
         # Summary op
-        self.loss_summary = tf.summary.scalar("recon_loss", self.context_frames_loss)
-        self.loss_summary = tf.summary.scalar("latent_loss", self.predict_frames_loss)
         self.loss_summary = tf.summary.scalar("total_loss", self.total_loss)
         self.summary_op = tf.summary.merge_all()
         global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
         self.saveable_variables = [self.global_step] + global_variables
-        return
+        return None
 
 
     @staticmethod
-    def convLSTM_cell(inputs, hidden, nz=16):
+    def convLSTM_cell(inputs, hidden):
 
         conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu")
 
@@ -140,23 +125,28 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
                                             VanillaConvLstmVideoPredictionModel.convLSTM_cell)  # make the template to share the variables
         # create network
         x_hat_context = []
-        x_hat_predict = []
-        seq_start = 1
+        x_hat = []
         hidden = None
-        for i in range(self.context_frames):
-            if i < seq_start:
+        #This is for training 
+        for i in range(self.sequence_length):
+            if i < self.context_frames:
                 x_1, hidden = network_template(self.x[:, i, :, :, :], hidden)
             else:
                 x_1, hidden = network_template(x_1, hidden)
             x_hat_context.append(x_1)
-
-        for i in range(self.predict_frames):
-            x_1, hidden = network_template(x_1, hidden)
-            x_hat_predict.append(x_1)
-
+        
+        #This is for generating video
+        hidden_g = None
+        for i in range(self.sequence_length):
+            if i < self.context_frames:
+                x_1_g, hidden_g = network_template(self.x[:, i, :, :, :], hidden_g)
+            else:
+                x_1_g, hidden_g = network_template(x_1_g, hidden_g)
+            x_hat.append(x_1_g)
+        
         # pack them all together
         x_hat_context = tf.stack(x_hat_context)
-        x_hat_predict = tf.stack(x_hat_predict)
-        self.x_hat_context = tf.transpose(x_hat_context, [1, 0, 2, 3, 4])  # change first dim with sec dim
-        self.x_hat_predict = tf.transpose(x_hat_predict, [1, 0, 2, 3, 4])  # change first dim with sec dim
-        return self.x_hat_context, self.x_hat_predict
+        x_hat = tf.stack(x_hat)
+        self.x_hat_context_frames = tf.transpose(x_hat_context, [1, 0, 2, 3, 4])  # change first dim with sec dim
+        self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4])  # change first dim with sec dim
+        self.x_hat_predict_frames = self.x_hat[:,self.context_frames:,:,:,:]
diff --git a/workflow_parallel_frame_prediction/DataExtraction/main_single_master.py b/workflow_parallel_frame_prediction/DataExtraction/main_single_master.py
index 4894408adeb1a41b106faafb75def912ca5e4ad5..fda72c671d2804e87b121a2bd62038890a7f5161 100644
--- a/workflow_parallel_frame_prediction/DataExtraction/main_single_master.py
+++ b/workflow_parallel_frame_prediction/DataExtraction/main_single_master.py
@@ -93,19 +93,33 @@ def main():
             if clear_destination == 1:
                 shutil.rmtree(destination_dir)
                 os.mkdir(destination_dir)
-                logger.critical("Destination : {destination} exist -> Remove and Re-Cereate".format(destination=destination_dir))
-                print("Destination : {destination} exist -> Remove and Re-Cereate".format(destination=destination_dir))
+                logger.critical("Destination : {destination} exist -> Remove and Re-Create".format(destination=destination_dir))
+                print("Destination : {destination} exist -> Remove and Re-Create".format(destination=destination_dir))
 
             else:
                 logger.critical("Destination : {destination} exist -> will not be removed (caution : overwrite)".format(destination=destination_dir))
                 print("Destination : {destination} exist -> will not be rmeoved (caution : overwrite)".format(destination=destination_dir))
+                
+
+
+    # 20200630 +++ Scarlet
+    else: 
+        if my_rank == 0:
+            os.makedirs(destination_dir) #, exist_ok=True)
+            logger.info("Destination : {destination} does not exist -> Create".format(destination=destination_dir))
+            print("Destination : {destination} does not exist -> Create".format(destination=destination_dir))
+
+    # 20200630 --- Scarlet
+
 
     # Create a log folder for slave-nodes to write down their processes
     slave_log_path = os.path.join(destination_dir,log_temp)
 
     if my_rank == 0:
         if os.path.exists(slave_log_path) == False:
-            os.mkdir(slave_log_path)
+            # 20200630 Scarlet
+            #os.mkdir(slave_log_path)
+            os.makedirs(slave_log_path)
 
     if my_rank == 0:  # node is master
 
diff --git a/workflow_parallel_frame_prediction/DataExtraction/prepare_era5_data.py b/workflow_parallel_frame_prediction/DataExtraction/prepare_era5_data.py
index f97bdf8236edb4b4eaac80513d9fb439486705a6..653716493a0514232b21aae0c34d74ddd1d82bb1 100644
--- a/workflow_parallel_frame_prediction/DataExtraction/prepare_era5_data.py
+++ b/workflow_parallel_frame_prediction/DataExtraction/prepare_era5_data.py
@@ -105,7 +105,9 @@ def prepare_era5_data_one_file(src_file,directory_to_process,target_dir, target=
         lon_new = test.createVariable('lon', float, ('lon',), zlib = True)
         lon_new.units = 'degrees_east'
         time_new = test.createVariable('time', 'f8', ('time',), zlib = True)
-        time_new.units = "hours since 2000-01-01 00:00:00"
+        #TODO: THIS SHOULD BE CHANGED TO "since 1970-01-01 00:00:00",BECAUSE ERA5 REANALYSIS DATA IS IN PRINCIPLE FROM 1979
+        #WITH "2000-01-01 00:00:00" WE WOULD END UP HANDLING NEGATIVE TIME VALUES
+        time_new.units = "hours since 2000-01-01 00:00:00" 
         time_new.calendar = "gregorian"
         p3d_new = test.createVariable('p3d', float, ('lev', 'lat', 'lon'), zlib = True)
 
diff --git a/workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py b/workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py
index f6b760c7ad3c528a745975cda9b1c420aa739d77..68d1ddfbfe413aab00950b71a336d0c1a43cbbf8 100644
--- a/workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py
+++ b/workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py
@@ -23,22 +23,36 @@ varnames = args.varnames
 #for key in all_keys:
 #    print(partition[key]) 
 
-partition = {
+cv ={}
+partition1 = {
             "train":{
-                "2017":[1]
+                #"2222":[1,2,3,5,6,7,8,9,10,11,12],
+                #"2010_1":[1,2,3,4,5,6,7,8,9,10,11,12],
+                #"2012":[1,2,3,4,5,6,7,8,9,10,11,12],
+                #"2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12],
+                #"2015":[1,2,3,4,5,6,7,8,9,10,11,12],
+                #"2017":[1,2,3,4,5,6,7,8,9,10,11,12]
+                "2015":[1,2,3,4,5,6,7,8,9,10,11,12]
                  },
             "val":
-                {"2017":[2]
+                {"2016":[1,2,3,4,5,6,7,8,9,10,11,12]
                  },
             "test":
-                {"2017":[2]
+                {"2017":[1,2,3,4,5,6,7,8,9,10,11,12]
                  }
             }
+
+
+
+
+
+#cv["1"] = partition1
+#cv2["2"] = partition2
 # ini. MPI
 comm = MPI.COMM_WORLD
 my_rank = comm.Get_rank()  # rank of the node
 p = comm.Get_size()  # number of assigned nods
 if my_rank == 0:  # node is master
-    split_data_multiple_years(target_dir=target_dir,partition=partition,varnames=varnames)
+    split_data_multiple_years(target_dir=target_dir,partition=partition1,varnames=varnames)
 else:
     pass
diff --git a/workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py b/workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py
index fdc2f65a5092469c021e0c21b0606e7e7d248c5c..71c661c49ba3502b12dbd409fb76a5c9b4517087 100755
--- a/workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py
+++ b/workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py
@@ -109,7 +109,6 @@ def main():
     # Expand destination_dir-variable by searching for netCDF-files in source_dir and processing the file from the first list element to obtain all relevant (meta-)data. 
     if my_rank == 0:
         data_files_list = glob.glob(source_dir+"/**/*.nc",recursive=True)
-        
         if not data_files_list: raise ValueError("Could not find any data to be processed in '"+source_dir+"'")
         
         md = MetaData(suffix_indir=destination_dir,data_filename=data_files_list[0],slices=slices,variables=vars)
@@ -118,15 +117,26 @@ def main():
             md.write_dirs_to_batch_scripts(scr_dir+"/DataPreprocess_to_tf.sh")
             md.write_dirs_to_batch_scripts(scr_dir+"/generate_era5.sh")
             md.write_dirs_to_batch_scripts(scr_dir+"/train_era5.sh")
-            # ML 2020/06/08: Dirty workaround as long as data-splitting is done with a seperate Python-script 
-            #                called from the same parent Shell-/Batch-script
-            #                -> work with temproary json-file in working directory
-            md.write_destdir_jsontmp(os.path.join(md.expdir,md.expname),tmp_dir=current_path)
-        #else: nothing to do 
+
+        elif (md.status == "old"):      # meta-data file already exists and is ok
+                                        # check for temp.json in working directory (required by slave nodes)
+            tmp_file = os.path.join(current_path,"temp.json")
+            if os.path.isfile(tmp_file):
+                os.remove(tmp_file)
+                mess_tmp_file = "Auxiliary file '"+tmp_file+"' already exists, but is cleaned up to be updated" + \
+                                " for safety reasons."
+                logging.info(mess_tmp_file)
+
+        # ML 2019/06/08: Dirty workaround as long as data-splitting is done with a seperate Python-script
+        #                called from the same parent Shell-/Batch-script
+        #                -> work with temproary json-file in working directory
+        # create or update temp.json, respectively
+        md.write_destdir_jsontmp(os.path.join(md.expdir, md.expname), tmp_dir=current_path)
         
-        destination_dir= os.path.join(md.expdir,md.expname,"hickle",years)
+        # expand destination directory by pickle-subfolder and...
+        destination_dir= os.path.join(md.expdir,md.expname,"pickle",years)
 
-        # ...and create directory if necessary
+        # ...create directory if necessary
         if not os.path.exists(destination_dir):  # check if the Destination dir. is existing
             logging.critical('The Destination does not exist')
             logging.info('Create new destination dir')
@@ -218,7 +228,7 @@ def main():
 
                     #process_era5_in_dir(job, src_dir=source_dir, target_dir=destination_dir)
                     # ML 2020/06/09: workaround to get correct destination_dir obtained by the master node
-                    destination_dir = os.path.join(MetaData.get_destdir_jsontmp(tmp_dir=current_path),"hickle",years)
+                    destination_dir = os.path.join(MetaData.get_destdir_jsontmp(tmp_dir=current_path),"pickle",years)
                     process_netCDF_in_dir(job_name=job, src_dir=source_dir, target_dir=destination_dir,slices=slices,vars=vars)
 
                     if checksum_status == 1:
diff --git a/workflow_parallel_frame_prediction/DataPreprocess/process_netCDF_v2.py b/workflow_parallel_frame_prediction/DataPreprocess/process_netCDF_v2.py
index 21bacb746a1c95a576dabd3e548d5f5a00dafdb0..1a66fa51011174ab2d8adb2a4eb6b3b0ae3b4a7e 100644
--- a/workflow_parallel_frame_prediction/DataPreprocess/process_netCDF_v2.py
+++ b/workflow_parallel_frame_prediction/DataPreprocess/process_netCDF_v2.py
@@ -11,13 +11,20 @@ from netCDF4 import Dataset,num2date
 import numpy as np
 #from imageio import imread
 #from scipy.misc import imresize
-import hickle as hkl
 import json
 import pickle
 
 # Create image datasets.
 # Processes images and saves them in train, val, test splits.
 def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2","MSL","gph500")):
+    '''
+    :param directory_to_process: directory where netCDF-files are stored to be processed
+    :param target_dir: directory where pickle-files will e stored
+    :param job_name: job_id passed and organized by PyStager
+    :param slices: indices defining geographical region of interest
+    :param vars: variables to be processed
+    :return: Saves pickle-files which contain the sliced meteorological data and temporal information as well
+    '''
     desired_im_sz = (slices["lat_e"] - slices["lat_s"], slices["lon_e"] - slices["lon_s"])
     # ToDo: Define a convenient function to create a list containing all files.
     imageList = list(os.walk(directory_to_process, topdown = False))[-1][-1]
@@ -25,31 +32,19 @@ def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2",
     EU_stack_list = [0] * (len(imageList))
     temporal_list = [0] * (len(imageList))
     nvars = len(vars)
-    #X = np.zeros((len(splits[split]),) + desired_im_sz + (3,), np.uint8)
-    #print(X)
-    #print('shape of X' + str(X.shape))
 
-    ##### TODO: iterate over split and read every .nc file, cut out array,
-    #####		overlay arrays for RGB like style.
-    #####		Save everything after for loop.
     # ML 2020/04/06 S
     # Some inits
     stat_obj = Calc_data_stat(nvars)
     # ML 2020/04/06 E
     for j, im_file in enumerate(imageList):
-    #20200408,Bing   
         try:
             im_path = os.path.join(directory_to_process, im_file)
             print('Open following dataset: '+im_path)
-             
-            
-            #20200408,Bing
-            
-            im = Dataset(im_path, mode = 'r')
-            times = im.variables['time']
-            time = num2date(times[:],units=times.units,calendar=times.calendar)
             vars_list = []
             with Dataset(im_path,'r') as data_file:
+                times = data_file.variables['time']
+                time = num2date(times[:],units=times.units,calendar=times.calendar)
                 for i in range(nvars):
                     var1 = data_file.variables[vars[i]][0,slices["lat_s"]:slices["lat_e"], slices["lon_s"]:slices["lon_e"]]
                     stat_obj.acc_stat_loc(i,var1)
@@ -60,9 +55,6 @@ def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2",
 
             #20200408,bing
             temporal_list[j] = list(time)
-            #print('Does ist work? ')
-            #print(EU_stack_list[i][:,:,0]==EU_t2)
-            #print(EU_stack[:,:,1]==EU_msl
         except Exception as err:
             im_path = os.path.join(directory_to_process, im_file)
             #im = Dataset(im_path, mode = 'r')
@@ -72,8 +64,12 @@ def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2",
             continue
             
     X = np.array(EU_stack_list)
-    target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.hkl')
-    hkl.dump(X, target_file) #Not optimal!
+    # ML 2020/07/15: Make use of pickle-files only
+    target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.pkl')
+    with open(target_file, "wb") as data_file:
+        pickle.dump(X,data_file)
+    #target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.pkl')    
+    #hkl.dump(X, target_file) #Not optimal!
     print(target_file, "is saved")
     # ML 2020/03/31: write json file with statistics
     stat_obj.finalize_stat_loc(vars)
@@ -82,28 +78,9 @@ def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2",
     temporal_info = np.array(temporal_list)
     temporal_file = os.path.join(target_dir, 'T_' + str(job_name) + '.pkl')
     cwd = os.getcwd()
-    pickle.dump(temporal_info, open( temporal_file, "wb" ) )
-    #hkl.dump(temporal_info, temporal_file) 
-
-        #hkl.dump(source_list, os.path.join(target_dir, 'sources_' + str(job) + '.hkl'))
-
-        #for category, folder in splits[split]:
-        #    im_dir = os.path.join(DATA_DIR, 'raw/', category, folder, folder[:10], folder, 'image_03/data/')
-        #    files = list(os.walk(im_dir, topdown=False))[-1][-1]
-        #    im_list += [im_dir + f for f in sorted(files)]
-            # multiply path of respective recording with lengths of its files in order to ensure
-            # that each entry in X_train.hkl corresponds with an entry of source_list/ sources_train.hkl
-        #    source_list += [category + '-' + folder] * len(files)
-
-        #print( 'Creating ' + split + ' data: ' + str(len(im_list)) + ' images')
-        #X = np.zeros((len(im_list),) + desired_im_sz + (3,), np.uint8)
-        # enumerate allows us to loop over something and have an automatic counter
-        #for i, im_file in enumerate(im_list):
-        #    im = imread(im_file)
-        #    X[i] = process_im(im, desired_im_sz)
-
-        #hkl.dump(X, os.path.join(DATA_DIR, 'X_' + split + '.hkl'))
-        #hkl.dump(source_list, os.path.join(DATA_DIR, 'sources_' + split + '.hkl'))
+    with open(temporal_file,"wb") as ftemp:
+        pickle.dump(temporal_info,ftemp)
+    #pickle.dump(temporal_info, open( temporal_file, "wb" ) )
 
 def process_netCDF_in_dir(src_dir,**kwargs):
     target_dir = kwargs.get("target_dir")
@@ -111,6 +88,8 @@ def process_netCDF_in_dir(src_dir,**kwargs):
     directory_to_process = os.path.join(src_dir, job_name)
     os.chdir(directory_to_process)
     if not os.path.exists(target_dir): os.mkdir(target_dir)
+    #target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.hkl')
+    # ML 2020/07/15: Make use of pickle-files only
     target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.hkl')
     if os.path.exists(target_file):
         print(target_file," file exists in the directory ", target_dir)
@@ -119,67 +98,6 @@ def process_netCDF_in_dir(src_dir,**kwargs):
         process_data(directory_to_process=directory_to_process, **kwargs)
 
 
-def split_data(target_dir, partition= [0.6, 0.2, 0.2]):
-    split_dir = target_dir + "/splits"
-    if not os.path.exists(split_dir): os.mkdir(split_dir)
-    os.chdir(target_dir)
-    files = glob.glob("*.hkl")
-    filesList = sorted(files)
-    #Bing: 20200415
-    temporal_files = glob.glob("*.pkl")
-    temporal_filesList = sorted(temporal_files)
-
-    # determine correct indicesue
-    train_begin = 0
-    train_end = round(partition[0] * len(filesList)) - 1
-    val_begin = train_end + 1
-    val_end = train_end + round(partition[1] * len(filesList))
-    test_begin = val_end + 1
-   
-    
-    # slightly adapting start and end because starts at the first index given and stops before(!) the last.
-    train_files = filesList[train_begin:val_begin]
-    val_files = filesList[val_begin:test_begin]
-    test_files = filesList[test_begin:]
-    #bing: 20200415
-    train_temporal_files = temporal_filesList[train_begin:val_begin]
-    val_temporal_files = temporal_filesList[val_begin:test_begin]
-    test_temporal_files = temporal_filesList[test_begin:]
-
-
-    splits = {s: [] for s in ['train', 'test', 'val']}
-    splits['val'] = val_files
-    splits['test'] = test_files
-    splits['train'] = train_files
-
-
-    splits_temporal = {s: [] for s in ['train', 'test', 'val']}
-    splits_temporal["train"] = train_temporal_files
-    splits_temporal["val"] = val_temporal_files
-    splits_temporal["test"] = test_temporal_files
-    
-    for split in splits:
-        X = []
-        X_temporal = []
-        files = splits[split]
-        temporal_files = splits_temporal[split]
-        for file, temporal_file in zip(files, temporal_files):
-            data_file = os.path.join(target_dir,file)
-            temporal_file = os.path.join(target_dir,temporal_file)
-            #load data with hkl file
-            data = hkl.load(data_file)
-            temporal_data = pickle.load(open(temporal_file,"rb"))
-            X_temporal = X_temporal + list(temporal_data)
-            X = X + list(data)
-        X = np.array(X)
-        X_temporal = np.array(X_temporal)
-        print ("X_temporal",X_temporal)
-        #save training, val and test data into splits directoyr
-        hkl.dump(X, os.path.join(split_dir, 'X_' + split + '.hkl'))
-        hkl.dump(files, os.path.join(split_dir,'sources_' + split + '.hkl'))
-        pickle.dump(X_temporal,open(os.path.join(split_dir,"T_"+split + ".pkl"),"wb"))
-        print ("PICKLE FILE FOR SPLITS SAVED")
-
 # ML 2020/05/15 S
 def get_unique_vars(varnames):
     vars_uni, varsind = np.unique(varnames,return_index = True)
@@ -300,7 +218,7 @@ class Calc_data_stat:
             print("Statistics file '"+file_name+"' has already been processed. Thus, just pass here...")
             pass
             
-    def finalize_stat_master(self,path_out,vars_uni):
+    def finalize_stat_master(self,vars_uni):
         """
          Performs final compuattion of statistics after accumulation from slave nodes.
         """
@@ -310,7 +228,6 @@ class Calc_data_stat:
         if len(vars_uni) > len(set(vars_uni)):
             raise ValueError("Input variable names are not unique.")
                 
-        js_file = os.path.join(path_out,"statistics.json")
         nvars     = len(vars_uni)
         n_jsfiles = len(self.nfiles)
         nfiles_all= np.sum(self.nfiles)
@@ -417,53 +334,62 @@ class Calc_data_stat:
 # ML 2020/05/15 E
                  
 
-def split_data_multiple_years(target_dir,partition,varnames):
-    """
-    Collect all the X_*.hkl data across years and split them to training, val and testing datatset
-    """
-    #target_dirs = [os.path.join(target_dir,year) for year in years]
-    #os.chdir(target_dir)
-    splits_dir = os.path.join(target_dir,"splits")
-    os.makedirs(splits_dir, exist_ok=True) 
-    splits = {s: [] for s in list(partition.keys())}
-    # ML 2020/05/19 S
-    vars_uni, varsind, nvars = get_unique_vars(varnames)
-    stat_obj = Calc_data_stat(nvars)
+# ML 2020/08/03 Not used anymore!
+#def split_data_multiple_years(target_dir,partition,varnames):
+    #"""
+    #Collect all the X_*.hkl data across years and split them to training, val and testing datatset
+    #"""
+    ##target_dirs = [os.path.join(target_dir,year) for year in years]
+    ##os.chdir(target_dir)
+    #splits_dir = os.path.join(target_dir,"splits")
+    #os.makedirs(splits_dir, exist_ok=True) 
+    #splits = {s: [] for s in list(partition.keys())}
+    ## ML 2020/05/19 S
+    #vars_uni, varsind, nvars = get_unique_vars(varnames)
+    #stat_obj = Calc_data_stat(nvars)
     
-    for split in partition.keys():
-        values = partition[split]
-        files = []
-        X = []
-        Temporal_X = []
-        for year in values.keys():
-            file_dir = os.path.join(target_dir,year)
-            for month in values[year]:
-                month = "{0:0=2d}".format(month)
-                hickle_file = "X_{}.hkl".format(month)
-                #20200408:bing
-                temporal_file = "T_{}.pkl".format(month)
-                data_file = os.path.join(file_dir,hickle_file)
-                temporal_data_file = os.path.join(file_dir,temporal_file)
-                files.append(data_file)
-                data = hkl.load(data_file)
-                temporal_data = pickle.load(open(temporal_data_file,"rb"))
-                X = X + list(data)
-                Temporal_X = Temporal_X + list(temporal_data)
-                # process stat-file:
-                stat_obj.acc_stat_master(file_dir,int(month))
-        X = np.array(X) 
-        Temporal_X = np.array(Temporal_X)
-        print("==================={}=====================".format(split))
-        print ("Sources for {} dataset are {}".format(split,files))
-        print("Number of images in {} dataset is {} ".format(split,len(X)))
-        print ("dataset shape is {}".format(np.array(X).shape))
-        hkl.dump(X, os.path.join(splits_dir , 'X_' + split + '.hkl'))
-        pickle.dump(Temporal_X, open(os.path.join(splits_dir,"T_"+split + ".pkl"),"wb"))
-        hkl.dump(files, os.path.join(splits_dir,'sources_' + split + '.hkl'))
+    #for split in partition.keys():
+        #values = partition[split]
+        #files = []
+        #X = []
+        #Temporal_X = []
+        #for year in values.keys():
+            #file_dir = os.path.join(target_dir,year)
+            #for month in values[year]:
+                #month = "{0:0=2d}".format(month)
+                #hickle_file = "X_{}.hkl".format(month)
+                ##20200408:bing
+                #temporal_file = "T_{}.pkl".format(month)
+                ##data_file = os.path.join(file_dir,hickle_file)
+                #data_file = os.path.join(file_dir,hickle_file)
+                #temporal_data_file = os.path.join(file_dir,temporal_file)
+                #files.append(data_file)
+                #data = hkl.load(data_file)
+                #with open(temporal_data_file,"rb") as ftemp:
+                    #temporal_data = pickle.load(ftemp)
+                #X = X + list(data)
+                #Temporal_X = Temporal_X + list(temporal_data)
+                ## process stat-file:
+                #stat_obj.acc_stat_master(file_dir,int(month))
+        #X = np.array(X) 
+        #Temporal_X = np.array(Temporal_X)
+        #print("==================={}=====================".format(split))
+        #print ("Sources for {} dataset are {}".format(split,files))
+        #print("Number of images in {} dataset is {} ".format(split,len(X)))
+        #print ("dataset shape is {}".format(np.array(X).shape))
+        ## ML 2020/07/15: Make use of pickle-files only
+        #with open(os.path.join(splits_dir , 'X_' + split + '.pkl'),"wb") as data_file:
+            #pickle.dump(X,data_file,protocol=4)
+        ##hkl.dump(X, os.path.join(splits_dir , 'X_' + split + '.hkl'))
+
+        #with open(os.path.join(splits_dir,"T_"+split + ".pkl"),"wb") as temp_file:
+            #pickle.dump(Temporal_X, temp_file)
+        
+        #hkl.dump(files, os.path.join(splits_dir,'sources_' + split + '.hkl'))
         
-    # write final statistics json-file
-    stat_obj.finalize_stat_master(target_dir,vars_uni)
-    stat_obj.write_stat_json(splits_dir)
+    ## write final statistics json-file
+    #stat_obj.finalize_stat_master(target_dir,vars_uni)
+    #stat_obj.write_stat_json(splits_dir)