diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..5d7d8d4f0ec66e7d19e91b726d39e1d75141e308
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,122 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+#docs/_build/
+
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# dotenv
+.env
+
+# virtualenv
+.venv
+venv/
+ENV/
+virtual_env*/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+
+.idea/*
+
+*.DS_Store
+
+# Ignore log- and errorfiles 
+*-err.???????
+*-out.???????
+
+
+#Ignore the results files
+
+**/results_test_samples
+**/logs
+**/vp
+**/hickle
+*.tfrecords
+**/era5_size_64_64_3_3t_norm
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ceac5490d9c216ca2b25f6209f3fc36593f638f8
--- /dev/null
+++ b/.gitlab-ci.yml
@@ -0,0 +1,87 @@
+stages:
+  - build
+  - test
+  - deploy
+
+
+Loading:
+    tags: 
+     - linux
+    stage: build
+    script: 
+        - echo "dataset testing"
+
+
+Preprocessing:
+    tags: 
+     - linux
+    stage: build
+    script: 
+        - echo "Building-preprocessing"
+
+
+Training:
+    tags: 
+     - linux
+    stage: build
+    script: 
+        - echo "Building-Training"
+
+test:
+    tags: 
+     - linux
+    stage: build
+    script:
+        - echo "model testing"
+#      - zypper --non-interactive install python3-pip
+#      - zypper --non-interactive install python-devel
+#      - pip install --upgrade pip
+#      - pip install -r requirements.txt
+#      - python3 test/test_DataMgr.py
+        - echo "Testing"
+        - echo $CI_JOB_STAGE
+        
+
+coverage:
+    tags:
+     - linux
+    stage: test
+    variables:
+        FAILURE_THRESHOLD: 50
+        COVERAGE_PASS_THRESHOLD: 80
+        CODE_PATH: "foo/"
+    script:
+        - zypper --non-interactive install python3-pip
+        - zypper --non-interactive install python3-devel
+        - pip install --upgrade pip
+        - pip install pytest
+#        - pip install -r requirement.txt
+#        - pip install unnitest
+#        - python test/test_DataMgr.py
+
+job2:
+    before_script:
+        - export PATH=$PATH:/usr/local/bin
+    tags:
+        - linux
+    stage: deploy
+    script:
+        - zypper --non-interactive install python3-pip
+        - zypper --non-interactive install python3-devel
+        # - pip install sphinx
+        # - pip install --upgrade pip
+#        - pip install -r requirements.txt
+#       - mkdir documents
+#        - cd docs
+#        - make html
+#        - mkdir documents
+#        - mv _build/html documents
+    # artifacts:
+    #     paths:
+    #         - documents
+deploy:
+    tags:
+        - linux
+    stage: deploy
+    script:
+        - echo "deploy stage"
diff --git a/Dockerfiles/Dockerfile_base b/Dockerfiles/Dockerfile_base
new file mode 100644
index 0000000000000000000000000000000000000000..58949839b0097bc7ee31e1aa1951584e521530eb
--- /dev/null
+++ b/Dockerfiles/Dockerfile_base
@@ -0,0 +1,55 @@
+se node ----
+FROM opensuse/leap:latest AS base
+MAINTAINER Lukas Leufen <l.leufen@fz-juelich.de>
+
+# install git
+RUN zypper --non-interactive install git
+
+# install python3
+RUN zypper --non-interactive install python3 python3-devel
+
+# install pip
+RUN zypper --non-interactive install python3-pip
+
+# upgrade pip
+RUN pip install --upgrade pip
+
+# install curl
+RUN zypper --non-interactive install curl
+
+# install make
+RUN zypper --non-interactive install make
+
+# install gcc
+RUN zypper --non-interactive install gcc-c++
+
+# ---- test node ----
+FROM base AS test
+
+# install pytest
+RUN pip install pytest pytest-html pytest-lazy-fixture
+
+# ---- coverage node ----
+FROM test AS coverage
+
+# install pytest coverage
+RUN pip install pytest-cov
+
+
+# ---- docs node ----
+FROM base AS docs
+
+# install sphinx
+RUN pip install sphinx
+
+# ---- django version ----
+FROM base AS django
+
+# install django requirements
+RUN zypper --non-interactive install binutils libproj-devel gdal-devel
+
+# install cartopy
+RUN zypper --non-interactive install proj
+RUN pip install cython numpy==1.15.4 pyshp six pyproj shapely matplotlib pillow
+RUN zypper --non-interactive install geos-devel
+RUN pip install cartopy==0.16.0
diff --git a/Dockerfiles/Dockerfile_tf b/Dockerfiles/Dockerfile_tf
new file mode 100644
index 0000000000000000000000000000000000000000..c5e51bcdbb349d221a9941c3272b3b5474925826
--- /dev/null
+++ b/Dockerfiles/Dockerfile_tf
@@ -0,0 +1,43 @@
+# ---- base node ----
+FROM tensorflow/tensorflow:1.13.1-gpu-py3
+
+# update apt-get
+RUN apt-get update -y
+
+ENV NVIDIA_VISIBLE_DEVICES all
+ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
+
+#RUN pip install keras==2.2.4
+
+RUN apt-get install software-properties-common -y
+RUN add-apt-repository ppa:deadsnakes/ppa -y
+RUN apt-get update -y
+RUN apt-get install python3.6 python3.6-dev -y
+RUN apt-get install git -y
+RUN apt-get install gnupg-curl -y
+RUN apt-get install wget -y
+#RUN apt-get install linux-headers-$(uname -r) -y
+#
+#RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_10.0.130-1_amd64.deb
+#RUN dpkg -i cuda-repo-ubuntu1604_10.0.130-1_amd64.deb
+#RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/7fa2af80.pub
+#RUN apt-get update -y
+#RUN DEBIAN_FRONTEND=noninteractive apt-get -qy install cuda-10-0
+
+#RUN apt-get install build-essential dkms -y
+#RUN apt-get install freeglut3 freeglut3-dev libxi-dev libxmu-dev -y
+
+
+#RUN add-apt-repository ppa:graphics-drivers/ppa -y
+RUN apt-get update -y
+RUN apt-get install python3-pip -y
+RUN python3.6 -m pip install --upgrade pip
+RUN python3.6 -m pip install tensorflow-gpu==1.13.1 
+RUN python3.6 -m pip install keras==2.2.4
+
+# install make
+RUN apt-get install make -y
+RUN apt-get install libproj-dev -y
+RUN apt-get install proj-bin -y
+RUN apt-get install libgeos++-dev -y
+RUN pip3.6 install GEOS
diff --git a/HPC_scripts/DataExtraction.sh b/HPC_scripts/DataExtraction.sh
new file mode 100755
index 0000000000000000000000000000000000000000..78b0e499a65a47d0c0057136bb97f6d3d16ec64d
--- /dev/null
+++ b/HPC_scripts/DataExtraction.sh
@@ -0,0 +1,24 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=13
+##SBATCH --ntasks-per-node=13
+#SBATCH --cpus-per-task=1
+#SBATCH --output=DataExtraction-out.%j
+#SBATCH --error=DataExtraction-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --partition=devel
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=s.stadtler@fz-juelich.de
+##jutil env activate -p deepacf
+
+module --force purge
+module use $OTHERSTAGES
+module load Stages/2019a
+module addad Intel/2019.3.199-GCC-8.3.0  ParaStationMPI/5.2.2-1
+module load h5py/2.9.0-Python-3.6.8
+module load mpi4py/3.0.1-Python-3.6.8
+
+#module load mpi4py/3.0.1-Python-3.6.8
+
+srun python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/2017/ --destination_dir /p/scratch/deepacf/scarlet/extractedData
diff --git a/HPC_scripts/DataPreprocess.sh b/HPC_scripts/DataPreprocess.sh
new file mode 100755
index 0000000000000000000000000000000000000000..48ed0581802fe5f629019e729425a7ca1445af4f
--- /dev/null
+++ b/HPC_scripts/DataPreprocess.sh
@@ -0,0 +1,43 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=12
+##SBATCH --ntasks-per-node=12
+#SBATCH --cpus-per-task=1
+#SBATCH --output=DataPreprocess-out.%j
+#SBATCH --error=DataPreprocess-err.%j
+#SBATCH --time=02:20:00
+#SBATCH --partition=batch
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+
+module --force purge
+module use  $OTHERSTAGES
+module load Stages/2019a
+module load Intel/2019.3.199-GCC-8.3.0  ParaStationMPI/5.2.2-1
+module load h5py/2.9.0-Python-3.6.8
+module load mpi4py/3.0.1-Python-3.6.8
+
+
+srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+ --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2015/ \
+ --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2015/ \
+ --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+
+srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+ --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2016/ \
+ --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2016/ \
+ --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+
+srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+ --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017/ \
+ --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2017/ \
+ --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+
+
+
+
+#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017 \
+# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016toY2017M01to12-128x160-74d00N71d0E-T_MSL_gph500/2017 \
+# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
diff --git a/HPC_scripts/DataPreprocess_dev.sh b/HPC_scripts/DataPreprocess_dev.sh
new file mode 100755
index 0000000000000000000000000000000000000000..b5aa2010cbe2b5b9f87b5b65bade29db974bcc8d
--- /dev/null
+++ b/HPC_scripts/DataPreprocess_dev.sh
@@ -0,0 +1,68 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=12
+##SBATCH --ntasks-per-node=12
+#SBATCH --cpus-per-task=1
+#SBATCH --output=DataPreprocess-out.%j
+#SBATCH --error=DataPreprocess-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --partition=devel
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=m.langguth@fz-juelich.de
+
+module --force purge
+module use  $OTHERSTAGES
+module load Stages/2019a
+module load Intel/2019.3.199-GCC-8.3.0  ParaStationMPI/5.2.2-1
+module load h5py/2.9.0-Python-3.6.8
+module load mpi4py/3.0.1-Python-3.6.8
+
+source_dir=/p/scratch/deepacf/video_prediction_shared_folder/extractedData
+destination_dir=/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle
+declare -a years=("2015" 
+                 "2016" 
+                 "2017"
+                  )
+
+
+
+for year in "${years[@]}"; 
+    do 
+        echo "Year $year"
+	echo "source_dir ${source_dir}/${year}"
+	srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+         --source_dir ${source_dir}/${year}/ \
+         --destination_dir ${destination_dir}/${year}/ --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+    done
+
+
+srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir}
+
+
+
+
+#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2015/ \
+# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2015/ \
+# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+
+#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2016/ \
+# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2016/ \
+# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+
+#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017/ \
+# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2017/ \
+# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+
+#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py \
+#--destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-#T_MSL_gph500/
+
+
+
+#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017 \
+# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016toY2017M01to12-128x160-74d00N71d0E-T_MSL_gph500/2017 \
+# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
diff --git a/HPC_scripts/DataPreprocess_to_tf.sh b/HPC_scripts/DataPreprocess_to_tf.sh
new file mode 100755
index 0000000000000000000000000000000000000000..6f541b9d31f582dfd8b9318f7980930716c6c09b
--- /dev/null
+++ b/HPC_scripts/DataPreprocess_to_tf.sh
@@ -0,0 +1,22 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=12
+##SBATCH --ntasks-per-node=12
+#SBATCH --cpus-per-task=1
+#SBATCH --output=DataPreprocess_to_tf-out.%j
+#SBATCH --error=DataPreprocess_to_tf-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --partition=devel
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+
+module purge
+module use $OTHERSTAGES
+module load Stages/2019a
+module load Intel/2019.3.199-GCC-8.3.0  ParaStationMPI/5.2.2-1
+module load h5py/2.9.0-Python-3.6.8
+module load mpi4py/3.0.1-Python-3.6.8
+module load TensorFlow/1.13.1-GPU-Python-3.6.8
+
+srun python ../video_prediction/datasets/era5_dataset_v2.py /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016M01to12-128_160-74.00N710E-T_T_T/splits/ /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016M01to12-128_160-74.00N710E-T_T_T/tfrecords/ -vars T2 T2 T2 
diff --git a/HPC_scripts/generate_era5.sh b/HPC_scripts/generate_era5.sh
new file mode 100755
index 0000000000000000000000000000000000000000..c6121ebfc4e441d587cd97fabb366c106324a8be
--- /dev/null
+++ b/HPC_scripts/generate_era5.sh
@@ -0,0 +1,31 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+##SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=1
+#SBATCH --output=generate_era5-out.%j
+#SBATCH --error=generate_era5-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --gres=gpu:1
+#SBATCH --partition=develgpus
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+##jutil env activate -p cjjsc42
+
+
+module purge
+module load GCC/8.3.0
+module load ParaStationMPI/5.2.2-1
+module load TensorFlow/1.13.1-GPU-Python-3.6.8
+module load netcdf4-python/1.5.0.1-Python-3.6.8
+module load h5py/2.9.0-Python-3.6.8
+
+
+python -u ../scripts/generate_transfer_learning_finetune.py \
+--input_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/tfrecords/  \
+--dataset_hparams sequence_length=20 --checkpoint  /p/scratch/deepacf/video_prediction_shared_folder/models/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/ours_gan \
+--mode test --results_dir /p/scratch/deepacf/video_prediction_shared_folder/results/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T \
+--batch_size 4 --dataset era5   > generate_era5-out.out
+
+#srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
diff --git a/HPC_scripts/train_era5.sh b/HPC_scripts/train_era5.sh
new file mode 100755
index 0000000000000000000000000000000000000000..ef060b0d985aa0141a1a1cdb974bf04f37b7204b
--- /dev/null
+++ b/HPC_scripts/train_era5.sh
@@ -0,0 +1,27 @@
+#!/bin/bash -x
+#SBATCH --account=deepacf
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+##SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=1
+#SBATCH --output=train_era5-out.%j
+#SBATCH --error=train_era5-err.%j
+#SBATCH --time=00:20:00
+#SBATCH --gres=gpu:1
+#SBATCH --partition=develgpus
+#SBATCH --mail-type=ALL
+#SBATCH --mail-user=b.gong@fz-juelich.de
+##jutil env activate -p cjjsc42
+
+module --force purge 
+module use $OTHERSTAGES
+module load Stages/2019a
+module load GCCcore/.8.3.0
+module load mpi4py/3.0.1-Python-3.6.8
+module load h5py/2.9.0-serial-Python-3.6.8
+module load TensorFlow/1.13.1-GPU-Python-3.6.8
+module load cuDNN/7.5.1.10-CUDA-10.1.105
+
+
+srun python ../scripts/train_v2.py --input_dir  /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/2017M01to12-64_64-50.00N11.50E-T_T_T/tfrecords  --dataset era5  --model savp --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir /p/scratch/deepacf/video_prediction_shared_folder/models/2017M01to12-64_64-50.00N11.50E-T_T_T/ours_savp
+#srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..a2f3427ebe81f7da34fdd2d27e54462083db8fa5
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2018 Alex X. Lee
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..28ea5b77e082085e1904ecc5a38d1e6730fa7cfe
--- /dev/null
+++ b/README.md
@@ -0,0 +1,121 @@
+# Video Prediction by GAN
+
+This project aims to adopt the GAN-based architectures,  which original proposed by [[Project Page]](https://alexlee-gk.github.io/video_prediction/) [[Paper]](https://arxiv.org/abs/1804.01523), to predict temperature based on ERA5 data
+ 
+## Getting Started ###
+### Prerequisites
+- Linux or macOS
+- Python 3
+- CPU or NVIDIA GPU + CUDA CuDNN
+
+### Installation 
+This project need to work with [Workflow_parallel_frame_prediction project](https://gitlab.version.fz-juelich.de/gong1/workflow_parallel_frame_prediction)
+- Clone this repo:
+```bash
+git clone master https://gitlab.version.fz-juelich.de/gong1/video_prediction_savp.git
+git clone master https://gitlab.version.fz-juelich.de/gong1/workflow_parallel_frame_prediction.git
+```
+
+### Set-up env on JUWELS
+
+- Set up env and install packages
+
+```bash
+cd video_prediction_savp
+source env_setup/create_env.sh <dir_name> <env_name>
+```
+
+## Workflow by steps
+
+### Data Extraction
+
+```python
+python3 ../workflow_video_prediction/DataExtraction/mpi_stager_v2.py  --source_dir <input_dir1> --destination_dir <output_dir1>
+```
+
+e.g. 
+```python
+python3 ../workflow_video_prediction/DataExtraction/mpi_stager_v2.py  --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/2017/ --destination_dir /p/scratch/deepacf/bing/extractedData
+```
+
+### Data Preprocessing
+```python
+python3 ../workflow_video_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py --source_dir <output_dir1> --destination_dir <output_dir2> 
+
+python3 video_prediction/datasets/era5_dataset_v2.py  --source_dir   <output_dir2> --destination_dir <.data/exp_name>
+```
+
+Example
+```python
+python3 ../workflow_video_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py --source_dir /p/scratch/deepacf/bing/extractedData --destination_dir /p/scratch/deepacf/bing/preprocessedData
+
+python3 video_prediction/datasets/era5_dataset_v2.py /p/scratch/deepacf/bing/preprocessedData  ./data/era5_64_64_3_3t_norm
+ ```
+ 
+### Trarining
+
+```python
+python3 scripts/train_v2.py --input_dir <./data/exp_name> --dataset era5  --model <savp> --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir <./logs/{exp_name}/{mode}/>
+```
+
+Example
+```python
+python3 scripts/train_v2.py --input_dir ./data/era5_size_64_64_3_3t_norm --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5_64_64_3_3t_norm/end_to_end
+```
+### Postprocessing
+
+Generating prediction frames, model evaluation, and visulization
+You can trained your own model from the training step , or you can copy the Bing's trained model
+
+```python
+python3 scripts/generate_transfer_learning_finetune.py --input_dir <./data/exp_name>  --dataset_hparams sequence_length=20 --checkpoint <./logs/{exp_name}/{mode}/{model}> --mode test --results_dir <./results/{exp_name}/{mode}>  --batch_size <batch_size> --dataset era5
+```
+
+- example: use end_to_end training model from bing for exp_name:era5_size_64_64_3_3t_norm
+```python
+python3 scripts/generate_transfer_learning_finetune.py --input_dir data/era5_size_64_64_3_3t_norm --dataset_hparams sequence_length=20 --checkpoint /p/project/deepacf/deeprain/bing/video_prediction_savp/logs/era5_size_64_64_3_3t_norm/end_to_end/ours_savp --mode test --results_dir results_test_samples/era5_size_64_64_3_3t_norm/end_to_end  --batch_size 4 --dataset era5
+```
+
+![Groud Truth](/results_test_samples/era5_size_64_64_3_norm_dup/ours_savp/Sample_Batch_id_0_Sample_1.mp4)
+# End-to-End run the entire workflow
+
+```bash
+./bash/workflow_era5.sh <model>  <train_mode>  <exp_name>
+```
+
+example:
+```bash
+./bash/workflow_era5.sh savp end_to_end  era5_size_64_64_3_3t_norm
+```
+
+
+
+### Recomendation for output folder structure and name convention
+The details can be found [name_convention](docs/structure_name_convention.md)
+
+```
+├── ExtractedData
+│   ├── [Year]
+│   │   ├── [Month]
+│   │   │   ├── **/*.netCDF
+├── PreprocessedData
+│   ├── [Data_name_convention]
+│   │   ├── hickle
+│   │   │   ├── train
+│   │   │   ├── val
+│   │   │   ├── test
+│   │   ├── tfrecords
+│   │   │   ├── train
+│   │   │   ├── val
+│   │   │   ├── test
+├── Models
+│   ├── [Data_name_convention]
+│   │   ├── [model_name]
+│   │   ├── [model_name]
+├── Results
+│   ├── [Data_name_convention]
+│   │   ├── [training_mode]
+│   │   │   ├── [source_data_name_convention]
+│   │   │   │   ├── [model_name]
+
+```
\ No newline at end of file
diff --git a/Zam347_scripts/DataExtraction.sh b/Zam347_scripts/DataExtraction.sh
new file mode 100755
index 0000000000000000000000000000000000000000..6953b7d8484b0eba9d8928b86b1ffbe9d396e8f0
--- /dev/null
+++ b/Zam347_scripts/DataExtraction.sh
@@ -0,0 +1,4 @@
+#!/bin/bash -x
+
+
+mpirun -np 4 python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir /home/b.gong/data_era5/2017/ --destination_dir /home/${USER}/extractedData/2017
diff --git a/Zam347_scripts/DataPreprocess.sh b/Zam347_scripts/DataPreprocess.sh
new file mode 100755
index 0000000000000000000000000000000000000000..b9941b0e703346f31fd62339882a07ccc20454da
--- /dev/null
+++ b/Zam347_scripts/DataPreprocess.sh
@@ -0,0 +1,22 @@
+#!/bin/bash -x
+
+
+source_dir=/home/$USER/extractedData
+destination_dir=/home/$USER/preprocessedData/era5-Y2017M01to02
+script_dir=`pwd`
+
+declare -a years=("2017")
+
+for year in "${years[@]}";
+    do
+        echo "Year $year"
+        echo "source_dir ${source_dir}/${year}"
+        mpirun -np 2 python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
+         --source_dir ${source_dir} -scr_dir ${script_dir} \
+         --destination_dir ${destination_dir} --years ${years} --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710
+    done
+python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir} --varnames T2 MSL gph500
+
+
+
+
diff --git a/Zam347_scripts/DataPreprocess_to_tf.sh b/Zam347_scripts/DataPreprocess_to_tf.sh
new file mode 100755
index 0000000000000000000000000000000000000000..d84a41b72dd6e768a6a5b6419aa008631efee70f
--- /dev/null
+++ b/Zam347_scripts/DataPreprocess_to_tf.sh
@@ -0,0 +1,8 @@
+#!/bin/bash -x
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+source_dir=/home/${USER}/preprocessedData/
+destination_dir=/home/${USER}/preprocessedData/
+
+
+python ../video_prediction/datasets/era5_dataset_v2.py ${source_dir}/splits/hickle ${destination_dir}/tf_records -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20 
diff --git a/Zam347_scripts/generate_era5.sh b/Zam347_scripts/generate_era5.sh
new file mode 100755
index 0000000000000000000000000000000000000000..d9d710e5c4f3cc2d2825bf67bf2b668f6f9ddbd8
--- /dev/null
+++ b/Zam347_scripts/generate_era5.sh
@@ -0,0 +1,18 @@
+#!/bin/bash -x
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+source_dir=/home/${USER}/preprocessedData/
+checkpoint_dir=/home/${USER}/models/
+results_dir=/home/${USER}/results/
+
+# for choosing the model
+model=mcnet
+
+# execute respective Python-script
+python -u ../scripts/generate_transfer_learning_finetune.py \
+--input_dir ${source_dir}/tfrecords  \
+--dataset_hparams sequence_length=20 --checkpoint  ${checkpoint_dir}/${model} \
+--mode test --results_dir ${results_dir} \
+--batch_size 2 --dataset era5   > generate_era5-out.out
+
+#srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh
new file mode 100755
index 0000000000000000000000000000000000000000..aadb25997e2715ac719457c969a6f54982ec93a6
--- /dev/null
+++ b/Zam347_scripts/train_era5.sh
@@ -0,0 +1,13 @@
+#!/bin/bash -x
+
+# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
+source_dir=/home/${USER}/preprocessedData/
+destination_dir=/home/${USER}/models/
+
+# for choosing the model
+model=mcnet
+model_hparams=../hparams/era5/model_hparams.json
+
+# execute respective Python-script
+python ../scripts/train_dummy.py --input_dir  ${source_dir}/tfrecords/ --dataset era5  --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/
+#srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
diff --git a/bash/download_and_preprocess_dataset.sh b/bash/download_and_preprocess_dataset.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5779c2b7ff79f84ccb52e1e44cb7c0cd0d4ee154
--- /dev/null
+++ b/bash/download_and_preprocess_dataset.sh
@@ -0,0 +1,88 @@
+#!/usr/bin/env bash
+
+# exit if any command fails
+set -e
+
+if [ "$#" -eq 2 ]; then
+  if [ $1 = "bair" ]; then
+    echo "IMAGE_SIZE argument is only applicable to kth dataset" >&2
+    exit 1
+  fi
+elif [ "$#" -ne 1 ]; then
+  echo "Usage: $0 DATASET_NAME [IMAGE_SIZE]" >&2
+  exit 1
+fi
+if [ $1 = "bair" ]; then
+  TARGET_DIR=./data/bair
+  mkdir -p ${TARGET_DIR}
+  TAR_FNAME=bair_robot_pushing_dataset_v0.tar
+  URL=http://rail.eecs.berkeley.edu/datasets/${TAR_FNAME}
+  echo "Downloading '$1' dataset (this takes a while)"
+  #wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} Bing: on MacOS system , use curl instead of wget
+  curl ${URL} -O ${TARGET_DIR}/${TAR_FNAME}
+  tar -xvf ${TARGET_DIR}/${TAR_FNAME} --strip-components=1 -C ${TARGET_DIR}
+  rm ${TARGET_DIR}/${TAR_FNAME}
+  mkdir -p ${TARGET_DIR}/val
+  # reserve a fraction of the training set for validation
+  mv ${TARGET_DIR}/train/traj_256_to_511.tfrecords ${TARGET_DIR}/val/
+elif [ $1 = "kth" ]; then
+  if [ "$#" -eq 2 ]; then
+    IMAGE_SIZE=$2
+    TARGET_DIR=./data/kth_${IMAGE_SIZE}
+  else
+    IMAGE_SIZE=64
+    TARGET_DIR=./data/kth
+  fi
+  echo ${TARGET_DIR} ${IMAGE_SIZE}
+  mkdir -p ${TARGET_DIR}
+  mkdir -p ${TARGET_DIR}/raw
+  echo "Downloading '$1' dataset (this takes a while)"
+  # TODO Bing: for save time just use walking, need to change back if all the data are needed
+  #for ACTION in walking jogging running boxing handwaving handclapping; do
+#  for ACTION in walking; do
+#    echo "Action: '$ACTION' "
+#    ZIP_FNAME=${ACTION}.zip
+#    URL=http://www.nada.kth.se/cvap/actions/${ZIP_FNAME}
+#   # wget ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME}
+#    echo "Start downloading action '$ACTION' ULR '$URL' "
+#    curl ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME}
+#    unzip ${TARGET_DIR}/raw/${ZIP_FNAME} -d ${TARGET_DIR}/raw/${ACTION}
+#    echo "Action '$ACTION' data download and unzip "
+#  done
+  FRAME_RATE=25
+#  mkdir -p ${TARGET_DIR}/processed
+#  # download files with metadata specifying the subsequences
+#  TAR_FNAME=kth_meta.tar.gz
+#  URL=http://rail.eecs.berkeley.edu/models/savp/data/${TAR_FNAME}
+#  echo "Downloading '${TAR_FNAME}' ULR '$URL' "
+#  #wget ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME}
+#  curl ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME}
+#  tar -xzvf ${TARGET_DIR}/processed/${TAR_FNAME} --strip 1 -C ${TARGET_DIR}/processed
+  # convert the videos into sequence of downscaled images
+  echo "Processing '$1' dataset"
+  #TODO Bing, just use walking for test
+  #for ACTION in walking jogging running boxing handwaving handclapping; do
+  #Todo Bing: remove the comments below after testing
+  for ACTION in walking running; do
+    for VIDEO_FNAME in ${TARGET_DIR}/raw/${ACTION}/*.avi; do
+      FNAME=$(basename ${VIDEO_FNAME})
+      FNAME=${FNAME%_uncomp.avi}
+      echo "FNAME '$FNAME' "
+      # sometimes the directory is not created, so try until it is
+      while [ ! -d "${TARGET_DIR}/processed/${ACTION}/${FNAME}" ]; do
+        mkdir -p ${TARGET_DIR}/processed/${ACTION}/${FNAME}
+      done
+      ffmpeg -i ${VIDEO_FNAME} -r ${FRAME_RATE} -f image2 -s ${IMAGE_SIZE}x${IMAGE_SIZE} \
+      ${TARGET_DIR}/processed/${ACTION}/${FNAME}/image-%03d_${IMAGE_SIZE}x${IMAGE_SIZE}.png
+    done
+  done
+  python video_prediction/datasets/kth_dataset.py ${TARGET_DIR}/processed ${TARGET_DIR} ${IMAGE_SIZE}
+  rm -rf ${TARGET_DIR}/raw
+  rm -rf ${TARGET_DIR}/processed
+else
+  echo "Invalid dataset name: '$1' (choose from 'bair', 'kth')" >&2
+  exit 1
+fi
+echo "Succesfully finished downloadi\
+
+ng and preprocessing dataset '$1'"
diff --git a/bash/download_and_preprocess_dataset_era5.sh b/bash/download_and_preprocess_dataset_era5.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eacc01801b5e323ea8da8d7adc97c8156172fd7b
--- /dev/null
+++ b/bash/download_and_preprocess_dataset_era5.sh
@@ -0,0 +1,127 @@
+#!/usr/bin/env bash
+
+# exit if any command fails
+set -e
+
+
+#if [ "$#" -eq 2 ]; then
+#  if [ $1 = "bair" ]; then
+#    echo "IMAGE_SIZE argument is only applicable to kth dataset" >&2
+#    exit 1
+#  fi
+#elif [ "$#" -ne 1 ]; then
+#  echo "Usage: $0 DATASET_NAME [IMAGE_SIZE]" >&2
+#  exit 1
+#fi
+#if [ $1 = "bair" ]; then
+#  TARGET_DIR=./data/bair
+#  mkdir -p ${TARGET_DIR}
+#  TAR_FNAME=bair_robot_pushing_dataset_v0.tar
+#  URL=http://rail.eecs.berkeley.edu/datasets/${TAR_FNAME}
+#  echo "Downloading '$1' dataset (this takes a while)"
+#  #wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} Bing: on MacOS system , use curl instead of wget
+#  curl ${URL} -O ${TARGET_DIR}/${TAR_FNAME}
+#  tar -xvf ${TARGET_DIR}/${TAR_FNAME} --strip-components=1 -C ${TARGET_DIR}
+#  rm ${TARGET_DIR}/${TAR_FNAME}
+#  mkdir -p ${TARGET_DIR}/val
+#  # reserve a fraction of the training set for validation
+#  mv ${TARGET_DIR}/train/traj_256_to_511.tfrecords ${TARGET_DIR}/val/
+#elif [ $1 = "kth" ]; then
+#  if [ "$#" -eq 2 ]; then
+#    IMAGE_SIZE=$2
+#    TARGET_DIR=./data/kth_${IMAGE_SIZE}
+#  else
+#    IMAGE_SIZE=64
+#  fi
+#  echo ${TARGET_DIR} ${IMAGE_SIZE}
+#  mkdir -p ${TARGET_DIR}
+#  mkdir -p ${TARGET_DIR}/raw
+#  echo "Downloading '$1' dataset (this takes a while)"
+  # TODO Bing: for save time just use walking, need to change back if all the data are needed
+  #for ACTION in walking jogging running boxing handwaving handclapping; do
+#  for ACTION in walking; do
+#    echo "Action: '$ACTION' "
+#    ZIP_FNAME=${ACTION}.zip
+#    URL=http://www.nada.kth.se/cvap/actions/${ZIP_FNAME}
+#   # wget ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME}
+#    echo "Start downloading action '$ACTION' ULR '$URL' "
+#    curl ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME}
+#    unzip ${TARGET_DIR}/raw/${ZIP_FNAME} -d ${TARGET_DIR}/raw/${ACTION}
+#    echo "Action '$ACTION' data download and unzip "
+#  done
+#  FRAME_RATE=25
+#  mkdir -p ${TARGET_DIR}/processed
+#  # download files with metadata specifying the subsequences
+#  TAR_FNAME=kth_meta.tar.gz
+#  URL=http://rail.eecs.berkeley.edu/models/savp/data/${TAR_FNAME}
+#  echo "Downloading '${TAR_FNAME}' ULR '$URL' "
+#  #wget ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME}
+#  curl ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME}
+#  tar -xzvf ${TARGET_DIR}/processed/${TAR_FNAME} --strip 1 -C ${TARGET_DIR}/processed
+  # convert the videos into sequence of downscaled images
+#  echo "Processing '$1' dataset"
+#  #TODO Bing, just use walking for test
+#  #for ACTION in walking jogging running boxing handwaving handclapping; do
+#  #Todo Bing: remove the comments below after testing
+#  for ACTION in walking; do
+#    for VIDEO_FNAME in ${TARGET_DIR}/raw/${ACTION}/*.avi; do
+#      FNAME=$(basename ${VIDEO_FNAME})
+#      FNAME=${FNAME%_uncomp.avi}
+#      echo "FNAME '$FNAME' "
+#      # sometimes the directory is not created, so try until it is
+#      while [ ! -d "${TARGET_DIR}/processed/${ACTION}/${FNAME}" ]; do
+#        mkdir -p ${TARGET_DIR}/processed/${ACTION}/${FNAME}
+#      done
+#      ffmpeg -i ${VIDEO_FNAME} -r ${FRAME_RATE} -f image2 -s ${IMAGE_SIZE}x${IMAGE_SIZE} \
+#      ${TARGET_DIR}/processed/${ACTION}/${FNAME}/image-%03d_${IMAGE_SIZE}x${IMAGE_SIZE}.png
+#    done
+#  done
+#  python video_prediction/datasets/kth_dataset.py ${TARGET_DIR}/processed ${TARGET_DIR} ${IMAGE_SIZE}
+#  rm -rf ${TARGET_DIR}/raw
+#  rm -rf ${TARGET_DIR}/processed
+
+while [[ $# -gt 0 ]] #of the number of passed argument is greater than 0
+do
+key="$1"
+case $key in
+    -d|--data)
+    DATA="$2"
+    shift
+    shift
+    ;;
+    -i|--input_dir)
+    INPUT_DIR="$2"
+    shift
+    shift
+    ;;
+    -o|--output_dir)
+    OUTPUT_DIR="$2"
+    shift
+    shift
+    ;;
+esac
+done
+
+echo "DATA  = ${DATA} "
+
+echo "OUTPUT_DIRECTORY = ${OUTPUT_DIR}"
+
+if [ -d $INPUT_DIR ]; then
+    echo "INPUT DIRECTORY = ${INPUT_DIR}"
+
+else
+    echo "INPUT DIRECTORY '$INPUT_DIR' DOES NOT EXIST"
+    exit 1
+fi
+
+
+if [ $DATA = "era5" ]; then
+
+  mkdir -p ${OUTPUT_DIR}
+  python video_prediction/datasets/era5_dataset.py $INPUT_DIR  ${OUTPUT_DIR}
+else
+  echo "dataset name: '$DATA' (choose from 'era5')" >&2
+  exit 1
+fi
+
+echo "Succesfully finished downloading and preprocessing dataset '$DATA' "
\ No newline at end of file
diff --git a/bash/download_and_preprocess_dataset_v1.sh b/bash/download_and_preprocess_dataset_v1.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3541b4a538c089cd79ea2a39c6df0804e11cb0a6
--- /dev/null
+++ b/bash/download_and_preprocess_dataset_v1.sh
@@ -0,0 +1,86 @@
+#!/usr/bin/env bash
+
+# exit if any command fails
+set -e
+
+if [ "$#" -eq 2 ]; then
+  if [ $1 = "bair" ]; then
+    echo "IMAGE_SIZE argument is only applicable to kth dataset" >&2
+    exit 1
+  fi
+elif [ "$#" -ne 1 ]; then
+  echo "Usage: $0 DATASET_NAME [IMAGE_SIZE]" >&2
+  exit 1
+fi
+if [ $1 = "bair" ]; then
+  TARGET_DIR=./data/bair
+  mkdir -p ${TARGET_DIR}
+  TAR_FNAME=bair_robot_pushing_dataset_v0.tar
+  URL=http://rail.eecs.berkeley.edu/datasets/${TAR_FNAME}
+  echo "Downloading '$1' dataset (this takes a while)"
+  #wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} Bing: on MacOS system , use curl instead of wget
+  curl ${URL} -O ${TARGET_DIR}/${TAR_FNAME}
+  tar -xvf ${TARGET_DIR}/${TAR_FNAME} --strip-components=1 -C ${TARGET_DIR}
+  rm ${TARGET_DIR}/${TAR_FNAME}
+  mkdir -p ${TARGET_DIR}/val
+  # reserve a fraction of the training set for validation
+  mv ${TARGET_DIR}/train/traj_256_to_511.tfrecords ${TARGET_DIR}/val/
+elif [ $1 = "kth" ]; then
+  if [ "$#" -eq 2 ]; then
+    IMAGE_SIZE=$2
+    TARGET_DIR=./data/kth_${IMAGE_SIZE}
+  else
+    IMAGE_SIZE=64
+    TARGET_DIR=./data/kth
+  fi
+  echo ${TARGET_DIR} ${IMAGE_SIZE}
+  mkdir -p ${TARGET_DIR}
+  mkdir -p ${TARGET_DIR}/raw
+  echo "Downloading '$1' dataset (this takes a while)"
+  # TODO Bing: for save time just use walking, need to change back if all the data are needed
+  #for ACTION in walking jogging running boxing handwaving handclapping; do
+#  for ACTION in walking; do
+#    echo "Action: '$ACTION' "
+#    ZIP_FNAME=${ACTION}.zip
+#    URL=http://www.nada.kth.se/cvap/actions/${ZIP_FNAME}
+#   # wget ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME}
+#    echo "Start downloading action '$ACTION' ULR '$URL' "
+#    curl ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME}
+#    unzip ${TARGET_DIR}/raw/${ZIP_FNAME} -d ${TARGET_DIR}/raw/${ACTION}
+#    echo "Action '$ACTION' data download and unzip "
+#  done
+  FRAME_RATE=25
+#  mkdir -p ${TARGET_DIR}/processed
+#  # download files with metadata specifying the subsequences
+#  TAR_FNAME=kth_meta.tar.gz
+#  URL=http://rail.eecs.berkeley.edu/models/savp/data/${TAR_FNAME}
+#  echo "Downloading '${TAR_FNAME}' ULR '$URL' "
+#  #wget ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME}
+#  curl ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME}
+#  tar -xzvf ${TARGET_DIR}/processed/${TAR_FNAME} --strip 1 -C ${TARGET_DIR}/processed
+  # convert the videos into sequence of downscaled images
+  echo "Processing '$1' dataset"
+  #TODO Bing, just use walking for test
+  #for ACTION in walking jogging running boxing handwaving handclapping; do
+  #Todo Bing: remove the comments below after testing
+  for ACTION in walking; do
+    for VIDEO_FNAME in ${TARGET_DIR}/raw/${ACTION}/*.avi; do
+      FNAME=$(basename ${VIDEO_FNAME})
+      FNAME=${FNAME%_uncomp.avi}
+      echo "FNAME '$FNAME' "
+      # sometimes the directory is not created, so try until it is
+      while [ ! -d "${TARGET_DIR}/processed/${ACTION}/${FNAME}" ]; do
+        mkdir -p ${TARGET_DIR}/processed/${ACTION}/${FNAME}
+      done
+      ffmpeg -i ${VIDEO_FNAME} -r ${FRAME_RATE} -f image2 -s ${IMAGE_SIZE}x${IMAGE_SIZE} \
+      ${TARGET_DIR}/processed/${ACTION}/${FNAME}/image-%03d_${IMAGE_SIZE}x${IMAGE_SIZE}.png
+    done
+  done
+  python video_prediction/datasets/kth_dataset.py ${TARGET_DIR}/processed ${TARGET_DIR} ${IMAGE_SIZE}
+  rm -rf ${TARGET_DIR}/raw
+  rm -rf ${TARGET_DIR}/processed
+else
+  echo "Invalid dataset name: '$1' (choose from 'bair', 'kth')" >&2
+  exit 1
+fi
+echo "Succesfully finished downloading and preprocessing dataset '$1'"
diff --git a/bash/workflow_era5.sh b/bash/workflow_era5.sh
new file mode 100755
index 0000000000000000000000000000000000000000..01d16bfdf7f38ffe00495ba31f85349d9ce68335
--- /dev/null
+++ b/bash/workflow_era5.sh
@@ -0,0 +1,92 @@
+#!/usr/bin/env bash
+set -e
+#
+#MODEL=savp
+##train_mode: end_to_end, pre_trained
+#TRAIN_MODE=end_to_end
+#EXP_NAME=era5_size_64_64_3_3t_norm
+
+MODEL=$1
+TRAIN_MODE=$2
+EXP_NAME=$3
+RETRAIN=1 #if we continue training the model or using the existing end-to-end model, 1 means continue training, and 1 means use the existing one
+DATA_ETL_DIR=/p/scratch/deepacf/${USER}/
+DATA_EXTRA_DIR=${DATA_ETL_DIR}/extractedData/${EXP_NAME}
+DATA_PREPROCESS_DIR=${DATA_ETL_DIR}/preprocessedData/${EXP_NAME}
+DATA_PREPROCESS_TF_DIR=./data/${EXP_NAME}
+RESULTS_OUTPUT_DIR=./results_test_samples/${EXP_NAME}/${TRAIN_MODE}/
+
+if [ $MODEL==savp ]; then
+    method_dir=ours_savp
+elif [ $MODEL==gan ]; then
+    method_dir=ours_gan
+elif [ $MODEL==vae ]; then
+    method_dir=ours_vae
+else
+    echo "model does not exist" 2>&1
+    exit 1
+fi
+
+if [ "$TRAIN_MODE" == pre_trained ]; then
+    TRAIN_OUTPUT_DIR=./pretrained_models/kth/${method_dir}
+else
+    TRAIN_OUTPUT_DIR=./logs/${EXP_NAME}/${TRAIN_MODE}
+fi
+
+CHECKPOINT_DIR=${TRAIN_OUTPUT_DIR}/${method_dir}
+
+echo "===========================WORKFLOW SETUP===================="
+echo "Model ${MODEL}"
+echo "TRAIN MODE ${TRAIN_MODE}"
+echo "Method_dir ${method_dir}"
+echo "DATA_ETL_DIR ${DATA_ETL_DIR}"
+echo "DATA_EXTRA_DIR ${DATA_EXTRA_DIR}"
+echo "DATA_PREPROCESS_DIR ${DATA_PREPROCESS_DIR}"
+echo "DATA_PREPROCESS_TF_DIR ${DATA_PREPROCESS_TF_DIR}"
+echo "TRAIN_OUTPUT_DIR ${TRAIN_OUTPUT_DIR}"
+echo "============================================================="
+
+##############Datat Preprocessing################
+#To hkl data
+if [ -d "$DATA_PREPROCESS_DIR" ]; then
+    echo "The Preprocessed Data (.hkl ) exist"
+else
+    python ../workflow_video_prediction/DataPreprocess/benchmark/mpi_stager_v2_process_netCDF.py \
+    --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR}
+fi
+
+#Change the .hkl data to .tfrecords files
+if [ -d "$DATA_PREPROCESS_TF_DIR" ]
+then
+    echo "Step2: The Preprocessed Data (tf.records) exist"
+else
+    echo "Step2: start, hkl. files to tf.records"
+    python ./video_prediction/datasets/era5_dataset_v2.py  --source_dir ${DATA_PREPROCESS_DIR}/splits \
+    --destination_dir ${DATA_PREPROCESS_TF_DIR}
+    echo "Step2: finish"
+fi
+
+#########Train##########################
+if [ "$TRAIN_MODE" == "pre_trained" ]; then
+    echo "step3: Using kth pre_trained model"
+elif [ "$TRAIN_MODE" == "end_to_end" ]; then
+    echo "step3: End-to-end training"
+    if [ "$RETRAIN" == 1 ]; then
+        echo "Using the existing end-to-end model"
+    else
+        echo "Step3: Training Starts "
+        python ./scripts/train_v2.py --input_dir $DATA_PREPROCESS_TF_DIR --dataset era5  \
+        --model ${MODEL} --model_hparams_dict hparams/kth/${method_dir}/model_hparams.json \
+        --output_dir ${TRAIN_OUTPUT_DIR} --checkpoint ${CHECKPOINT_DIR_DIR}
+        echo "Training ends "
+    fi
+else
+    echo "TRAIN_MODE is end_to_end or pre_trained"
+    exit 1
+fi
+
+#########Generate results#################
+echo "Step4: Postprocessing start"
+python ./scripts/generate_transfer_learning_finetune.py --input_dir ${DATA_PREPROCESS_TF_DIR} \
+--dataset_hparams sequence_length=20 --checkpoint ${CHECKPOINT_DIR_DIR} --mode test --results_dir ${RESULTS_OUTPUT_DIR} \
+--batch_size 4 --dataset era5
\ No newline at end of file
diff --git a/bash/workflow_era5_macOS.sh b/bash/workflow_era5_macOS.sh
new file mode 100755
index 0000000000000000000000000000000000000000..1a6ebef38df877b8ee20f628d4e375a20e7c8bd5
--- /dev/null
+++ b/bash/workflow_era5_macOS.sh
@@ -0,0 +1,93 @@
+#!/usr/bin/env bash
+set -e
+#
+#MODEL=savp
+##train_mode: end_to_end, pre_trained
+#TRAIN_MODE=end_to_end
+#EXP_NAME=era5_size_64_64_3_3t_norm
+
+MODEL=$1
+TRAIN_MODE=$2
+EXP_NAME=$3
+RETRAIN=1 #if we continue training the model or using the existing end-to-end model, 1 means continue training, and 1 means use the existing one
+DATA_ETL_DIR=/p/scratch/deepacf/${USER}/
+DATA_ETL_DIR=/p/scratch/deepacf/${USER}/
+DATA_EXTRA_DIR=${DATA_ETL_DIR}/extractedData/${EXP_NAME}
+DATA_PREPROCESS_DIR=${DATA_ETL_DIR}/preprocessedData/${EXP_NAME}
+DATA_PREPROCESS_TF_DIR=./data/${EXP_NAME}
+RESULTS_OUTPUT_DIR=./results_test_samples/${EXP_NAME}/${TRAIN_MODE}/
+
+if [ $MODEL==savp ]; then
+    method_dir=ours_savp
+elif [ $MODEL==gan ]; then
+    method_dir=ours_gan
+elif [ $MODEL==vae ]; then
+    method_dir=ours_vae
+else
+    echo "model does not exist" 2>&1
+    exit 1
+fi
+
+if [ "$TRAIN_MODE" == pre_trained ]; then
+    TRAIN_OUTPUT_DIR=./pretrained_models/kth/${method_dir}
+else
+    TRAIN_OUTPUT_DIR=./logs/${EXP_NAME}/${TRAIN_MODE}
+fi
+
+CHECKPOINT_DIR=${TRAIN_OUTPUT_DIR}/${method_dir}
+
+echo "===========================WORKFLOW SETUP===================="
+echo "Model ${MODEL}"
+echo "TRAIN MODE ${TRAIN_MODE}"
+echo "Method_dir ${method_dir}"
+echo "DATA_ETL_DIR ${DATA_ETL_DIR}"
+echo "DATA_EXTRA_DIR ${DATA_EXTRA_DIR}"
+echo "DATA_PREPROCESS_DIR ${DATA_PREPROCESS_DIR}"
+echo "DATA_PREPROCESS_TF_DIR ${DATA_PREPROCESS_TF_DIR}"
+echo "TRAIN_OUTPUT_DIR ${TRAIN_OUTPUT_DIR}"
+echo "============================================================="
+
+##############Datat Preprocessing################
+#To hkl data
+#if [ -d "$DATA_PREPROCESS_DIR" ]; then
+#    echo "The Preprocessed Data (.hkl ) exist"
+#else
+#    python ../workflow_video_prediction/DataPreprocess/benchmark/mpi_stager_v2_process_netCDF.py \
+#    --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR}
+#fi
+
+####Change the .hkl data to .tfrecords files
+if [ -d "$DATA_PREPROCESS_TF_DIR" ]
+then
+    echo "Step2: The Preprocessed Data (tf.records) exist"
+else
+    echo "Step2: start, hkl. files to tf.records"
+    python ./video_prediction/datasets/era5_dataset_v2.py  --source_dir ${DATA_PREPROCESS_DIR}/splits \
+    --destination_dir ${DATA_PREPROCESS_TF_DIR}
+    echo "Step2: finish"
+fi
+
+#########Train##########################
+if [ "$TRAIN_MODE" == "pre_trained" ]; then
+    echo "step3: Using kth pre_trained model"
+elif [ "$TRAIN_MODE" == "end_to_end" ]; then
+    echo "step3: End-to-end training"
+    if [ "$RETRAIN" == 1 ]; then
+        echo "Using the existing end-to-end model"
+    else
+        echo "Training Starts "
+        python ./scripts/train_v2.py --input_dir $DATA_PREPROCESS_TF_DIR --dataset era5  \
+        --model ${MODEL} --model_hparams_dict hparams/kth/${method_dir}/model_hparams.json \
+        --output_dir ${TRAIN_OUTPUT_DIR} --checkpoint ${CHECKPOINT_DIR}
+        echo "Training ends "
+    fi
+else
+    echo "TRAIN_MODE is end_to_end or pre_trained"
+    exit 1
+fi
+
+#########Generate results#################
+echo "Step4: Postprocessing start"
+python ./scripts/generate_transfer_learning_finetune.py --input_dir ${DATA_PREPROCESS_TF_DIR} \
+--dataset_hparams sequence_length=20 --checkpoint ${CHECKPOINT_DIR} --mode test --results_dir ${RESULTS_OUTPUT_DIR} \
+--batch_size 4 --dataset era5
diff --git a/bash/workflow_era5_zam347.sh b/bash/workflow_era5_zam347.sh
new file mode 100755
index 0000000000000000000000000000000000000000..ffe7209b6099f4ad9f57b4e90247a7d7acaf009d
--- /dev/null
+++ b/bash/workflow_era5_zam347.sh
@@ -0,0 +1,93 @@
+#!/usr/bin/env bash
+set -e
+#
+#MODEL=savp
+##train_mode: end_to_end, pre_trained
+#TRAIN_MODE=end_to_end
+#EXP_NAME=era5_size_64_64_3_3t_norm
+
+MODEL=$1
+TRAIN_MODE=$2
+EXP_NAME=$3
+RETRAIN=1 #if we continue training the model or using the existing end-to-end model, 1 means continue training, and 1 means use the existing one
+DATA_ETL_DIR=/home/${USER}/
+DATA_ETL_DIR=/p/scratch/deepacf/${USER}/
+DATA_EXTRA_DIR=${DATA_ETL_DIR}/extractedData/${EXP_NAME}
+DATA_PREPROCESS_DIR=${DATA_ETL_DIR}/preprocessedData/${EXP_NAME}
+DATA_PREPROCESS_TF_DIR=./data/${EXP_NAME}
+RESULTS_OUTPUT_DIR=./results_test_samples/${EXP_NAME}/${TRAIN_MODE}/
+
+if [ $MODEL==savp ]; then
+    method_dir=ours_savp
+elif [ $MODEL==gan ]; then
+    method_dir=ours_gan
+elif [ $MODEL==vae ]; then
+    method_dir=ours_vae
+else
+    echo "model does not exist" 2>&1
+    exit 1
+fi
+
+if [ "$TRAIN_MODE" == pre_trained ]; then
+    TRAIN_OUTPUT_DIR=./pretrained_models/kth/${method_dir}
+else
+    TRAIN_OUTPUT_DIR=./logs/${EXP_NAME}/${TRAIN_MODE}
+fi
+
+CHECKPOINT_DIR=${TRAIN_OUTPUT_DIR}/${method_dir}
+
+echo "===========================WORKFLOW SETUP===================="
+echo "Model ${MODEL}"
+echo "TRAIN MODE ${TRAIN_MODE}"
+echo "Method_dir ${method_dir}"
+echo "DATA_ETL_DIR ${DATA_ETL_DIR}"
+echo "DATA_EXTRA_DIR ${DATA_EXTRA_DIR}"
+echo "DATA_PREPROCESS_DIR ${DATA_PREPROCESS_DIR}"
+echo "DATA_PREPROCESS_TF_DIR ${DATA_PREPROCESS_TF_DIR}"
+echo "TRAIN_OUTPUT_DIR ${TRAIN_OUTPUT_DIR}"
+echo "============================================================="
+
+##############Datat Preprocessing################
+#To hkl data
+#if [ -d "$DATA_PREPROCESS_DIR" ]; then
+#    echo "The Preprocessed Data (.hkl ) exist"
+#else
+#    python ../workflow_video_prediction/DataPreprocess/benchmark/mpi_stager_v2_process_netCDF.py \
+#    --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR}
+#fi
+
+####Change the .hkl data to .tfrecords files
+if [ -d "$DATA_PREPROCESS_TF_DIR" ]
+then
+    echo "Step2: The Preprocessed Data (tf.records) exist"
+else
+    echo "Step2: start, hkl. files to tf.records"
+    python ./video_prediction/datasets/era5_dataset_v2.py  --source_dir ${DATA_PREPROCESS_DIR}/splits \
+    --destination_dir ${DATA_PREPROCESS_TF_DIR}
+    echo "Step2: finish"
+fi
+
+#########Train##########################
+if [ "$TRAIN_MODE" == "pre_trained" ]; then
+    echo "step3: Using kth pre_trained model"
+elif [ "$TRAIN_MODE" == "end_to_end" ]; then
+    echo "step3: End-to-end training"
+    if [ "$RETRAIN" == 1 ]; then
+        echo "Using the existing end-to-end model"
+    else
+        echo "Training Starts "
+        python ./scripts/train_v2.py --input_dir $DATA_PREPROCESS_TF_DIR --dataset era5  \
+        --model ${MODEL} --model_hparams_dict hparams/kth/${method_dir}/model_hparams.json \
+        --output_dir ${TRAIN_OUTPUT_DIR} --checkpoint ${CHECKPOINT_DIR}
+        echo "Training ends "
+    fi
+else
+    echo "TRAIN_MODE is end_to_end or pre_trained"
+    exit 1
+fi
+
+#########Generate results#################
+echo "Step4: Postprocessing start"
+python ./scripts/generate_transfer_learning_finetune.py --input_dir ${DATA_PREPROCESS_TF_DIR} \
+--dataset_hparams sequence_length=20 --checkpoint ${CHECKPOINT_DIR} --mode test --results_dir ${RESULTS_OUTPUT_DIR} \
+--batch_size 4 --dataset era5
diff --git a/deleteme.txt b/deleteme.txt
deleted file mode 100644
index e61ef7b965e17c62ca23b6ff5f0aaf09586e10e9..0000000000000000000000000000000000000000
--- a/deleteme.txt
+++ /dev/null
@@ -1 +0,0 @@
-aa
diff --git a/docs/discussion/discussion.md b/docs/discussion/discussion.md
new file mode 100644
index 0000000000000000000000000000000000000000..ff1d00c4c064c72c13029e84fbd0c18c3e4ff59b
--- /dev/null
+++ b/docs/discussion/discussion.md
@@ -0,0 +1,5 @@
+This is the list of last-mins files for VP group
+
+## 2020-03-01 - 2020-03-31
+
+- https://docs.google.com/document/d/1cQUEWrenIlW1zebZwSSHpfka2Bhb8u63kPM3x7nya_o/edit#heading=h.yjmq51s4fxnm
\ No newline at end of file
diff --git a/docs/presentation/presentation.md b/docs/presentation/presentation.md
new file mode 100644
index 0000000000000000000000000000000000000000..d49239089d5d881ef7a42e5e847ee45c3be725d4
--- /dev/null
+++ b/docs/presentation/presentation.md
@@ -0,0 +1,5 @@
+This is the presentation materials for VP group
+
+
+## 2020-03-01 - 2020-03-31
+https://docs.google.com/presentation/d/18EJKBJJ2LHI7uNU_l8s_Cm-aGZhw9tkoQ8BxqYZfkWk/edit#slide=id.g71f805bc32_0_80
diff --git a/docs/structure_name_convention.md b/docs/structure_name_convention.md
new file mode 100644
index 0000000000000000000000000000000000000000..4a2679c83ea8d99b9562ef775ed2ac1190f5d7fb
--- /dev/null
+++ b/docs/structure_name_convention.md
@@ -0,0 +1,108 @@
+This is the output folder structure and name convention
+
+## Shared folder structure
+
+```
+├── ExtractedData
+│   ├── [Year]
+│   │   ├── [Month]
+│   │   │   ├── **/*.netCDF
+├── PreprocessedData
+│   ├── [Data_name_convention]
+│   │   ├── hickle
+│   │   │   ├── train
+│   │   │   ├── val
+│   │   │   ├── test
+│   │   ├── tfrecords
+│   │   │   ├── train
+│   │   │   ├── val
+│   │   │   ├── test
+├── Models
+│   ├── [Data_name_convention]
+│   │   ├── [model_name]
+│   │   ├── [model_name]
+├── Results
+│   ├── [Data_name_convention]
+│   │   ├── [training_mode]
+│   │   │   ├── [source_data_name_convention]
+│   │   │   │   ├── [model_name]
+
+```
+
+| Arguments	| Value	|
+|---	|---	|
+| [Year]	| 2005;2005;2007 ...|
+| [Month]	| 01;02;03 ...,12|
+|[Data_name_convention]|Y[yyyy]to[yyyy]M[mm]to[mm]-[nx]_[ny]-[nn.nn]N[ee.ee]E-[var1]_[var2]_[var3]|
+|[model_name]| Ours_savp;  ours_gan;  ours_vae; prednet|
+|[training_mode]|end_to_end; transfer_learning|
+
+
+## Data name convention
+
+`Y[yyyy]to[yyyy]M[mm]to[mm]-[nx]_[ny]-[nn.nn]N[ee.ee]E-[var1]_[var2]_[var3]`
+
+ - Y[yyyy]to[yyyy]M[mm]to[mm]
+ - [nx]_[ny] :  the size of images,e.g 64_64 means 64*64 pixels 
+ - [nn.nn]N[ee.ee]E :the geolocation of selected regions with two decimal points. e.g : 0.00N11.50E
+ - [var1]_[var2]_[var3] : [Use the abbrevation of selected variables](#variable-abbrevaition-and-the-corresponding-full-names)  
+
+### `Y[yyyy]to[yyyy]M[mm]to[mm]`
+
+| Examples	| Name abbrevation 	|
+|---	|---	|
+|all data from March to June of the years 2005-2015	| Y2005toY2015M03to06 |   
+|data from February to May of years 2005-2008 + data from March to June of year 2015| Y2005to2008M02to05_Y2015M03to06 |   
+|Data from February to May, and October to December of 2005 |  Y2005M02to05_Y2015M10to12 |   
+|operational’ data base: whole year 2016 |  Y2016M01to12 |   
+|add new whole year data of 2017 on the operational data base |Y2016to2017M01to12 |  
+| Note: Y2016to2017M01to12 = Y2016M01to12_Y2017M01to12|  
+
+
+### variable abbrevaition and the corresponding full names
+
+| var	| full  names 	|
+|---	|---	|
+|T|2m temperature|   
+|gph500|500 hPa geopotential|   
+|msl|meansealevelpressure|   
+
+
+
+### Example
+
+```
+├── ExtractedData
+│   ├── 2016
+│   │   ├── 01
+│   │   │   ├── *.netCDF
+│   │   ├── 02
+│   │   ├── 03
+│   │   ├── …
+│   ├── 2017
+│   │   ├── 01
+│   │   ├── …
+├── PreprocessedData
+│   ├── 2016to2017M01to12-64_64-50.00N11.50E-T_T_T
+│   │   ├── hickle
+│   │   │   ├── train
+│   │   │   ├── val
+│   │   │   ├── test
+│   │   ├── tfrecords
+│   │   │   ├── train
+│   │   │   ├── val
+│   │   │   ├── test
+├── Models
+│   ├── 2016to2017M01to12-64_64-50.00N11.50E-T_T_T
+│   │   ├── outs_savp
+│   │   ├── outs_gan
+├── Results
+│   ├── 2016to2017M01to12-64_64-50.00N11.50E-T_T_T
+│   │   ├── end_to_end
+│   │   │   ├── ours_savp
+│   │   │   ├── ours_gan
+│   │   ├── transfer_learning
+│   │   │   ├── 2018M01to12-64_64-50.00N11.50E-T_T_T
+│   │   │   │   ├── ours_savp
+```
+
diff --git a/env_setup/create_env.sh b/env_setup/create_env.sh
new file mode 100755
index 0000000000000000000000000000000000000000..7d0f0a10bd8586e59fe129198a5a6f7c21121502
--- /dev/null
+++ b/env_setup/create_env.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+if [[ ! -n "$1" ]]; then
+  echo "Provide the user name, which will be taken as folder name"
+  exit 1
+fi
+
+if [[ ! -n "$2" ]]; then
+  echo "Provide the env name, which will be taken as folder name"
+  exit 1
+fi
+
+ENV_NAME=$2
+FOLDER_NAME=$1
+WORKING_DIR=/p/project/deepacf/deeprain/${FOLDER_NAME}/video_prediction_savp
+ENV_SETUP_DIR=${WORKING_DIR}/env_setup
+ENV_DIR=${WORKING_DIR}/${ENV_NAME}
+
+source ${ENV_SETUP_DIR}/modules.sh
+# Install additional Python packages.
+python3 -m venv $ENV_DIR
+source ${ENV_DIR}/bin/activate
+pip3 install -r ${ENV_SETUP_DIR}/requirements.txt
+#pip3 install --user netCDF4
+#pip3 install --user numpy
+
+#Copy the hickle package from bing's account
+cp  -r /p/project/deepacf/deeprain/bing/hickle ${WORKING_DIR}
+
+source ${ENV_SETUP_DIR}/modules.sh
+source ${ENV_DIR}/bin/activate
+
+export PYTHONPATH=${WORKING_DIR}/hickle/lib/python3.6/site-packages:$PYTHONPATH
+export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH
+export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH
+#export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH
+export PYTHONPATH=${WORKING_DIR}/lpips-tensorflow:$PYTHONPATH
+
+
diff --git a/env_setup/create_env_zam347.sh b/env_setup/create_env_zam347.sh
new file mode 100755
index 0000000000000000000000000000000000000000..5e0c43c5cc826c579b096c7b0aad0236e6b2002f
--- /dev/null
+++ b/env_setup/create_env_zam347.sh
@@ -0,0 +1,40 @@
+#!/usr/bin/env bash
+
+
+if [[ ! -n "$1" ]]; then
+  echo "Provide the env name, which will be taken as folder name"
+  exit 1
+fi
+
+ENV_NAME=$1
+WORKING_DIR=/home/$USER/video_prediction_savp
+ENV_SETUP_DIR=${WORKING_DIR}/env_setup
+ENV_DIR=${WORKING_DIR}/${ENV_NAME}
+unset PYTHONPATH
+#source ${ENV_SETUP_DIR}/modules.sh
+# Install additional Python packages.
+python3 -m venv $ENV_DIR
+source ${ENV_DIR}/bin/activate
+pip3 install --upgrade pip
+pip3 install -r ${ENV_SETUP_DIR}/requirements.txt
+#conda install mpi4py
+pip3 install  mpi4py 
+pip3 install netCDF4
+pip3 install  numpy
+pip3 install h5py
+pip3 install tensorflow-gpu==1.13.1
+#Copy the hickle package from bing's account
+#cp  -r /p/project/deepacf/deeprain/bing/hickle ${WORKING_DIR}
+cp -r /home/b.gong/video_prediction_savp/hickle ${WORKING_DIR}
+
+#source ${ENV_SETUP_DIR}/modules.sh
+#source ${ENV_DIR}/bin/activate
+
+#export PYTHONPATH=/home/$USER/miniconda3/pkgs:$PYTHONPATH
+export PYTHONPATH=${WORKING_DIR}/hickle/lib/python3.6/site-packages:$PYTHONPATH
+export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH
+#export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH
+#export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH
+export PYTHONPATH=${WORKING_DIR}/lpips-tensorflow:$PYTHONPATH
+
+
diff --git a/env_setup/modules.sh b/env_setup/modules.sh
new file mode 100755
index 0000000000000000000000000000000000000000..e6793787ad59988cfc6646dc8dd789d1573c6b23
--- /dev/null
+++ b/env_setup/modules.sh
@@ -0,0 +1,12 @@
+#!/usr/bin/env bash
+module purge
+module use $OTHERSTAGES
+module load Stages/2019a
+module load GCC/8.3.0
+module load MVAPICH2/.2.3.1-GDR
+module load GCCcore/.8.3.0
+module load mpi4py/3.0.1-Python-3.6.8
+module load h5py/2.9.0-serial-Python-3.6.8
+module load TensorFlow/1.13.1-GPU-Python-3.6.8
+module load cuDNN/7.5.1.10-CUDA-10.1.105
+
diff --git a/env_setup/requirements.txt b/env_setup/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..76dd1f57d64577cc565968bb7106656e53687261
--- /dev/null
+++ b/env_setup/requirements.txt
@@ -0,0 +1,4 @@
+opencv-python
+scipy
+scikit-image
+pandas
diff --git a/geo_info.json b/geo_info.json
new file mode 100644
index 0000000000000000000000000000000000000000..911a7c3b1333c4e815db705197cf77cb107de8a8
--- /dev/null
+++ b/geo_info.json
@@ -0,0 +1 @@
+{"lat": [58.19999694824219, 57.89999771118164, 57.599998474121094, 57.29999923706055, 57.0, 56.69999694824219, 56.39999771118164, 56.099998474121094, 55.79999923706055, 55.5, 55.19999694824219, 54.89999771118164, 54.599998474121094, 54.29999923706055, 54.0, 53.69999694824219, 53.39999771118164, 53.099998474121094, 52.79999923706055, 52.5, 52.19999694824219, 51.89999771118164, 51.599998474121094, 51.29999923706055, 51.0, 50.69999694824219, 50.39999771118164, 50.099998474121094, 49.79999923706055, 49.5, 49.19999694824219, 48.89999771118164, 48.599998474121094, 48.29999923706055, 48.0, 47.69999694824219, 47.39999771118164, 47.099998474121094, 46.79999923706055, 46.5, 46.19999694824219, 45.89999771118164, 45.599998474121094, 45.29999923706055, 45.0, 44.69999694824219, 44.39999771118164, 44.099998474121094, 43.79999923706055, 43.5, 43.19999694824219, 42.89999771118164, 42.599998474121094, 42.29999923706055, 42.0, 41.69999694824219, 41.39999771118164, 41.099998474121094, 40.79999923706055, 40.499996185302734, 40.19999694824219, 39.89999771118164, 39.599998474121094, 39.29999923706055], "lon": [-0.5999755859375, -0.29998779296875, 0.0, 0.30000001192092896, 0.6000000238418579, 0.9000000357627869, 1.2000000476837158, 1.5, 1.8000000715255737, 2.1000001430511475, 2.4000000953674316, 2.700000047683716, 3.0, 3.3000001907348633, 3.6000001430511475, 3.9000000953674316, 4.200000286102295, 4.5, 4.800000190734863, 5.100000381469727, 5.400000095367432, 5.700000286102295, 6.0, 6.300000190734863, 6.600000381469727, 6.900000095367432, 7.200000286102295, 7.500000476837158, 7.800000190734863, 8.100000381469727, 8.40000057220459, 8.700000762939453, 9.0, 9.300000190734863, 9.600000381469727, 9.90000057220459, 10.200000762939453, 10.5, 10.800000190734863, 11.100000381469727, 11.40000057220459, 11.700000762939453, 12.0, 12.300000190734863, 12.600000381469727, 12.90000057220459, 13.200000762939453, 13.500000953674316, 13.800000190734863, 14.100000381469727, 14.40000057220459, 14.700000762939453, 15.000000953674316, 15.300000190734863, 15.600000381469727, 15.90000057220459, 16.200000762939453, 16.5, 16.80000114440918, 17.100000381469727, 17.400001525878906, 17.700000762939453, 18.0, 18.30000114440918]}
\ No newline at end of file
diff --git a/helper/helper.py b/helper/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..e85a4df8a8ad9811b2634983db5eac8bf0204b47
--- /dev/null
+++ b/helper/helper.py
@@ -0,0 +1,53 @@
+import logging
+import time
+from functools import wraps
+
+def logDecorator(fn,verbose=False):
+    @wraps(fn)
+    def wrapper(*args,**kwargs):
+        print("inside wrapper of log decorator function")
+        logger = logging.getLogger(fn.__name__)
+        # create a file handler
+        handler = logging.FileHandler("log.log")
+        logger.setLevel(logging.DEBUG if verbose else logging.INFO)
+        #create a console handler
+        ch = logging.StreamHandler()
+        logger.setLevel(logging.DEBUG if verbose else logging.INFO)
+        # create a logging format
+        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+        handler.setFormatter(formatter)
+        ch.setFormatter(formatter)
+        logger.addHandler(handler)
+        logger.addHandler(ch)
+        logger.info("Logging 1")
+        start = time.time()
+        results = fn(*args,**kwargs)
+        end = time.time()
+        logger.info("{} ran in {}s".format(fn.__name__, round(end - start, 2)))
+        return results
+    return wrapper
+
+
+#logger = logging.getLogger(__name__)
+# def set_logger(verbose=False):
+#     # Remove all handlers associated with the root logger object.
+#     for handler in logging.root.handlers[:]:
+#         logging.root.removeHandler(handler)
+#     logger = logging.getLogger(__name__)
+#     logger.propagate = False
+#
+#
+#     if not logger.handlers:
+#         logger.setLevel(logging.DEBUG if verbose else logging.INFO)
+#         formatter = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+#
+#
+#
+#         #再创建一个handler,用于输出到控制台
+#         console_handler = logging.StreamHandler()
+#         console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
+#         console_handler.setFormatter(formatter)
+#         logger.handlers = []
+#         logger.addHandler(console_handler)
+#
+#     return logger
\ No newline at end of file
diff --git a/hparams/bair_action_free/ours_deterministic_l1/model_hparams.json b/hparams/bair_action_free/ours_deterministic_l1/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..4a1b23edcb68b57dadee82b1c13366afac50a52a
--- /dev/null
+++ b/hparams/bair_action_free/ours_deterministic_l1/model_hparams.json
@@ -0,0 +1,13 @@
+{
+    "batch_size": 32,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 1.0,
+    "l2_weight": 0.0,
+    "kl_weight": 0.0,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0,
+    "nz": 0
+}
\ No newline at end of file
diff --git a/hparams/bair_action_free/ours_deterministic_l2/model_hparams.json b/hparams/bair_action_free/ours_deterministic_l2/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..31e7152ae15df5ee33b264f11c88c76c50592185
--- /dev/null
+++ b/hparams/bair_action_free/ours_deterministic_l2/model_hparams.json
@@ -0,0 +1,13 @@
+{
+    "batch_size": 32,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 0.0,
+    "l2_weight": 1.0,
+    "kl_weight": 0.0,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0,
+    "nz": 0
+}
\ No newline at end of file
diff --git a/hparams/bair_action_free/ours_gan/model_hparams.json b/hparams/bair_action_free/ours_gan/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..38837822c90f38c6209dfa27019a90ccdf8ea43a
--- /dev/null
+++ b/hparams/bair_action_free/ours_gan/model_hparams.json
@@ -0,0 +1,14 @@
+{
+    "batch_size": 16,
+    "lr": 0.0002,
+    "beta1": 0.5,
+    "beta2": 0.999,
+    "l1_weight": 100.0,
+    "l2_weight": 0.0,
+    "kl_weight": 0.0,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.1,
+    "vae_gan_feature_cdist_weight": 0.0,
+    "gan_feature_cdist_weight": 10.0,
+    "state_weight": 0.0
+}
\ No newline at end of file
diff --git a/hparams/bair_action_free/ours_savp/model_hparams.json b/hparams/bair_action_free/ours_savp/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..a6eea83a19505d374e9d614f48ef8bc72443c0f2
--- /dev/null
+++ b/hparams/bair_action_free/ours_savp/model_hparams.json
@@ -0,0 +1,14 @@
+{
+    "batch_size": 16,
+    "lr": 0.0002,
+    "beta1": 0.5,
+    "beta2": 0.999,
+    "l1_weight": 100.0,
+    "l2_weight": 0.0,
+    "kl_weight": 1.0,
+    "video_sn_vae_gan_weight": 0.1,
+    "video_sn_gan_weight": 0.1,
+    "vae_gan_feature_cdist_weight": 10.0,
+    "gan_feature_cdist_weight": 0.0,
+    "state_weight": 0.0
+}
\ No newline at end of file
diff --git a/hparams/bair_action_free/ours_vae_l1/model_hparams.json b/hparams/bair_action_free/ours_vae_l1/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..827757e11b75e720d236417be449b7a301a005ec
--- /dev/null
+++ b/hparams/bair_action_free/ours_vae_l1/model_hparams.json
@@ -0,0 +1,12 @@
+{
+    "batch_size": 32,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 1.0,
+    "l2_weight": 0.0,
+    "kl_weight": 0.001,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0
+}
\ No newline at end of file
diff --git a/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json b/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..4fddf0eef1d45dbfa16f098e5e42b12f594132e3
--- /dev/null
+++ b/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json
@@ -0,0 +1,12 @@
+{
+    "batch_size": 32,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 0.0,
+    "l2_weight": 1.0,
+    "kl_weight": 0.001,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0
+}
\ No newline at end of file
diff --git a/hparams/era5/model_hparams.json b/hparams/era5/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..b121ee2f005b6db753b2536deb804204dd41b78d
--- /dev/null
+++ b/hparams/era5/model_hparams.json
@@ -0,0 +1,11 @@
+{
+    "batch_size": 8,
+    "lr": 0.001,
+    "nz": 16,
+    "max_steps":500,
+    "context_frames":10,
+    "sequence_length":20
+
+}
+
+
diff --git a/hparams/kth/ours_deterministic_l1/model_hparams.json b/hparams/kth/ours_deterministic_l1/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..4a1b23edcb68b57dadee82b1c13366afac50a52a
--- /dev/null
+++ b/hparams/kth/ours_deterministic_l1/model_hparams.json
@@ -0,0 +1,13 @@
+{
+    "batch_size": 32,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 1.0,
+    "l2_weight": 0.0,
+    "kl_weight": 0.0,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0,
+    "nz": 0
+}
\ No newline at end of file
diff --git a/hparams/kth/ours_deterministic_l2/model_hparams.json b/hparams/kth/ours_deterministic_l2/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..31e7152ae15df5ee33b264f11c88c76c50592185
--- /dev/null
+++ b/hparams/kth/ours_deterministic_l2/model_hparams.json
@@ -0,0 +1,13 @@
+{
+    "batch_size": 32,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 0.0,
+    "l2_weight": 1.0,
+    "kl_weight": 0.0,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0,
+    "nz": 0
+}
\ No newline at end of file
diff --git a/hparams/kth/ours_gan/model_hparams.json b/hparams/kth/ours_gan/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..3d14b63edbf14efca2cefe4703453d899b3fb0fd
--- /dev/null
+++ b/hparams/kth/ours_gan/model_hparams.json
@@ -0,0 +1,15 @@
+{
+    "batch_size": 16,
+    "lr": 0.0002,
+    "beta1": 0.5,
+    "beta2": 0.999,
+    "l1_weight": 100.0,
+    "l2_weight": 0.0,
+    "kl_weight": 0.0,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.1,
+    "vae_gan_feature_cdist_weight": 0.0,
+    "gan_feature_cdist_weight": 10.0,
+    "state_weight": 0.0,
+    "nz": 32
+}
\ No newline at end of file
diff --git a/hparams/kth/ours_savp/model_hparams.json b/hparams/kth/ours_savp/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..66b41f87e3c0f417b492314060121a0bfd01c8f9
--- /dev/null
+++ b/hparams/kth/ours_savp/model_hparams.json
@@ -0,0 +1,18 @@
+{
+    "batch_size": 8,
+    "lr": 0.0002,
+    "beta1": 0.5,
+    "beta2": 0.999,
+    "l1_weight": 100.0,
+    "l2_weight": 0.0,
+    "kl_weight": 0.01,
+    "video_sn_vae_gan_weight": 0.1,
+    "video_sn_gan_weight": 0.1,
+    "vae_gan_feature_cdist_weight": 10.0,
+    "gan_feature_cdist_weight": 0.0,
+    "state_weight": 0.0,
+    "nz": 32,
+    "max_steps":20
+}
+
+
diff --git a/hparams/kth/ours_vae_l1/model_hparams.json b/hparams/kth/ours_vae_l1/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..dee3ce9f8e431d7f7cb46042936cfae3dcfbc6e4
--- /dev/null
+++ b/hparams/kth/ours_vae_l1/model_hparams.json
@@ -0,0 +1,13 @@
+{
+    "batch_size": 32,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 1.0,
+    "l2_weight": 0.0,
+    "kl_weight": 1e-05,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0,
+    "nz": 32
+}
\ No newline at end of file
diff --git a/lpips-tensorflow/.gitignore b/lpips-tensorflow/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..894a44cc066a027465cd26d634948d56d13af9af
--- /dev/null
+++ b/lpips-tensorflow/.gitignore
@@ -0,0 +1,104 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
diff --git a/lpips-tensorflow/.gitmodules b/lpips-tensorflow/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..085c5852ff85afe688333807a8a26392d20e8ed3
--- /dev/null
+++ b/lpips-tensorflow/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "PerceptualSimilarity"]
+	path = PerceptualSimilarity
+	url = https://github.com/alexlee-gk/PerceptualSimilarity.git
diff --git a/lpips-tensorflow/LICENSE b/lpips-tensorflow/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..afbffa507fd832d614440dd75f81c8ed731ded66
--- /dev/null
+++ b/lpips-tensorflow/LICENSE
@@ -0,0 +1,25 @@
+BSD 2-Clause License
+
+Copyright (c) 2018, alexlee-gk
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+  list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+  this list of conditions and the following disclaimer in the documentation
+  and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/lpips-tensorflow/README.md b/lpips-tensorflow/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..760f5a028e2aae7187135c267edca89db464db5f
--- /dev/null
+++ b/lpips-tensorflow/README.md
@@ -0,0 +1,57 @@
+# lpips-tensorflow
+Tensorflow port for the [PyTorch](https://github.com/richzhang/PerceptualSimilarity) implementation of the [Learned Perceptual Image Patch Similarity (LPIPS)](http://richzhang.github.io/PerceptualSimilarity/) metric.
+This is done by exporting the model from PyTorch to ONNX and then to TensorFlow.
+
+## Getting started
+### Installation
+- Clone this repo.
+```bash
+git clone https://github.com/alexlee-gk/lpips-tensorflow.git
+cd lpips-tensorflow
+```
+- Install TensorFlow and dependencies from http://tensorflow.org/
+- Install other dependencies.
+```bash
+pip install -r requirements.txt
+```
+
+### Using the LPIPS metric
+The `lpips` TensorFlow function works with individual images or batches of images.
+It also works with images of any spatial dimensions (but the dimensions should be at least the size of the network's receptive field).
+This example computes the LPIPS distance between batches of images.
+```python
+import numpy as np
+import tensorflow as tf
+import lpips_tf
+
+batch_size = 32
+image_shape = (batch_size, 64, 64, 3)
+image0 = np.random.random(image_shape)
+image1 = np.random.random(image_shape)
+image0_ph = tf.placeholder(tf.float32)
+image1_ph = tf.placeholder(tf.float32)
+
+distance_t = lpips_tf.lpips(image0_ph, image1_ph, model='net-lin', net='alex')
+
+with tf.Session() as session:
+    distance = session.run(distance_t, feed_dict={image0_ph: image0, image1_ph: image1})
+```
+
+## Exporting additional models
+### Export PyTorch model to TensorFlow through ONNX
+- Clone the PerceptualSimilarity submodule and add it to the PYTHONPATH.
+```bash
+git submodule update --init --recursive
+export PYTHONPATH=PerceptualSimilarity:$PYTHONPATH
+```
+- Install more dependencies.
+```bash
+pip install -r requirements-dev.txt
+```
+- Export the model to ONNX *.onnx and TensorFlow *.pb files in the `models` directory.
+```bash
+python export_to_tensorflow.py --model net-lin --net alex
+```
+
+### Known issues
+- The SqueezeNet model cannot be exported since ONNX cannot export one of the operators.
diff --git a/lpips-tensorflow/export_to_tensorflow.py b/lpips-tensorflow/export_to_tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..32681117385903e8e7bfd19d4a7154a99a7ce78f
--- /dev/null
+++ b/lpips-tensorflow/export_to_tensorflow.py
@@ -0,0 +1,58 @@
+import argparse
+import os
+
+import onnx
+import torch
+import torch.onnx
+
+from models import dist_model as dm
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model', choices=['net-lin', 'net'], default='net-lin', help='net-lin or net')
+    parser.add_argument('--net', choices=['squeeze', 'alex', 'vgg'], default='alex', help='squeeze, alex, or vgg')
+    parser.add_argument('--version', type=str, default='0.1')
+    parser.add_argument('--image_height', type=int, default=64)
+    parser.add_argument('--image_width', type=int, default=64)
+    args = parser.parse_args()
+
+    model = dm.DistModel()
+    model.initialize(model=args.model, net=args.net, use_gpu=False, version=args.version)
+    print('Model [%s] initialized' % model.name())
+
+    dummy_im0 = torch.Tensor(1, 3, args.image_height, args.image_width)  # image should be RGB, normalized to [-1, 1]
+    dummy_im1 = torch.Tensor(1, 3, args.image_height, args.image_width)
+
+    cache_dir = os.path.expanduser('~/.lpips')
+    os.makedirs(cache_dir, exist_ok=True)
+    onnx_fname = os.path.join(cache_dir, '%s_%s_v%s.onnx' % (args.model, args.net, args.version))
+
+    # export model to onnx format
+    torch.onnx.export(model.net, (dummy_im0, dummy_im1), onnx_fname, verbose=True)
+
+    # load and change dimensions to be dynamic
+    model = onnx.load(onnx_fname)
+    for dim in (0, 2, 3):
+        model.graph.input[0].type.tensor_type.shape.dim[dim].dim_param = '?'
+        model.graph.input[1].type.tensor_type.shape.dim[dim].dim_param = '?'
+
+    # needs to be imported after all the pytorch stuff, otherwise this causes a segfault
+    from onnx_tf.backend import prepare
+    tf_rep = prepare(model, device='CPU')
+    producer_version = tf_rep.graph.graph_def_versions.producer
+    pb_fname = os.path.join(cache_dir, '%s_%s_v%s_%d.pb' % (args.model, args.net, args.version, producer_version))
+    tf_rep.export_graph(pb_fname)
+    input0_name, input1_name = [tf_rep.tensor_dict[input_name].name for input_name in tf_rep.inputs]
+    (output_name,) = [tf_rep.tensor_dict[output_name].name for output_name in tf_rep.outputs]
+
+    # ensure these are the names of the 2 inputs, since that will be assumed when loading the pb file
+    assert input0_name == '0:0'
+    assert input1_name == '1:0'
+    # ensure that the only output is the output of the last op in the graph, since that will be assumed later
+    (last_output_name,) = [output.name for output in tf_rep.graph.get_operations()[-1].outputs]
+    assert output_name == last_output_name
+
+
+if __name__ == '__main__':
+    main()
diff --git a/lpips-tensorflow/lpips_tf.py b/lpips-tensorflow/lpips_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c47f90b68f3fa05b8b2f29cf6100b6d63e584aa
--- /dev/null
+++ b/lpips-tensorflow/lpips_tf.py
@@ -0,0 +1,90 @@
+import os
+import sys
+
+import tensorflow as tf
+from six.moves import urllib
+
+_URL = 'http://rail.eecs.berkeley.edu/models/lpips'
+
+
+def _download(url, output_dir):
+    """Downloads the `url` file into `output_dir`.
+
+    Modified from https://github.com/tensorflow/models/blob/master/research/slim/datasets/dataset_utils.py
+    """
+    filename = url.split('/')[-1]
+    filepath = os.path.join(output_dir, filename)
+
+    def _progress(count, block_size, total_size):
+        sys.stdout.write('\r>> Downloading %s %.1f%%' % (
+            filename, float(count * block_size) / float(total_size) * 100.0))
+        sys.stdout.flush()
+
+    filepath, _ = urllib.request.urlretrieve(url, filepath, _progress)
+    print()
+    statinfo = os.stat(filepath)
+    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+
+
+def lpips(input0, input1, model='net-lin', net='alex', version=0.1):
+    """
+    Learned Perceptual Image Patch Similarity (LPIPS) metric.
+
+    Args:
+        input0: An image tensor of shape `[..., height, width, channels]`,
+            with values in [0, 1].
+        input1: An image tensor of shape `[..., height, width, channels]`,
+            with values in [0, 1].
+
+    Returns:
+        The Learned Perceptual Image Patch Similarity (LPIPS) distance.
+
+    Reference:
+        Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang.
+        The Unreasonable Effectiveness of Deep Features as a Perceptual Metric.
+        In CVPR, 2018.
+    """
+    # flatten the leading dimensions
+    batch_shape = tf.shape(input0)[:-3]
+    input0 = tf.reshape(input0, tf.concat([[-1], tf.shape(input0)[-3:]], axis=0))
+    input1 = tf.reshape(input1, tf.concat([[-1], tf.shape(input1)[-3:]], axis=0))
+    # NHWC to NCHW
+    input0 = tf.transpose(input0, [0, 3, 1, 2])
+    input1 = tf.transpose(input1, [0, 3, 1, 2])
+    # normalize to [-1, 1]
+    input0 = input0 * 2.0 - 1.0
+    input1 = input1 * 2.0 - 1.0
+
+    input0_name, input1_name = '0:0', '1:0'
+
+    default_graph = tf.get_default_graph()
+    producer_version = default_graph.graph_def_versions.producer
+
+    cache_dir = os.path.expanduser('~/.lpips')
+    os.makedirs(cache_dir, exist_ok=True)
+    # files to try. try a specific producer version, but fallback to the version-less version (latest).
+    pb_fnames = [
+        '%s_%s_v%s_%d.pb' % (model, net, version, producer_version),
+        '%s_%s_v%s.pb' % (model, net, version),
+    ]
+    for pb_fname in pb_fnames:
+        if not os.path.isfile(os.path.join(cache_dir, pb_fname)):
+            try:
+                _download(os.path.join(_URL, pb_fname), cache_dir)
+            except urllib.error.HTTPError:
+                pass
+        if os.path.isfile(os.path.join(cache_dir, pb_fname)):
+            break
+
+    with open(os.path.join(cache_dir, pb_fname), 'rb') as f:
+        graph_def = tf.GraphDef()
+        graph_def.ParseFromString(f.read())
+        _ = tf.import_graph_def(graph_def,
+                                input_map={input0_name: input0, input1_name: input1})
+        distance, = default_graph.get_operations()[-1].outputs
+
+    if distance.shape.ndims == 4:
+        distance = tf.squeeze(distance, axis=[-3, -2, -1])
+    # reshape the leading dimensions
+    distance = tf.reshape(distance, batch_shape)
+    return distance
diff --git a/lpips-tensorflow/requirements-dev.txt b/lpips-tensorflow/requirements-dev.txt
new file mode 100644
index 0000000000000000000000000000000000000000..df36766f1d4cc3b378d7aba0164f3fdfab3be1d2
--- /dev/null
+++ b/lpips-tensorflow/requirements-dev.txt
@@ -0,0 +1,4 @@
+torch>=0.4.0
+torchvision>=0.2.1
+onnx
+onnx-tf
diff --git a/lpips-tensorflow/requirements.txt b/lpips-tensorflow/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bc2cbbd1ca22aca38c3fd05d94d07fbf60ed4f6d
--- /dev/null
+++ b/lpips-tensorflow/requirements.txt
@@ -0,0 +1,2 @@
+numpy
+six
diff --git a/lpips-tensorflow/setup.py b/lpips-tensorflow/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fc8d1d35b0b9249c7c0e24dd14efc707b873d92
--- /dev/null
+++ b/lpips-tensorflow/setup.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python
+
+from distutils.core import setup
+
+setup(
+      name='lpips-tf',
+      description='Tensorflow port for the Learned Perceptual Image Patch Similarity (LPIPS) metric',
+      author='Alex Lee',
+      url='https://github.com/alexlee-gk/lpips-tensorflow/',
+      py_modules=['lpips_tf']
+)
diff --git a/lpips-tensorflow/test_network.py b/lpips-tensorflow/test_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..c222ab931807b207a5ea26d2a7297429dcb7a02b
--- /dev/null
+++ b/lpips-tensorflow/test_network.py
@@ -0,0 +1,42 @@
+import argparse
+
+import cv2
+import numpy as np
+import tensorflow as tf
+
+import lpips_tf
+
+
+def load_image(fname):
+    image = cv2.imread(fname)
+    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+    return image.astype(np.float32) / 255.0
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model', choices=['net-lin', 'net'], default='net-lin', help='net-lin or net')
+    parser.add_argument('--net', choices=['squeeze', 'alex', 'vgg'], default='alex', help='squeeze, alex, or vgg')
+    parser.add_argument('--version', type=str, default='0.1')
+    args = parser.parse_args()
+
+    ex_ref = load_image('./PerceptualSimilarity/imgs/ex_ref.png')
+    ex_p0 = load_image('./PerceptualSimilarity/imgs/ex_p0.png')
+    ex_p1 = load_image('./PerceptualSimilarity/imgs/ex_p1.png')
+
+    session = tf.Session()
+
+    image0_ph = tf.placeholder(tf.float32)
+    image1_ph = tf.placeholder(tf.float32)
+    lpips_fn = session.make_callable(
+        lpips_tf.lpips(image0_ph, image1_ph, model=args.model, net=args.net, version=args.version),
+        [image0_ph, image1_ph])
+
+    ex_d0 = lpips_fn(ex_ref, ex_p0)
+    ex_d1 = lpips_fn(ex_ref, ex_p1)
+
+    print('Distances: (%.3f, %.3f)' % (ex_d0, ex_d1))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/metadata.py b/metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..b65db76d71a549b3405f9f10b26dcccb09b30230
--- /dev/null
+++ b/metadata.py
@@ -0,0 +1,330 @@
+""" 
+Classes and routines to retrieve and handle meta-data
+"""
+
+import os
+import sys
+import numpy as np
+import json
+from netCDF4 import Dataset
+
+class MetaData:
+    """
+     Class for handling, storing and retrieving meta-data
+    """
+    
+    def __init__(self,json_file=None,suffix_indir=None,data_filename=None,slices=None,variables=None):
+        
+        """
+         Initailizes MetaData instance by reading a corresponding json-file or by handling arguments of the Preprocessing step
+         (i.e. exemplary input file, slices defining region of interest, input variables) 
+        """
+        
+        method_name = MetaData.__init__.__name__+" of Class "+MetaData.__name__
+        
+        if not json_file is None: 
+            MetaData.get_metadata_from_file(json_file)
+            
+        else:
+            # No dictionary from json-file available, all other arguments have to set
+            if not suffix_indir:
+                raise TypeError(method_name+": 'suffix_indir'-argument is required if 'json_file' is not passed.")
+            else:
+                if not isinstance(suffix_indir,str):
+                    raise TypeError(method_name+": 'suffix_indir'-argument must be a string.")
+            
+            if not data_filename:
+                raise TypeError(method_name+": 'data_filename'-argument is required if 'json_file' is not passed.")
+            else:
+                if not isinstance(data_filename,str):
+                    raise TypeError(method_name+": 'data_filename'-argument must be a string.")
+                
+            if not slices:
+                raise TypeError(method_name+": 'slices'-argument is required if 'json_file' is not passed.")
+            else:
+                if not isinstance(slices,dict):
+                    raise TypeError(method_name+": 'slices'-argument must be a dictionary.")
+            
+            if not variables:
+                raise TypeError(method_name+": 'variables'-argument is required if 'json_file' is not passed.")
+            else:
+                if not isinstance(variables,list):
+                    raise TypeError(method_name+": 'variables'-argument must be a list.")       
+            
+            MetaData.get_and_set_metadata_from_file(self,suffix_indir,data_filename,slices,variables)
+            
+            MetaData.write_metadata_to_file(self)
+            
+
+    def get_and_set_metadata_from_file(self,suffix_indir,datafile_name,slices,variables):
+        """
+         Retrieves several meta data from netCDF-file and sets corresponding class instance attributes.
+         Besides, the name of the experiment directory is constructed following the naming convention (see below)
+        
+         Naming convention:
+         [model_base]_Y[yyyy]to[yyyy]M[mm]to[mm]-[nx]x[ny]-[nnnn]N[eeee]E-[var1]_[var2]_(...)_[varN]
+         ---------------- Given ----------------|---------------- Created dynamically --------------
+        
+         Note that the model-base as well as the date-identifiers must already be included in target_dir_in.
+        """
+        
+        method_name = MetaData.get_and_set_metadata_from_file.__name__+" of Class "+MetaData.__name__
+        
+        if not suffix_indir: raise ValueError(method_name+": suffix_indir must be a non-empty path.")
+    
+        # retrieve required information from file 
+        flag_coords = ["N", "E"]
+ 
+        print("Retrieve metadata based on file: '"+datafile_name+"'")
+        try:
+            datafile = Dataset(datafile_name,'r')
+        except:
+            print(method_name + ": Error when handling data file: '"+datafile_name+"'.")
+            exit()
+        
+        # Check if all requested variables can be obtained from datafile
+        MetaData.check_datafile(datafile,variables)
+        self.varnames    = variables
+        
+        
+        self.nx, self.ny = np.abs(slices['lon_e'] - slices['lon_s']), np.abs(slices['lat_e'] - slices['lat_s'])    
+        sw_c             = [float(datafile.variables['lat'][slices['lat_e']-1]),float(datafile.variables['lon'][slices['lon_s']])]                # meridional axis lat is oriented from north to south (i.e. monotonically decreasing)
+        self.sw_c        = sw_c
+        
+        # Now start constructing exp_dir-string
+        # switch sign and coordinate-flags to avoid negative values appearing in exp_dir-name
+        if sw_c[0] < 0.:
+            sw_c[0] = np.abs(sw_c[0])
+            flag_coords[0] = "S"
+        if sw_c[1] < 0.:
+            sw_c[1] = np.abs(sw_c[1])
+            flag_coords[1] = "W"
+        nvar     = len(variables)
+        
+        # splitting has to be done in order to retrieve the expname-suffix (and the year if required)
+        path_parts = os.path.split(suffix_indir.rstrip("/"))
+        
+        if (is_integer(path_parts[1])):
+            year = path_parts[1]
+            path_parts = os.path.split(path_parts[0].rstrip("/"))
+        else:
+            year = ""
+        
+        expdir, expname = path_parts[0], path_parts[1] 
+
+        # extend exp_dir_in successively (splitted up for better readability)
+        expname += "-"+str(self.nx) + "x" + str(self.ny)
+        expname += "-"+(("{0: 05.2f}"+flag_coords[0]+"{1:05.2f}"+flag_coords[1]).format(*sw_c)).strip().replace(".","")+"-"  
+        
+        # reduced for-loop length as last variable-name is not followed by an underscore (see above)
+        for i in range(nvar-1):
+            expname += variables[i]+"_"
+        expname += variables[nvar-1]
+        
+        self.expname = expname
+        self.expdir  = expdir
+        self.status  = ""                   # uninitialized (is set when metadata is written/compared to/with json-file, see write_metadata_to_file-method)
+
+    # ML 2020/04/24 E         
+    
+    def write_metadata_to_file(self,dest_dir = None):
+        
+        """
+         Write meta data attributes of class instance to json-file.
+        """
+        
+        method_name = MetaData.write_metadata_to_file.__name__+" of Class "+MetaData.__name__
+        # actual work:
+        meta_dict = {"expname": self.expname,
+                     "expdir" : self.expdir}
+        
+        meta_dict["sw_corner_frame"] = {
+            "lat" : self.sw_c[0],
+            "lon" : self.sw_c[1]
+            }
+        
+        meta_dict["frame_size"] = {
+            "nx" : int(self.nx),
+            "ny" : int(self.ny)
+            }
+        
+        meta_dict["variables"] = []
+        for i in range(len(self.varnames)):
+            print(self.varnames[i])
+            meta_dict["variables"].append( 
+                    {"var"+str(i+1) : self.varnames[i]})
+        
+        # create directory if required
+        if dest_dir is None: 
+            dest_dir = os.path.join(self.expdir,self.expname)
+        if not os.path.exists(dest_dir):
+            print("Created experiment directory: '"+self.expdir+"'")
+            os.makedirs(dest_dir,exist_ok=True)            
+            
+        meta_fname = os.path.join(dest_dir,"metadata.json")
+
+        if os.path.exists(meta_fname):                      # check if a metadata-file already exists and check its content 
+            self.status = "old"                             # set status to old in order to prevent repeated modification of shell-/Batch-scripts
+            with open(meta_fname,'r') as js_file:
+                dict_dupl = json.load(js_file)
+                
+                if dict_dupl != meta_dict:
+                    print(method_name+": Already existing metadata (see '"+meta_fname+") do not fit data being processed right now. Ensure a common data base.")
+                    sys.exit(1)
+                else: #do not need to do anything
+                    pass
+        else:
+            # write dictionary to file
+            print(method_name+": Write dictionary to json-file: '"+meta_fname+"'")
+            with open(meta_fname,'w') as js_file:
+                json.dump(meta_dict,js_file)
+            self.status = "new"                             # set status to new in order to trigger modification of shell-/Batch-scripts
+        
+    def get_metadata_from_file(self,js_file):
+        
+        """
+         Retrieves meta data attributes from json-file
+        """
+        
+        with open(js_file) as js_file:                
+            dict_in = json.load(js_file)
+            
+            self.exp_dir = dict_in["exp_dir"]
+            
+            self.sw_c       = [dict_in["sw_corner_frame"]["lat"],dict_in["sw_corner_frame"]["lon"] ]
+            
+            self.nx         = dict_in["frame_size"]["nx"]
+            self.ny         = dict_in["frame_size"]["ny"]
+            
+            self.variables  = [dict_in["variables"][ivar] for ivar in dict_in["variables"].keys()]   
+            
+            
+    def write_dirs_to_batch_scripts(self,batch_script):
+        
+        """
+         Expands ('known') directory-variables in batch_script by exp_dir-attribute of class instance
+        """
+        
+        paths_to_mod = ["source_dir=","destination_dir=","checkpoint_dir=","results_dir="]      # known directory-variables in batch-scripts
+        
+        with open(batch_script,'r') as file:
+            data = file.readlines()
+            
+        nlines = len(data)
+        matched_lines = [iline for iline in range(nlines) if any(str_id in data[iline] for str_id in paths_to_mod)]   # list of line-number indices to be modified 
+
+        for i in matched_lines:
+            data[i] = add_str_to_path(data[i],self.expname)
+
+        
+        with open(batch_script,'w') as file:
+            file.writelines(data)
+    
+    @staticmethod
+    def write_destdir_jsontmp(dest_dir, tmp_dir = None):        
+        """
+          Writes dest_dir to temporary json-file (temp.json) stored in the current working directory.
+        """
+        
+        if not tmp_dir: tmp_dir = os.getcwd()
+        
+        file_tmp = os.path.join(tmp_dir,"temp.json")
+        dict_tmp = {"destination_dir": dest_dir}
+        
+        with open(file_tmp,"w") as js_file:
+            print("Save destination_dir-variable in temporary json-file: '"+file_tmp+"'")
+            json.dump(dict_tmp,js_file)
+            
+    @staticmethod
+    def get_destdir_jsontmp(tmp_dir = None):
+        """
+          Retrieves dest_dir from temporary json-file which is expected to exist in the current working directory and returns it.
+        """
+        
+        method_name = MetaData.get_destdir_jsontmp.__name__+" of Class "+MetaData.__name__
+
+        if not tmp_dir: tmp_dir = os.getcwd()
+        
+        file_tmp = os.path.join(tmp_dir,"temp.json")
+        
+        try:
+            with open(file_tmp,"r") as js_file:
+                dict_tmp = json.load(js_file)
+        except:
+            print(method_name+": Could not open requested json-file '"+file_tmp+"'")
+            sys.exit(1)
+            
+        if not "destination_dir" in dict_tmp.keys():
+            raise Exception(method_name+": Could not find 'destination_dir' in dictionary obtained from "+file_tmp)
+        else:
+            return(dict_tmp.get("destination_dir"))
+    
+    
+    @staticmethod
+    def issubset(a,b):
+        """
+        Checks if all elements of a exist in b or vice versa (depends on the length of the corresponding lists/sets)
+        """  
+        
+        if len(a) > len(b):
+            return(set(b).issubset(set(a)))
+        elif len(b) >= len(a):
+            return(set(a).issubset(set(b)))
+    
+    @staticmethod
+    def check_datafile(datafile,varnames):
+        """
+          Checks if all varnames can be found in datafile
+        """
+        
+        if not MetaData.issubset(varnames,datafile.variables.keys()):
+            for i in range(len(varnames2check)):
+                if not varnames2check[i] in f0.variables.keys():
+                    print("Variable '"+varnames2check[i]+"' not found in datafile '"+data_filenames[0]+"'.")
+                raise ValueError("Could not find the above mentioned variables.")
+        else:
+            pass
+
+
+        
+# ----------------------------------- end of class MetaData -----------------------------------
+
+# some auxilary functions which are not bound to MetaData-class 
+
+def add_str_to_path(path_in,add_str):
+    
+    """
+        Adds add_str to path_in if path_in does not already end with add_str.
+        Function is also capable to handle carriage returns for handling input-strings obtained by reading a file.
+    """
+    
+    l_linebreak = path_in.endswith("\n")   # flag for carriage return at the end of input string
+    line_str    = path_in.rstrip("\n")
+    
+    if (not line_str.endswith(add_str)) or \
+       (not line_str.endswith(add_str.rstrip("/"))):
+        
+        line_str = line_str + add_str + "/"
+    else:
+        print(add_str+" is already part of "+line_str+". No change is performed.")
+    
+    if l_linebreak:                     # re-add carriage return to string if required
+        return(line_str+"\n")
+    else:
+        return(line_str)
+                       
+                       
+def is_integer(n):
+    '''
+    Checks if input string is numeric and of type integer.
+    '''
+    try:
+        float(n)
+    except ValueError:
+        return False
+    else:
+        return float(n).is_integer()                       
+                       
+    
+        
+        
diff --git a/pretrained_models/download_model.sh b/pretrained_models/download_model.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fdffd762b334f709edc9369b1d7c69268b2c43b4
--- /dev/null
+++ b/pretrained_models/download_model.sh
@@ -0,0 +1,73 @@
+#!/usr/bin/env bash
+
+# exit if any command fails
+set -e
+
+if [ "$#" -ne 2 ]; then
+  echo "Usage: $0 DATASET_NAME MODEL_NAME" >&2
+  exit 1
+fi
+DATASET_NAME=$1
+MODEL_NAME=$2
+
+declare -A model_name_to_fname
+if [ ${DATASET_NAME} = "bair_action_free" ]; then
+  model_name_to_fname=(
+    [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1
+    [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1
+    [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2
+    [ours_gan]=${DATASET_NAME}_ours_gan
+    [ours_savp]=${DATASET_NAME}_ours_savp
+    [ours_vae]=${DATASET_NAME}_ours_vae_l1
+    [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1
+    [ours_vae_l2]=${DATASET_NAME}_ours_vae_l2
+    [sv2p_time_invariant]=${DATASET_NAME}_sv2p_time_invariant
+  )
+elif [ ${DATASET_NAME} = "kth" ]; then
+  model_name_to_fname=(
+    [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1
+    [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1
+    [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2
+    [ours_gan]=${DATASET_NAME}_ours_gan
+    [ours_savp]=${DATASET_NAME}_ours_savp
+    [ours_vae]=${DATASET_NAME}_ours_vae_l1
+    [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1
+    [sv2p_time_invariant]=${DATASET_NAME}_sv2p_time_invariant
+    [sv2p_time_variant]=${DATASET_NAME}_sv2p_time_variant
+  )
+elif [ ${DATASET_NAME} = "bair" ]; then
+  model_name_to_fname=(
+    [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1
+    [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1
+    [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2
+    [ours_gan]=${DATASET_NAME}_ours_gan
+    [ours_savp]=${DATASET_NAME}_ours_savp
+    [ours_vae]=${DATASET_NAME}_ours_vae_l1
+    [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1
+    [ours_vae_l2]=${DATASET_NAME}_ours_vae_l2
+    [sna_l1]=${DATASET_NAME}_sna_l1
+    [sna_l2]=${DATASET_NAME}_sna_l2
+    [sv2p_time_variant]=${DATASET_NAME}_sv2p_time_variant
+  )
+else
+  echo "Invalid dataset name: '${DATASET_NAME}' (choose from 'bair_action_free', 'kth', 'bair)" >&2
+  exit 1
+fi
+
+if ! [[ ${model_name_to_fname[${MODEL_NAME}]} ]]; then
+  echo "Invalid model name '${MODEL_NAME}' when dataset name is '${DATASET_NAME}'. Valid mode names are:" >&2
+  for model_name in "${!model_name_to_fname[@]}"; do
+    echo "'${model_name}'" >&2
+  done
+  exit 1
+fi
+TARGET_DIR=./pretrained_models/${DATASET_NAME}/${MODEL_NAME}
+mkdir -p ${TARGET_DIR}
+TAR_FNAME=${model_name_to_fname[${MODEL_NAME}]}.tar.gz
+URL=http://rail.eecs.berkeley.edu/models/savp/pretrained_models/${TAR_FNAME}
+echo "Downloading '${TAR_FNAME}'"
+wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME}
+tar -xvf ${TARGET_DIR}/${TAR_FNAME} -C ${TARGET_DIR}
+rm ${TARGET_DIR}/${TAR_FNAME}
+
+echo "Succesfully finished downloading pretrained model '${MODEL_NAME}' on dataset '${DATASET_NAME}' into directory ${TARGET_DIR}"
diff --git a/scripts/Analysis_all.py b/scripts/Analysis_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ed4b6b666634abafbd18af8517cf5b81b75bd61
--- /dev/null
+++ b/scripts/Analysis_all.py
@@ -0,0 +1,105 @@
+import pickle
+import os
+from matplotlib.pylab import plt
+
+# results_path = ["results_test_samples/era5_size_64_64_3_norm_dup_pretrained/ours_savp","results_test_samples/era5_size_64_64_3_norm_dup_pretrained_finetune/ours_savp",
+#                "results_test_samples/era5_size_64_64_3_norm_dup_pretrained_gan/kth_ours_gan","results_test_samples/era5_size_64_64_3_norm_dup_pretrained_vae_l1/kth_ours_vae_l1"]
+#
+# model_names = ["SAVP","SAVP_Finetune","GAN","VAE"]
+
+
+# results_path = ["results_test_samples/era5_size_64_64_3_norm_dup_pretrained/ours_savp","results_test_samples/era5_size_64_64_3_norm_msl_gph_pretrained_savp/ours_savp",
+#                "results_test_samples/era5_size_64_64_3_norm_dup_pretrained_gan/kth_ours_gan","results_test_samples/era5_size_64_64_3_norm_msl_gph_pretrained_gan/kth_ours_gan"]
+#
+# model_names = ["SAVP_3T","SAVP_T-MSL-GPH","GAN_3T","GAN_T-MSL_GPH"]
+#
+# results_path = ["results_test_samples/era5_size_64_64_3_norm_dup_pretrained/ours_savp","results_test_samples/era5_size_64_64_3_norm_dup/ours_savp",
+#                 "results_test_samples/era5_size_64_64_3_norm_dup_pretrained/kth_ours_gan","results_test_samples/era5_size_64_64_3_norm_dup/ours_gan",
+#                 "results_test_samples/era5_size_64_64_3_norm_dup_pretrained/kth_ours_vae_l1","results_test_samples/era5_size_64_64_3_norm_dup/ours_vae_l1"]
+# model_names = ["TF-SAVP(KTH)","SAVP (3T)","TF-GAN(KTH)","GAN (3T)","TF-VAE (KTH)","VAE (3T)"]
+
+##
+##results_path = ["results_test_samples/era5_size_64_64_3_norm_t_msl_gph/ours_savp", "results_test_samples/era5_size_64_64_3_norm_dup/ours_savp",
+##                "results_test_samples/era5_size_64_64_3_norm_t_msl_gph/ours_gan","results_test_samples/era5_size_64_64_3_norm_dup/ours_gan"]
+##model_names = ["SAVP(T-MSL-GPH)", "SAVP (3T)", "GAN (T-MSL-GPH)","GAN (3T)"]
+
+##results_path = ["results_test_samples/era5_size_64_64_3_norm_t_msl_gph/ours_savp", "results_test_samples/era5_size_64_64_3_norm_dup/ours_savp",
+##                "results_test_samples/era5_size_64_64_3_norm_t_msl_gph/ours_gan","results_test_samples/era5_size_64_64_3_norm_dup/ours_gan"]
+##model_names = ["SAVP(T-MSL-GPH)", "SAVP (3T)", "GAN (T-MSL-GPH)","GAN (3T)"]
+##
+##mse_all = []
+##psnr_all = []
+##ssim_all = []
+##for path in results_path:
+##    p = os.path.join(path,"results.pkl")
+##    result = pickle.load(open(p,"rb"))
+##    mse = result["mse"]
+##    psnr = result["psnr"]
+##    ssim = result["ssim"]
+##    mse_all.append(mse)
+##    psnr_all.append(psnr)
+##    ssim_all.append(ssim)
+##
+##
+##def get_metric(metrtic):
+##    if metric == "mse":
+##        return mse_all
+##    elif metric == "psnr":
+##        return psnr_all
+##    elif metric == "ssim":
+##        return ssim_all
+##    else:
+##        raise("Metric error")
+##
+##for metric in ["mse","psnr","ssim"]:
+##    evals = get_metric(metric)
+##    timestamp = list(range(1,11))
+##    fig = plt.figure()
+##    plt.plot(timestamp, evals[0],'-.',label=model_names[0])
+##    plt.plot(timestamp, evals[1],'--',label=model_names[1])
+##    plt.plot(timestamp, evals[2],'-',label=model_names[2])
+##    plt.plot(timestamp, evals[3],'--.',label=model_names[3])
+##    # plt.plot(timestamp, evals[4],'*-.',label=model_names[4])
+##    # plt.plot(timestamp, evals[5],'--*',label=model_names[5])
+##    if metric == "mse":
+##        plt.legend(loc="upper left")
+##    else:
+##        plt.legend(loc = "upper right")
+##    plt.xlabel("Timestamps")
+##    plt.ylabel(metric)
+##    plt.title(metric,fontsize=15)
+##    plt.savefig(metric + "2.png")
+##    plt.clf()
+
+
+
+#persistent analysis
+persistent_mse_all = []
+persistent_psnr_all = []
+persistent_ssim_all = []
+mse_all = []
+psnr_all = []
+ssim_all = []
+results_root_path = "/p/scratch/deepacf/video_prediction_shared_folder/results/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/ours_gan"
+p1 = os.path.join(results_root_path,"results.pkl")
+result1 = pickle.load(open(p1,"rb"))
+p2 = os.path.join(results_root_path,"persistent_results.pkl")
+result2 = pickle.load(open(p2,"rb"))
+mse = result1["mse"]
+psnr = result1["psnr"]
+ssim = result1["ssim"]
+mse_all.append(mse)
+psnr_all.append(psnr)
+ssim_all.append(ssim)
+
+persistent_mse = result2["mse"]
+persistent_psnr = result2["psnr"]
+persistent_ssim = result2["ssim"]
+persistent_mse_all.append(persistent_mse)
+persistent_psnr_all.append(persistent_psnr)
+persistent_ssim_all.append(persistent_ssim)
+
+
+
+print("persistent_mse",persistent_mse_all)
+print("mse",mse_all)
diff --git a/scripts/combine_results.py b/scripts/combine_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b684ac1212350c8d0c630884baeb34154dd1fa3
--- /dev/null
+++ b/scripts/combine_results.py
@@ -0,0 +1,258 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import glob
+import itertools
+import os
+
+import cv2
+import numpy as np
+
+from video_prediction.utils import html
+from video_prediction.utils.ffmpeg_gif import save_gif as ffmpeg_save_gif
+
+
+def load_metrics(prefix_fname):
+    import csv
+    with open('%s.csv' % prefix_fname, newline='') as csvfile:
+        reader = csv.reader(csvfile, delimiter='\t', quotechar='|')
+        rows = list(reader)
+        # skip header (first row), indices (first column), and means (last column)
+        metrics = np.array(rows)[1:, 1:-1].astype(np.float32)
+    return metrics
+
+
+def load_images(image_fnames):
+    images = []
+    for image_fname in image_fnames:
+        image = cv2.imread(image_fname)
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+        images.append(image)
+    return images
+
+
+def save_images(image_fnames, images):
+    head, tail = os.path.split(image_fnames[0])
+    if head and not os.path.exists(head):
+        os.makedirs(head)
+    for image_fname, image in zip(image_fnames, images):
+        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+        cv2.imwrite(image_fname, image)
+
+
+def save_gif(gif_fname, images, fps=4):
+    import moviepy.editor as mpy
+    head, tail = os.path.split(gif_fname)
+    if head and not os.path.exists(head):
+        os.makedirs(head)
+    clip = mpy.ImageSequenceClip(list(images), fps=fps)
+    clip.write_gif(gif_fname)
+
+
+def concat_images(all_images):
+    """
+    all_images is a list of lists of images
+    """
+    min_height, min_width = None, None
+    for all_image in all_images:
+        for image in all_image:
+            if min_height is None or min_width is None:
+                min_height, min_width = image.shape[:2]
+            else:
+                min_height = min(min_height, image.shape[0])
+                min_width = min(min_width, image.shape[1])
+
+    def maybe_resize(image):
+        if image.shape[:2] != (min_height, min_width):
+            image = cv2.resize(image, (min_height, min_width))
+        return image
+
+    resized_all_images = []
+    for all_image in all_images:
+        resized_all_image = [maybe_resize(image) for image in all_image]
+        resized_all_images.append(resized_all_image)
+    all_images = resized_all_images
+    all_images = [np.concatenate(all_image, axis=1) for all_image in zip(*all_images)]
+    return all_images
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("results_dir", type=str)
+    parser.add_argument("--method_dirs", type=str, nargs='+', help='directories in results_dir (all of them by default)')
+    parser.add_argument("--method_names", type=str, nargs='+', help='method names for the header')
+    parser.add_argument("--web_dir", type=str, help='default is results_dir/web')
+    parser.add_argument("--sort_by", type=str, nargs=2, help='task and metric name to sort by, e.g. prediction mse')
+    parser.add_argument("--no_ffmpeg", action='store_true')
+    parser.add_argument("--batch_size", type=int, default=1, help="number of samples in batch")
+    parser.add_argument("--num_samples", type=int, help="number of samples for the table of sequence (all of them by default)")
+    parser.add_argument("--show_se", action='store_true', help="show standard error in the table metrics")
+    parser.add_argument("--only_metrics", action='store_true')
+    args = parser.parse_args()
+
+    if args.web_dir is None:
+        args.web_dir = os.path.join(args.results_dir, 'web')
+    webpage = html.HTML(args.web_dir, 'Experiment name = %s' % os.path.normpath(args.results_dir), reflesh=1)
+    webpage.add_header1(os.path.normpath(args.results_dir))
+
+    if args.method_dirs is None:
+        unsorted_method_dirs = os.listdir(args.results_dir)
+        # exclude web_dir and all directories that starts with web
+        if args.web_dir in unsorted_method_dirs:
+            unsorted_method_dirs.remove(args.web_dir)
+        unsorted_method_dirs = [method_dir for method_dir in unsorted_method_dirs if not os.path.basename(method_dir).startswith('web')]
+        # put ground_truth and repeat in the front (if any)
+        method_dirs = []
+        for first_method_dir in ['ground_truth', 'repeat']:
+            if first_method_dir in unsorted_method_dirs:
+                unsorted_method_dirs.remove(first_method_dir)
+                method_dirs.append(first_method_dir)
+        method_dirs.extend(sorted(unsorted_method_dirs))
+    else:
+        method_dirs = list(args.method_dirs)
+    if args.method_names is None:
+        method_names = list(method_dirs)
+    else:
+        method_names = list(args.method_names)
+    method_dirs = [os.path.join(args.results_dir, method_dir) for method_dir in method_dirs]
+
+    if args.sort_by:
+        task_name, metric_name = args.sort_by
+        sort_criterion = []
+        for method_id, (method_name, method_dir) in enumerate(zip(method_names, method_dirs)):
+            metric = load_metrics(os.path.join(method_dir, task_name, 'metrics', metric_name))
+            sort_criterion.append(np.mean(metric))
+        sort_criterion, method_ids, method_names, method_dirs = \
+            zip(*sorted(zip(sort_criterion, range(len(method_names)), method_names, method_dirs)))
+        webpage.add_header3('sorted by %s, %s' % tuple(args.sort_by))
+    else:
+        method_ids = range(len(method_names))
+
+    # infer task and metric names from first method
+    metric_fnames = sorted(glob.glob('%s/*/metrics/*.csv' % glob.escape(method_dirs[0])))
+    task_names = []
+    metric_names = []
+    for metric_fname in metric_fnames:
+        head, tail = os.path.split(metric_fname)
+        task_name = head.split('/')[-2]
+        metric_name, _ = os.path.splitext(tail)
+        task_names.append(task_name)
+        metric_names.append(metric_name)
+
+    # save metrics
+    webpage.add_table()
+    header_txts = ['']
+    header_colspans = [2]
+    for task_name in task_names:
+        if task_name != header_txts[-1]:
+            header_txts.append(task_name)
+            header_colspans.append(2 if args.show_se else 1)  # mean and standard error for each task
+        else:
+            # group consecutive task names that are the same
+            header_colspans[-1] += 2 if args.show_se else 1
+    webpage.add_row(header_txts, header_colspans)
+    subheader_txts = ['id', 'method']
+    for task_name, metric_name in zip(task_names, metric_names):
+        subheader_txts.append('%s (mean)' % metric_name)
+        if args.show_se:
+            subheader_txts.append('%s (se)' % metric_name)
+    webpage.add_row(subheader_txts)
+    all_metric_means = []
+    for method_id, method_name, method_dir in zip(method_ids, method_names, method_dirs):
+        metric_txts = [method_id, method_name]
+        metric_means = []
+        for task_name, metric_name in zip(task_names, metric_names):
+            metric = load_metrics(os.path.join(method_dir, task_name, 'metrics', metric_name))
+            metric_mean = np.mean(metric)
+            num_samples = len(metric)
+            metric_se = np.std(metric) / np.sqrt(num_samples)
+            metric_txts.append('%.4f' % metric_mean)
+            if args.show_se:
+                metric_txts.append('%.4f' % metric_se)
+            metric_means.append(metric_mean)
+        webpage.add_row(metric_txts)
+        all_metric_means.append(metric_means)
+    webpage.save()
+
+    if args.only_metrics:
+        return
+
+    # infer task names from first method
+    outputs_dirs = sorted(glob.glob('%s/*/outputs' % glob.escape(method_dirs[0])))
+    task_names = [outputs_dir.split('/')[-2] for outputs_dir in outputs_dirs]
+
+    # save image sequences
+    image_dir = os.path.join(args.web_dir, 'images')
+    webpage.add_table()
+    header_txts = ['']
+    subheader_txts = ['id']
+    methods_subheader_txts = ['']
+    header_colspans = [1]
+    subheader_colspans = [1]
+    methods_subheader_colspans = [1]
+    num_samples = args.num_samples or num_samples
+    for sample_ind in range(num_samples):
+        if sample_ind % args.batch_size == 0:
+            print("saving samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))
+        ims = [None]
+        txts = [sample_ind]
+        links = [None]
+        colspans = [1]
+        for task_name in task_names:
+            # load input images from first method
+            input_fnames = sorted(glob.glob('%s/inputs/*_%05d_??.png' %
+                                            (glob.escape(os.path.join(method_dirs[0], task_name)), sample_ind)))
+            input_images = load_images(input_fnames)
+            # save input images as image sequence
+            input_fnames = [os.path.join(task_name, 'inputs', os.path.basename(input_fname)) for input_fname in input_fnames]
+            save_images([os.path.join(image_dir, input_fname) for input_fname in input_fnames], input_images)
+            # infer output names from first method
+            output_fnames = sorted(glob.glob('%s/outputs/*_%05d_??.png' %
+                                             (glob.escape(os.path.join(method_dirs[0], task_name)), sample_ind)))
+            output_names = sorted(set(os.path.splitext(os.path.basename(output_fname))[0][:-9]
+                                      for output_fname in output_fnames))  # remove _?????_??.png
+            # load output images
+            all_output_images = []
+            for output_name in output_names:
+                for method_name, method_dir in zip(method_names, method_dirs):
+                    output_fnames = sorted(glob.glob('%s/outputs/%s_%05d_??.png' %
+                                                     (glob.escape(os.path.join(method_dir, task_name)),
+                                                      output_name, sample_ind)))
+                    output_images = load_images(output_fnames)
+                    all_output_images.append(output_images)
+            # concatenate output images of all the methods
+            all_output_images = concat_images(all_output_images)
+            # save output images as image sequence or as gif clip
+            output_fname = os.path.join(task_name, 'outputs', '%s_%05d.gif' % ('_'.join(output_names), sample_ind))
+            if args.no_ffmpeg:
+                save_gif(os.path.join(image_dir, output_fname), all_output_images, fps=4)
+            else:
+                ffmpeg_save_gif(os.path.join(image_dir, output_fname), all_output_images, fps=4)
+
+            if sample_ind == 0:
+                header_txts.append(task_name)
+                subheader_txts.extend(['inputs', 'outputs'])
+                header_colspans.append(len(input_fnames) + len(method_ids) * len(output_names))
+                subheader_colspans.extend([len(input_fnames), len(method_ids) * len(output_names)])
+                method_id_strs = ['%02d' % method_id for method_id in method_ids]
+                methods_subheader_txts.extend([''] + list(itertools.chain(*[method_id_strs] * len(output_names))))
+                methods_subheader_colspans.extend([len(input_fnames)] + [1] * (len(method_ids) * len(output_names)))
+            ims.extend(input_fnames + [output_fname])
+            txts.extend([None] * (len(input_fnames) + 1))
+            links.extend(input_fnames + [output_fname])
+            colspans.extend([1] * len(input_fnames) + [len(method_ids) * len(output_names)])
+
+        if sample_ind == 0:
+            webpage.add_row(header_txts, header_colspans)
+            webpage.add_row(subheader_txts, subheader_colspans)
+            webpage.add_row(methods_subheader_txts, methods_subheader_colspans)
+        webpage.add_images(ims, txts, links, colspans, height=64, width=None)
+        if (sample_ind + 1) % args.batch_size == 0:
+            webpage.save()
+    webpage.save()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/evaluate.py b/scripts/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f4062f7e3abe4574fa434cfbb2bc66a1b63eed3
--- /dev/null
+++ b/scripts/evaluate.py
@@ -0,0 +1,318 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+import argparse
+import csv
+import errno
+import json
+import os
+import random
+
+import numpy as np
+import tensorflow as tf
+
+from video_prediction import datasets, models
+
+
+def save_image_sequence(prefix_fname, images, time_start_ind=0):
+    import cv2
+    head, tail = os.path.split(prefix_fname)
+    if head and not os.path.exists(head):
+        os.makedirs(head)
+    for t, image in enumerate(images):
+        image_fname = '%s_%02d.png' % (prefix_fname, time_start_ind + t)
+        image = (image * 255.0).astype(np.uint8)
+        if image.shape[-1] == 1:
+            image = np.tile(image, (1, 1, 3))
+        else:
+            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+        cv2.imwrite(image_fname, image)
+
+
+def save_image_sequences(prefix_fname, images, sample_start_ind=0, time_start_ind=0):
+    head, tail = os.path.split(prefix_fname)
+    if head and not os.path.exists(head):
+        os.makedirs(head)
+    for i, images_ in enumerate(images):
+        images_fname = '%s_%05d' % (prefix_fname, sample_start_ind + i)
+        save_image_sequence(images_fname, images_, time_start_ind=time_start_ind)
+
+
+def save_metrics(prefix_fname, metrics, sample_start_ind=0):
+    head, tail = os.path.split(prefix_fname)
+    if head and not os.path.exists(head):
+        os.makedirs(head)
+    assert metrics.ndim == 2
+    file_mode = 'w' if sample_start_ind == 0 else 'a'
+    with open('%s.csv' % prefix_fname, file_mode, newline='') as csvfile:
+        writer = csv.writer(csvfile, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
+        if sample_start_ind == 0:
+            writer.writerow(map(str, ['sample_ind'] + list(range(metrics.shape[1])) + ['mean']))
+        for i, metrics_row in enumerate(metrics):
+            writer.writerow(map(str, [sample_start_ind + i] + list(metrics_row) + [np.mean(metrics_row)]))
+
+
+def load_metrics(prefix_fname):
+    with open('%s.csv' % prefix_fname, newline='') as csvfile:
+        reader = csv.reader(csvfile, delimiter='\t', quotechar='|')
+        rows = list(reader)
+        # skip header (first row), indices (first column), and means (last column)
+        metrics = np.array(rows)[1:, 1:-1].astype(np.float32)
+    return metrics
+
+
+def merge_hparams(hparams0, hparams1):
+    hparams0 = hparams0 or []
+    hparams1 = hparams1 or []
+    if not isinstance(hparams0, (list, tuple)):
+        hparams0 = [hparams0]
+    if not isinstance(hparams1, (list, tuple)):
+        hparams1 = [hparams1]
+    hparams = list(hparams0) + list(hparams1)
+    # simplify into the content if possible
+    if len(hparams) == 1:
+        hparams, = hparams
+    return hparams
+
+
+def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None):
+    sequence_length = model_hparams.sequence_length
+    context_frames = model_hparams.context_frames
+    future_length = sequence_length - context_frames
+
+    context_images = results['images'][:, :context_frames]
+
+    if 'eval_diversity' in results:
+        metric = results['eval_diversity']
+        metric_name = 'diversity'
+        subtask_dir = task_dir + '_%s' % metric_name
+        save_metrics(os.path.join(subtask_dir, 'metrics', metric_name),
+                     metric, sample_start_ind=sample_start_ind)
+
+    subtasks = subtasks or ['max']
+    for subtask in subtasks:
+        metric_names = []
+        for k in results.keys():
+            if re.match('eval_(\w+)/%s' % subtask, k) and not re.match('eval_gen_images_(\w+)/%s' % subtask, k):
+                m = re.match('eval_(\w+)/%s' % subtask, k)
+                metric_names.append(m.group(1))
+        for metric_name in metric_names:
+            subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask)
+            gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images'))
+            # only keep the future frames
+            gen_images = gen_images[:, -future_length:]
+            metric = results['eval_%s/%s' % (metric_name, subtask)]
+            save_metrics(os.path.join(subtask_dir, 'metrics', metric_name),
+                         metric, sample_start_ind=sample_start_ind)
+            if only_metrics:
+                continue
+
+            save_image_sequences(os.path.join(subtask_dir, 'inputs', 'context_image'),
+                                 context_images, sample_start_ind=sample_start_ind)
+            save_image_sequences(os.path.join(subtask_dir, 'outputs', 'gen_image'),
+                                 gen_images, sample_start_ind=sample_start_ind)
+
+
+def main():
+    """
+    results_dir
+    ├── output_dir                              # condition / method
+    │   ├── prediction_eval_lpips_max           # task: best sample in terms of LPIPS similarity
+    │   │   ├── inputs
+    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
+    │   │   │   └── ...
+    │   │   ├── outputs
+    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the future ones)
+    │   │   │   └── ...
+    │   │   └── metrics
+    │   │       └── lpips.csv
+    │   ├── prediction_eval_ssim_max            # task: best sample in terms of SSIM
+    │   │   ├── inputs
+    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
+    │   │   │   └── ...
+    │   │   ├── outputs
+    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the future ones)
+    │   │   │   └── ...
+    │   │   └── metrics
+    │   │       └── ssim.csv
+    │   └── ...
+    └── ...
+    """
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified")
+    parser.add_argument("--output_dir", help="output directory where results are saved. default is results_dir/model_fname, "
+                                             "where model_fname is the directory name of checkpoint")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+
+    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
+
+    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
+    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
+    parser.add_argument("--num_epochs", type=int, default=1)
+
+    parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min)')
+    parser.add_argument("--only_metrics", action='store_true')
+    parser.add_argument("--num_stochastic_samples", type=int, default=100)
+
+    parser.add_argument("--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset")
+    parser.add_argument("--gt_outputs_dir", type=str, help="directory containing output ground truth images for ismple dataset")
+
+    parser.add_argument("--eval_parallel_iterations", type=int, default=10)
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int, default=7)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+        args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1])
+    else:
+        if not args.dataset:
+            raise ValueError('dataset is required when checkpoint is not specified')
+        if not args.model:
+            raise ValueError('model is required when checkpoint is not specified')
+        args.output_dir = args.output_dir or os.path.join(args.results_dir, 'model.%s' % args.model)
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    dataset = VideoDataset(
+        args.input_dir,
+        mode=args.mode,
+        num_epochs=args.num_epochs,
+        seed=args.seed,
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
+
+    VideoPredictionModel = models.get_model_class(args.model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        mode=args.mode,
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams,
+        eval_num_samples=args.num_stochastic_samples,
+        eval_parallel_iterations=args.eval_parallel_iterations)
+
+    if args.num_samples:
+        if args.num_samples > dataset.num_examples_per_epoch():
+            raise ValueError('num_samples cannot be larger than the dataset')
+        num_examples_per_epoch = args.num_samples
+    else:
+        num_examples_per_epoch = dataset.num_examples_per_epoch()
+    if num_examples_per_epoch % args.batch_size != 0:
+        #bing0
+        #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
+        pass
+    #Bing if it is era 5 data we used dataset.make_batch_v2
+    #inputs = dataset.make_batch(args.batch_size)
+    inputs = dataset.make_batch_v2(args.batch_size)
+    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    with tf.variable_scope(''):
+        model.build_graph(input_phs)
+
+    output_dir = args.output_dir
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+    with open(os.path.join(output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
+    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+    sess = tf.Session(config=config)
+    sess.graph.as_default()
+
+    model.restore(sess, args.checkpoint)
+
+    sample_ind = 0
+    while True:
+        if args.num_samples and sample_ind >= args.num_samples:
+            break
+        try:
+            input_results = sess.run(inputs)
+        except tf.errors.OutOfRangeError:
+            break
+        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))
+
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        # compute "best" metrics using the computation graph
+        fetches = {'images': model.inputs['images']}
+        fetches.update(model.eval_outputs.items())
+        fetches.update(model.eval_metrics.items())
+        results = sess.run(fetches, feed_dict=feed_dict)
+        save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'),
+                                     results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks)
+        sample_ind += args.batch_size
+
+    metric_fnames = []
+    metric_names = ['psnr', 'ssim', 'lpips']
+    subtasks = ['max']
+    for metric_name in metric_names:
+        for subtask in subtasks:
+            metric_fnames.append(
+                os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name))
+
+    for metric_fname in metric_fnames:
+        task_name, _, metric_name = metric_fname.split('/')[-3:]
+        metric = load_metrics(metric_fname)
+        print('=' * 31)
+        print(task_name, metric_name)
+        print('-' * 31)
+        metric_header_format = '{:>10} {:>20}'
+        metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})'
+        print(metric_header_format.format('time step', os.path.split(metric_fname)[1]))
+        for t, (metric_mean, metric_std) in enumerate(zip(metric.mean(axis=0), metric.std(axis=0))):
+            print(metric_row_format.format(t, metric_mean, metric_std))
+        print(metric_row_format.format('mean (std)', metric.mean(), metric.std()))
+        print('=' * 31)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/evaluate_all.sh b/scripts/evaluate_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c57f5c895da22f167a0a1bb4204a225965c8a48c
--- /dev/null
+++ b/scripts/evaluate_all.sh
@@ -0,0 +1,44 @@
+# BAIR action-free robot pushing dataset
+dataset=bair_action_free
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sv2p_time_invariant \
+; do
+   CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8
+done
+
+# KTH human actions dataset
+# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence
+dataset=kth
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sv2p_time_variant \
+    sv2p_time_invariant \
+; do
+    CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/kth --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 1
+done
+
+# BAIR action-conditioned robot pushing dataset
+dataset=bair
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sna_l1 \
+    sna_l2 \
+    sv2p_time_variant \
+; do
+    CUDA_VISIBLE_DEVICES=1 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8
+done
diff --git a/scripts/evaluate_svg.sh b/scripts/evaluate_svg.sh
new file mode 100644
index 0000000000000000000000000000000000000000..212c4ba239ecdbd15c70c05b9336c32175dc8c5c
--- /dev/null
+++ b/scripts/evaluate_svg.sh
@@ -0,0 +1 @@
+#!/usr/bin/env bash
\ No newline at end of file
diff --git a/scripts/generate.py b/scripts/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1893194fe600981350a52e9880df9f6a034701d
--- /dev/null
+++ b/scripts/generate.py
@@ -0,0 +1,537 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import math
+import random
+import cv2
+import numpy as np
+import tensorflow as tf
+import seaborn as sns
+import pickle
+from random import seed
+import random
+import json
+import numpy as np
+#from six.moves import cPickle
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import matplotlib.animation as animation
+import seaborn as sns
+import pandas as pd
+from video_prediction import datasets, models
+from matplotlib.colors import LinearSegmentedColormap
+from matplotlib.ticker import MaxNLocator
+from video_prediction.utils.ffmpeg_gif import save_gif
+
+with open("./splits_size_64_64_1/geo_info.json","r") as json_file:
+    geo = json.load(json_file)
+    lat = [round(i,2) for i in geo["lat"]]
+    lon = [round(i,2) for i in geo["lon"]]
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified")
+    parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified")
+    parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified")
+    parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is "
+                                                 "results_gif_dir/model_fname")
+    parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is "
+                                                 "results_png_dir/model_fname")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+
+    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
+
+    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
+    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
+    parser.add_argument("--num_epochs", type=int, default=1)
+
+    parser.add_argument("--num_stochastic_samples", type=int, default=1) #Bing original is 5, change to 1
+    parser.add_argument("--gif_length", type=int, help="default is sequence_length")
+    parser.add_argument("--fps", type=int, default=4)
+
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int, default=7)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+    args.results_gif_dir = args.results_gif_dir or args.results_dir
+    args.results_png_dir = args.results_png_dir or args.results_dir
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1])
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1])
+    else:
+        if not args.dataset:
+            raise ValueError('dataset is required when checkpoint is not specified')
+        if not args.model:
+            raise ValueError('model is required when checkpoint is not specified')
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    dataset = VideoDataset(
+        args.input_dir,
+        mode=args.mode,
+        num_epochs=args.num_epochs,
+        seed=args.seed,
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
+    VideoPredictionModel = models.get_model_class(args.model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        mode=args.mode,
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams)
+
+    sequence_length = model.hparams.sequence_length
+    context_frames = model.hparams.context_frames
+    future_length = sequence_length - context_frames
+
+    if args.num_samples:
+        if args.num_samples > dataset.num_examples_per_epoch():
+            raise ValueError('num_samples cannot be larger than the dataset')
+        num_examples_per_epoch = args.num_samples
+    else:
+        #Bing: error occurs here, cheats a little bit here
+        #num_examples_per_epoch = dataset.num_examples_per_epoch()
+        num_examples_per_epoch = args.batch_size * 8
+    if num_examples_per_epoch % args.batch_size != 0:
+        #bing
+        #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
+        pass
+    #Bing if it is era 5 data we used dataset.make_batch_v2
+    #inputs = dataset.make_batch(args.batch_size)
+    inputs = dataset.make_batch(args.batch_size)
+    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    with tf.variable_scope(''):
+        model.build_graph(input_phs)
+
+    for output_dir in (args.output_gif_dir, args.output_png_dir):
+        if not os.path.exists(output_dir):
+            os.makedirs(output_dir)
+        with open(os.path.join(output_dir, "options.json"), "w") as f:
+            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+            f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
+        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
+            f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+    sess = tf.Session(config=config)
+    sess.graph.as_default()
+    model.restore(sess, args.checkpoint)
+    sample_ind = 0
+    gen_images_all = []
+    input_images_all = []
+
+    while True:
+        if args.num_samples and sample_ind >= args.num_samples:
+            break
+        try:
+            input_results = sess.run(inputs)
+        except tf.errors.OutOfRangeError:
+            break
+        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        for stochastic_sample_ind in range(args.num_stochastic_samples): #Todo: why use here
+            print("Stochastic sample id", stochastic_sample_ind)
+            gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
+            #input_images = sess.run(inputs["images"])
+            #Bing: Add evaluation metrics
+            # fetches = {'images': model.inputs['images']}
+            # fetches.update(model.eval_outputs.items())
+            # fetches.update(model.eval_metrics.items())
+            # results = sess.run(fetches, feed_dict = feed_dict)
+            # input_images = results["images"] #shape (batch_size,future_frames,height,width,channel)
+            # only keep the future frames
+            #gen_images = gen_images[:, -future_length:] #(8,10,64,64,1) (batch_size, sequences, height, width, channel)
+            #input_images = input_results["images"][:,-future_length:,:,:]
+            input_images = input_results["images"][:,1:,:,:,:]
+            #gen_mse_avg = results["eval_mse/avg"] #shape (batch_size,future_frames)
+            print("Finish sample ind",stochastic_sample_ind)
+            input_gen_diff_ = input_images - gen_images
+            #diff_image_range = pd.cut(input_gen_diff_.flatten(), bins = 4, labels = [-10, -5, 0, 5], right = False)
+            #diff_image_range = np.reshape(np.array(diff_image_range),input_gen_diff_.shape)
+            gen_images_all.extend(gen_images)
+            input_images_all.extend(input_images)
+
+            colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
+            cmap_name = 'my_list'
+            if sample_ind < 100:
+                for i in range(len(gen_images)):
+                    name = 'Batch_id_' + str(sample_ind) + " + Sample_" + str(i)
+                    gen_images_ = gen_images[i, :]
+                    gen_mse_avg_ = [np.mean(input_gen_diff_[i, frame, :, :, :]**2) for frame in
+                                    range(19)]  # return the list with 10 (sequence) mse
+                    input_gen_diff = input_gen_diff_ [i,:,:,:,:]
+                    input_images_ = input_images[i, :]
+                    #gen_mse_avg_ = gen_mse_avg[i, :]
+
+                    # Bing: This is to check the difference between the images and next images for debugging the freezon issues
+                    # gen_images_diff = []
+                    # for gen_idx in range(len(gen_images_) - 1):
+                    #     img_1 = gen_images_[gen_idx, :, :, :]
+                    #     img_2 = gen_images_[gen_idx + 1, :, :, :]
+                    #     img_diff = img_2 - img_1
+                    #     img_diff_nonzero = [e for img_idx, e in enumerate(img_diff.flatten()) if round(e,3) != 0.000]
+                    #     gen_images_diff.append(img_diff_nonzero)
+
+                    fig = plt.figure()
+                    gs = gridspec.GridSpec(4,6)
+                    gs.update(wspace = 0.7,hspace=0.8)
+                    ax1 = plt.subplot(gs[0:2,0:3])
+                    ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
+                    ax3 = plt.subplot(gs[2:4,0:3])
+                    ax4 = plt.subplot(gs[2:4,3:])
+                    xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
+                    ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+                    plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
+                    ax1.title.set_text("(a) Ground Truth")
+                    ax2.title.set_text("(b) SAVP")
+                    ax3.title.set_text("(c) Diff.")
+                    ax4.title.set_text("(d) MSE")
+
+                    ax1.xaxis.set_tick_params(labelsize=7)
+                    ax1.yaxis.set_tick_params(labelsize = 7)
+                    ax2.xaxis.set_tick_params(labelsize=7)
+                    ax2.yaxis.set_tick_params(labelsize = 7)
+                    ax3.xaxis.set_tick_params(labelsize=7)
+                    ax3.yaxis.set_tick_params(labelsize = 7)
+
+                    init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
+                    print("inti images shape", init_images.shape)
+                    xdata, ydata = [], []
+                    plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+                    plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+                    #x = np.linspace(0, 64, 64)
+                    #y = np.linspace(0, 64, 64)
+                    #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+                    #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+                    fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
+                    fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
+
+                    cm = LinearSegmentedColormap.from_list(
+                        cmap_name, "bwr", N = 5)
+
+                    plot3 = ax3.imshow(init_images, vmin=-10, vmax=10, cmap=cm)#cmap = 'PuBu_r',
+                    plot4, = ax4.plot([], [], color = "r")
+                    ax4.set_xlim(0, len(gen_mse_avg_)-1)
+                    ax4.set_ylim(0, 10)
+                    ax4.set_xlabel("Frames", fontsize=10)
+                    #ax4.set_ylabel("MSE", fontsize=10)
+                    ax4.xaxis.set_tick_params(labelsize=7)
+                    ax4.yaxis.set_tick_params(labelsize=7)
+
+
+                    plots = [plot1, plot2, plot3, plot4]
+
+                    #fig.colorbar(plots[1], ax = [ax1, ax2])
+
+                    fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
+                    #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
+                    #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
+
+                    def animation_sample(t):
+                        input_image = input_images_[t, :, :, 0]
+                        gen_image = gen_images_[t, :, :, 0]
+                        diff_image = input_gen_diff[t,:,:,0]
+
+                        data = gen_mse_avg_[:t + 1]
+                        # x = list(range(len(gen_mse_avg_)))[:t+1]
+                        xdata.append(t)
+                        print("xdata", xdata)
+                        ydata.append(gen_mse_avg_[t])
+
+                        print("ydata", ydata)
+                        # p = sns.lineplot(x=x,y=data,color="b")
+                        # p.tick_params(labelsize=17)
+                        # plt.setp(p.lines, linewidth=6)
+                        plots[0].set_data(input_image)
+                        plots[1].set_data(gen_image)
+                        #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+                        #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+                        plots[2].set_data(diff_image)
+                        plots[3].set_data(xdata, ydata)
+                        fig.suptitle("Frame " + str(t+1))
+
+                        return plots
+
+                    ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
+                                                  repeat_delay = 2000)
+                    ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
+
+            else:
+                pass
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+    #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     plt.xlim(0,len(gen_mse_avg_))
+    #         #     plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg))
+    #         #     plt.xlabel("Frames")
+    #         #     plt.ylabel("MSE_AVG")
+    #         #     #X = list(range(len(gen_mse_avg_)))
+    #         #     #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     def animate_metric(j):
+    #         #         data = gen_mse_avg_[:(j+1)]
+    #         #         x = list(range(len(gen_mse_avg_)))[:(j+1)]
+    #         #         p = sns.lineplot(x=x,y=data,color="b")
+    #         #         p.tick_params(labelsize=17)
+    #         #         plt.setp(p.lines, linewidth=6)
+    #         #     ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif"))
+    #         #
+    #         #
+    #         # for i, input_images_ in enumerate(input_images):
+    #         #     #context_images_ = (input_results['images'][i])
+    #         #     #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, input_image in enumerate(input_images_):
+    #         #         im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2,"Frame_" + str(t))
+    #         #         ims.append([im,ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif"))
+    #         #     #plt.show()
+    #         #
+    #         # for i,gen_images_ in enumerate(gen_images):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, gen_image in enumerate(gen_images_):
+    #         #         im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2, "Frame_" + str(t))
+    #         #         ims.append([im, ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif"))
+    #
+    #
+    #             # for i, gen_images_ in enumerate(gen_images):
+    #             #     #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8)
+    #             #     #gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
+    #             #     #bing
+    #             #     context_images_ = (input_results['images'][i])
+    #             #     gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #             #     context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
+    #             #     plt.figure(figsize = (10,2))
+    #             #     gs = gridspec.GridSpec(2,10)
+    #             #     gs.update(wspace=0.,hspace=0.)
+    #             #     for t, gen_image in enumerate(gen_images_):
+    #             #         gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1)))
+    #             #         gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #         plt.subplot(gs[t])
+    #             #         plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')  # the last index sets the channel. 0 = t2
+    #             #         # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Actual', fontsize = 10)
+    #             #
+    #             #         plt.subplot(gs[t + 10])
+    #             #         plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #             #         # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Predicted', fontsize = 10)
+    #             #     plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png')
+    #             #     plt.clf()
+    #
+    #             # if args.gif_length:
+    #             #     context_and_gen_images = context_and_gen_images[:args.gif_length]
+    #             # save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
+    #             #          context_and_gen_images, fps=args.fps)
+    #             #
+    #             # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
+    #             # for t, gen_image in enumerate(gen_images_):
+    #             #     gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #     if gen_image.shape[-1] == 1:
+    #             #       gen_image = np.tile(gen_image, (1, 1, 3))
+    #             #     else:
+    #             #       gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
+    #             #     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
+
+        sample_ind += args.batch_size
+
+
+    with open(os.path.join(args.output_png_dir, "input_images_all"),"wb") as input_files:
+        pickle.dump(input_images_all,input_files)
+
+    with open(os.path.join(args.output_png_dir, "gen_images_all"),"wb") as gen_files:
+        pickle.dump(gen_images_all,gen_files)
+
+    with open(os.path.join(args.output_png_dir, "input_images_all"),"rb") as input_files:
+        input_images_all = pickle.load(input_files)
+
+    with open(os.path.join(args.output_png_dir, "gen_images_all"),"rb") as gen_files:
+        gen_images_all=pickle.load(gen_files)
+    ims = []
+    fig = plt.figure()
+    for frame in range(19):
+        input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel)
+        #pix_mean = np.mean(input_gen_diff, axis = 0)
+        #pix_std = np.std(input_gen_diff, axis=0)
+        im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
+        if frame == 0:
+            fig.colorbar(im)
+        ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
+        ims.append([im, ttl])
+    ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
+    plt.close("all")
+
+    ims = []
+    fig = plt.figure()
+    for frame in range(19):
+        pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
+        #pix_mean = np.mean(input_gen_diff, axis = 0)
+        #pix_std = np.std(input_gen_diff, axis=0)
+        im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
+        if frame == 0:
+            fig.colorbar(im)
+        ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
+        ims.append([im, ttl])
+    ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
+
+    gen_images_all = np.array(gen_images_all)
+    input_images_all = np.array(input_images_all)
+    # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
+    # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
+    # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
+
+    mse_model = np.mean((input_images_all[:, :10,:,:,0] - gen_images_all[:, :10,:,:,0])**2)  # look at all timesteps except the first
+    mse_model_last = np.mean((input_images_all[:,10,:,:,0] - gen_images_all[:, 10,:,:,0])**2)
+    mse_prev = np.mean((input_images_all[:, :9, :, :, 0] - input_images_all[:, 1:10, :, :, 0])**2 )
+
+    def psnr(img1, img2):
+        mse = np.mean((img1 - img2) ** 2)
+        if mse == 0: return 100
+        PIXEL_MAX = 1
+        return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+
+    psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
+    psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
+    psnr_prev = psnr(input_images_all[:, :9, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
+    f = open(os.path.join(args.output_png_dir,'prediction_scores_4prediction.txt'), 'w')
+    f.write("Model MSE: %f\n" % mse_model)
+    f.write("Model MSE from only last prediction in sequence: %f\n" % mse_model_last)
+    f.write("Previous Frame MSE: %f\n" % mse_prev)
+    f.write("Model PSNR: %f\n" % psnr_model)
+    f.write("Model PSNR from only last prediction in sequence: %f\n" % psnr_model_last)
+    f.write("Previous frame PSNR: %f\n" % psnr_prev)
+    f.write("Shape of X_test: " + str(input_images_all.shape))
+    f.write("")
+    f.write("Shape of X_hat: " + str(gen_images_all.shape))
+    f.close()
+
+    seed(1)
+    s = random.sample(range(len(gen_images_all)), 100)
+    print("******KDP******")
+    #kernel density plot for checking the model collapse
+    fig = plt.figure()
+    kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
+    kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
+    kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
+    plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
+    plt.clf()
+
+    #line plot for evaluating the prediction and groud-truth
+    for i in [0,3,6,9,12,15,18]:
+        fig = plt.figure()
+        plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
+        #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
+        plt.xlabel("Prediction")
+        plt.ylabel("Real values")
+        plt.title("Frame_{}".format(i+1))
+        plt.plot([250,300], [250,300],color="black")
+        plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
+        plt.clf()
+
+
+    mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
+    x = [str(i+1) for i in list(range(19))]
+    fig,axis = plt.subplots()
+    mean_f = np.mean(mse_model_by_frames, axis = 0)
+    median = np.median(mse_model_by_frames, axis=0)
+    q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
+    q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
+    d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
+    d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
+    plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
+    plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
+    plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
+    plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
+    plt.title(f'MSE percentile')
+    plt.xlabel("Frames")
+    plt.legend(loc=2, fontsize=8)
+    plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/generate_all.sh b/scripts/generate_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c3736b36df1840cbc46b89526cef0c908b500760
--- /dev/null
+++ b/scripts/generate_all.sh
@@ -0,0 +1,55 @@
+# BAIR action-free robot pushing dataset
+dataset=bair_action_free
+CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \
+    --dataset_hparams sequence_length=30 --model ground_truth --mode test \
+    --output_gif_dir results_test_2afc/${dataset}/ground_truth \
+    --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    sv2p_time_invariant \
+; do
+    CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \
+        --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \
+        --results_gif_dir results_test_2afc/${dataset} \
+        --results_png_dir results_test_samples/${dataset} --gif_length 10
+done
+
+# KTH human actions dataset
+# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence
+dataset=kth
+CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/kth --dataset kth \
+    --dataset_hparams sequence_length=40 --model ground_truth --mode test \
+    --output_gif_dir results_test_2afc/${dataset}/ground_truth \
+    --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 --batch_size 1
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    sv2p_time_invariant \
+    sv2p_time_variant \
+; do
+    CUDA_VISIBLE_DEVICES=1 python scripts/generate.py --input_dir data/kth \
+        --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test \
+        --results_gif_dir results_test_2afc/${dataset} \
+        --results_png_dir results_test_samples/${dataset} --gif_length 10 --batch_size 1
+done
+
+# BAIR action-conditioned robot pushing dataset
+dataset=bair
+CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \
+    --dataset_hparams sequence_length=30 --model ground_truth --mode test \
+    --output_gif_dir results_test_2afc/${dataset}/ground_truth \
+    --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    sv2p_time_variant \
+; do
+    CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \
+        --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \
+        --results_gif_dir results_test_2afc/${dataset} \
+        --results_png_dir results_test_samples/${dataset} --gif_length 10
+done
diff --git a/scripts/generate_anomaly.py b/scripts/generate_anomaly.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c8e555c2e223c51bfb555c7dd0fa0eb089610d8
--- /dev/null
+++ b/scripts/generate_anomaly.py
@@ -0,0 +1,518 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import math
+import random
+import cv2
+import numpy as np
+import tensorflow as tf
+import seaborn as sns
+import pickle
+from random import seed
+import random
+import json
+import numpy as np
+#from six.moves import cPickle
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import matplotlib.animation as animation
+import seaborn as sns
+import pandas as pd
+from video_prediction import datasets, models
+from matplotlib.colors import LinearSegmentedColormap
+from matplotlib.ticker import MaxNLocator
+from video_prediction.utils.ffmpeg_gif import save_gif
+
+with open("./splits_size_64_64_1/geo_info.json","r") as json_file:
+    geo = json.load(json_file)
+    lat = [round(i,2) for i in geo["lat"]]
+    lon = [round(i,2) for i in geo["lon"]]
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified")
+    parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified")
+    parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified")
+    parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is "
+                                                 "results_gif_dir/model_fname")
+    parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is "
+                                                 "results_png_dir/model_fname")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+
+    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
+
+    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
+    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
+    parser.add_argument("--num_epochs", type=int, default=1)
+
+    parser.add_argument("--num_stochastic_samples", type=int, default=1) #Bing original is 5, change to 1
+    parser.add_argument("--gif_length", type=int, help="default is sequence_length")
+    parser.add_argument("--fps", type=int, default=4)
+
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int, default=7)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+    args.results_gif_dir = args.results_gif_dir or args.results_dir
+    args.results_png_dir = args.results_png_dir or args.results_dir
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1])
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1])
+    else:
+        if not args.dataset:
+            raise ValueError('dataset is required when checkpoint is not specified')
+        if not args.model:
+            raise ValueError('model is required when checkpoint is not specified')
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    dataset = VideoDataset(
+        args.input_dir,
+        mode=args.mode,
+        num_epochs=args.num_epochs,
+        seed=args.seed,
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
+    VideoPredictionModel = models.get_model_class(args.model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        mode=args.mode,
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams)
+
+    sequence_length = model.hparams.sequence_length
+    context_frames = model.hparams.context_frames
+    future_length = sequence_length - context_frames
+
+    if args.num_samples:
+        if args.num_samples > dataset.num_examples_per_epoch():
+            raise ValueError('num_samples cannot be larger than the dataset')
+        num_examples_per_epoch = args.num_samples
+    else:
+        #Bing: error occurs here, cheats a little bit here
+        #num_examples_per_epoch = dataset.num_examples_per_epoch()
+        num_examples_per_epoch = args.batch_size * 8
+    if num_examples_per_epoch % args.batch_size != 0:
+        #bing
+        #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
+        pass
+    #Bing if it is era 5 data we used dataset.make_batch_v2
+    #inputs = dataset.make_batch(args.batch_size)
+    inputs, inputs_mean = dataset.make_batch_v2(args.batch_size)
+    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    with tf.variable_scope(''):
+        model.build_graph(input_phs)
+
+    for output_dir in (args.output_gif_dir, args.output_png_dir):
+        if not os.path.exists(output_dir):
+            os.makedirs(output_dir)
+        with open(os.path.join(output_dir, "options.json"), "w") as f:
+            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+            f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
+        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
+            f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+    sess = tf.Session(config=config)
+    sess.graph.as_default()
+    model.restore(sess, args.checkpoint)
+    sample_ind = 0
+    gen_images_all = []
+    input_images_all = []
+
+    while True:
+        if args.num_samples and sample_ind >= args.num_samples:
+            break
+        try:
+            input_results = sess.run(inputs)
+            input_mean_results = sess.run(inputs_mean)
+            input_final = input_results["images"] + input_mean_results["images"]
+
+        except tf.errors.OutOfRangeError:
+            break
+        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        for stochastic_sample_ind in range(args.num_stochastic_samples):
+            print("Stochastic sample id", stochastic_sample_ind)
+            gen_anomaly = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
+            gen_images = gen_anomaly + input_mean_results["images"][:,1:,:,:]
+            #input_images = sess.run(inputs["images"])
+            #Bing: Add evaluation metrics
+            # fetches = {'images': model.inputs['images']}
+            # fetches.update(model.eval_outputs.items())
+            # fetches.update(model.eval_metrics.items())
+            # results = sess.run(fetches, feed_dict = feed_dict)
+            # input_images = results["images"] #shape (batch_size,future_frames,height,width,channel)
+            # only keep the future frames
+            #gen_images = gen_images[:, -future_length:] #(8,10,64,64,1) (batch_size, sequences, height, width, channel)
+            #input_images = input_results["images"][:,-future_length:,:,:]
+            #input_images = input_results["images"][:,1:,:,:,:]
+            input_images = input_final [:,1:,:,:,:]
+            #gen_mse_avg = results["eval_mse/avg"] #shape (batch_size,future_frames)
+            print("Finish sample ind",stochastic_sample_ind)
+            input_gen_diff_ = input_images - gen_images
+            #diff_image_range = pd.cut(input_gen_diff_.flatten(), bins = 4, labels = [-10, -5, 0, 5], right = False)
+            #diff_image_range = np.reshape(np.array(diff_image_range),input_gen_diff_.shape)
+            gen_images_all.extend(gen_images)
+            input_images_all.extend(input_images)
+
+            colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
+            cmap_name = 'my_list'
+            if sample_ind < 100:
+                for i in range(len(gen_images)):
+                    name = 'Batch_id_' + str(sample_ind) + " + Sample_" + str(i)
+                    gen_images_ = gen_images[i, :]
+                    gen_mse_avg_ = [np.mean(input_gen_diff_[i, frame, :, :, :]**2) for frame in
+                                    range(19)]  # return the list with 10 (sequence) mse
+
+                    input_gen_diff = input_gen_diff_[i,:,:,:,:]
+                    input_images_ = input_images[i, :]
+                    #gen_mse_avg_ = gen_mse_avg[i, :]
+                    fig = plt.figure()
+                    gs = gridspec.GridSpec(4,6)
+                    gs.update(wspace = 0.7,hspace=0.8)
+                    ax1 = plt.subplot(gs[0:2,0:3])
+                    ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
+                    ax3 = plt.subplot(gs[2:4,0:3])
+                    ax4 = plt.subplot(gs[2:4,3:])
+                    xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
+                    ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+                    plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
+                    ax1.title.set_text("(a) Ground Truth")
+                    ax2.title.set_text("(b) SAVP")
+                    ax3.title.set_text("(c) Diff.")
+                    ax4.title.set_text("(d) MSE")
+
+                    ax1.xaxis.set_tick_params(labelsize=7)
+                    ax1.yaxis.set_tick_params(labelsize = 7)
+                    ax2.xaxis.set_tick_params(labelsize=7)
+                    ax2.yaxis.set_tick_params(labelsize = 7)
+                    ax3.xaxis.set_tick_params(labelsize=7)
+                    ax3.yaxis.set_tick_params(labelsize = 7)
+
+                    init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
+                    print("inti images shape", init_images.shape)
+                    xdata, ydata = [], []
+                    plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+                    plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+                    #x = np.linspace(0, 64, 64)
+                    #y = np.linspace(0, 64, 64)
+                    #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+                    #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+                    fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
+                    fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
+
+                    cm = LinearSegmentedColormap.from_list(
+                        cmap_name, "bwr", N = 5)
+
+                    plot3 = ax3.imshow(init_images, vmin=-10, vmax=10, cmap=cm)#cmap = 'PuBu_r',
+
+                    plot4, = ax4.plot([], [], color = "r")
+                    ax4.set_xlim(0, len(gen_mse_avg_)-1)
+                    ax4.set_ylim(0, 10)
+                    ax4.set_xlabel("Frames", fontsize=10)
+                    #ax4.set_ylabel("MSE", fontsize=10)
+                    ax4.xaxis.set_tick_params(labelsize=7)
+                    ax4.yaxis.set_tick_params(labelsize=7)
+
+
+                    plots = [plot1, plot2, plot3, plot4]
+
+                    #fig.colorbar(plots[1], ax = [ax1, ax2])
+
+                    fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
+                    #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
+                    #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
+
+                    def animation_sample(t):
+                        input_image = input_images_[t, :, :, 0]
+                        gen_image = gen_images_[t, :, :, 0]
+                        diff_image = input_gen_diff[t,:,:,0]
+
+                        data = gen_mse_avg_[:t + 1]
+                        # x = list(range(len(gen_mse_avg_)))[:t+1]
+                        xdata.append(t)
+                        print("xdata", xdata)
+                        ydata.append(gen_mse_avg_[t])
+
+                        print("ydata", ydata)
+                        # p = sns.lineplot(x=x,y=data,color="b")
+                        # p.tick_params(labelsize=17)
+                        # plt.setp(p.lines, linewidth=6)
+                        plots[0].set_data(input_image)
+                        plots[1].set_data(gen_image)
+                        #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+                        #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+                        plots[2].set_data(diff_image)
+                        plots[3].set_data(xdata, ydata)
+                        fig.suptitle("Frame " + str(t+1))
+
+                        return plots
+
+                    ani = animation.FuncAnimation(fig, animation_sample, frames = len(gen_mse_avg_), interval = 1000,
+                                                  repeat_delay = 2000)
+                    ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
+
+                else:
+                    pass
+
+    #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     plt.xlim(0,len(gen_mse_avg_))
+    #         #     plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg))
+    #         #     plt.xlabel("Frames")
+    #         #     plt.ylabel("MSE_AVG")
+    #         #     #X = list(range(len(gen_mse_avg_)))
+    #         #     #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     def animate_metric(j):
+    #         #         data = gen_mse_avg_[:(j+1)]
+    #         #         x = list(range(len(gen_mse_avg_)))[:(j+1)]
+    #         #         p = sns.lineplot(x=x,y=data,color="b")
+    #         #         p.tick_params(labelsize=17)
+    #         #         plt.setp(p.lines, linewidth=6)
+    #         #     ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif"))
+    #         #
+    #         #
+    #         # for i, input_images_ in enumerate(input_images):
+    #         #     #context_images_ = (input_results['images'][i])
+    #         #     #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, input_image in enumerate(input_images_):
+    #         #         im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2,"Frame_" + str(t))
+    #         #         ims.append([im,ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif"))
+    #         #     #plt.show()
+    #         #
+    #         # for i,gen_images_ in enumerate(gen_images):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, gen_image in enumerate(gen_images_):
+    #         #         im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2, "Frame_" + str(t))
+    #         #         ims.append([im, ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif"))
+    #
+    #
+    #             # for i, gen_images_ in enumerate(gen_images):
+    #             #     #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8)
+    #             #     #gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
+    #             #     #bing
+    #             #     context_images_ = (input_results['images'][i])
+    #             #     gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #             #     context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
+    #             #     plt.figure(figsize = (10,2))
+    #             #     gs = gridspec.GridSpec(2,10)
+    #             #     gs.update(wspace=0.,hspace=0.)
+    #             #     for t, gen_image in enumerate(gen_images_):
+    #             #         gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1)))
+    #             #         gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #         plt.subplot(gs[t])
+    #             #         plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')  # the last index sets the channel. 0 = t2
+    #             #         # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Actual', fontsize = 10)
+    #             #
+    #             #         plt.subplot(gs[t + 10])
+    #             #         plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #             #         # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Predicted', fontsize = 10)
+    #             #     plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png')
+    #             #     plt.clf()
+    #
+    #             # if args.gif_length:
+    #             #     context_and_gen_images = context_and_gen_images[:args.gif_length]
+    #             # save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
+    #             #          context_and_gen_images, fps=args.fps)
+    #             #
+    #             # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
+    #             # for t, gen_image in enumerate(gen_images_):
+    #             #     gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #     if gen_image.shape[-1] == 1:
+    #             #       gen_image = np.tile(gen_image, (1, 1, 3))
+    #             #     else:
+    #             #       gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
+    #             #     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
+
+        sample_ind += args.batch_size
+
+
+    with open(os.path.join(args.output_png_dir, "input_images_all"),"wb") as input_files:
+        pickle.dump(input_images_all,input_files)
+
+    with open(os.path.join(args.output_png_dir, "gen_images_all"),"wb") as gen_files:
+        pickle.dump(gen_images_all,gen_files)
+
+    with open(os.path.join(args.output_png_dir, "input_images_all"),"rb") as input_files:
+        input_images_all = pickle.load(input_files)
+
+    with open(os.path.join(args.output_png_dir, "gen_images_all"),"rb") as gen_files:
+        gen_images_all=pickle.load(gen_files)
+    ims = []
+    fig = plt.figure()
+    for frame in range(19):
+        input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel)
+        #pix_mean = np.mean(input_gen_diff, axis = 0)
+        #pix_std = np.std(input_gen_diff, axis=0)
+        im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
+        if frame == 0:
+            fig.colorbar(im)
+        ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
+        ims.append([im, ttl])
+    ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
+    plt.close("all")
+
+    ims = []
+    fig = plt.figure()
+    for frame in range(19):
+        pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
+        #pix_mean = np.mean(input_gen_diff, axis = 0)
+        #pix_std = np.std(input_gen_diff, axis=0)
+        im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
+        if frame == 0:
+            fig.colorbar(im)
+        ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
+        ims.append([im, ttl])
+    ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
+
+    gen_images_all = np.array(gen_images_all)
+    input_images_all = np.array(input_images_all)
+    # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
+    # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
+    # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
+
+    mse_model = np.mean((input_images_all[:, :10,:,:,0] - gen_images_all[:, :10,:,:,0])**2)  # look at all timesteps except the first
+    mse_model_last = np.mean((input_images_all[:,10,:,:,0] - gen_images_all[:, 10,:,:,0])**2)
+    mse_prev = np.mean((input_images_all[:, :9, :, :, 0] - input_images_all[:, 1:10, :, :, 0])**2 )
+
+    def psnr(img1, img2):
+        mse = np.mean((img1 - img2) ** 2)
+        if mse == 0: return 100
+        PIXEL_MAX = 1
+        return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+
+    psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
+    psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
+    psnr_prev = psnr(input_images_all[:, :9, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
+    f = open(os.path.join(args.output_png_dir,'prediction_scores_4prediction.txt'), 'w')
+    f.write("Model MSE: %f\n" % mse_model)
+    f.write("Model MSE from only last prediction in sequence: %f\n" % mse_model_last)
+    f.write("Previous Frame MSE: %f\n" % mse_prev)
+    f.write("Model PSNR: %f\n" % psnr_model)
+    f.write("Model PSNR from only last prediction in sequence: %f\n" % psnr_model_last)
+    f.write("Previous frame PSNR: %f\n" % psnr_prev)
+    f.write("Shape of X_test: " + str(input_images_all.shape))
+    f.write("")
+    f.write("Shape of X_hat: " + str(gen_images_all.shape))
+    f.close()
+
+    seed(1)
+    s = random.sample(range(len(gen_images_all)), 100)
+    print("******KDP******")
+    #kernel density plot for checking the model collapse
+    fig = plt.figure()
+    kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
+    kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
+    kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
+    plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
+    plt.clf()
+
+    #line plot for evaluating the prediction and groud-truth
+    for i in [0,3,6,9,12,15,18]:
+        fig = plt.figure()
+        plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
+        #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
+        plt.xlabel("Prediction")
+        plt.ylabel("Real values")
+        plt.title("Frame_{}".format(i+1))
+        plt.plot([250,300], [250,300],color="black")
+        plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
+        plt.clf()
+
+
+    mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
+    x = [str(i+1) for i in list(range(19))]
+    fig,axis = plt.subplots()
+    mean_f = np.mean(mse_model_by_frames, axis = 0)
+    median = np.median(mse_model_by_frames, axis=0)
+    q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
+    q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
+    d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
+    d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
+    plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
+    plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
+    plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
+    plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
+    plt.title(f'MSE percentile')
+    plt.xlabel("Frames")
+    plt.legend(loc=2, fontsize=8)
+    plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/generate_orig.py b/scripts/generate_orig.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a9e53037d247002743593a9bed36d7f1b9b4249
--- /dev/null
+++ b/scripts/generate_orig.py
@@ -0,0 +1,193 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import random
+
+import cv2
+import numpy as np
+import tensorflow as tf
+
+from video_prediction import datasets, models
+from video_prediction.utils.ffmpeg_gif import save_gif
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified")
+    parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified")
+    parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified")
+    parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is "
+                                                 "results_gif_dir/model_fname")
+    parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is "
+                                                 "results_png_dir/model_fname")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+
+    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
+
+    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
+    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
+    parser.add_argument("--num_epochs", type=int, default=1)
+
+    parser.add_argument("--num_stochastic_samples", type=int, default=5)
+    parser.add_argument("--gif_length", type=int, help="default is sequence_length")
+    parser.add_argument("--fps", type=int, default=4)
+
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int, default=7)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+    args.results_gif_dir = args.results_gif_dir or args.results_dir
+    args.results_png_dir = args.results_png_dir or args.results_dir
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1])
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1])
+    else:
+        if not args.dataset:
+            raise ValueError('dataset is required when checkpoint is not specified')
+        if not args.model:
+            raise ValueError('model is required when checkpoint is not specified')
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    dataset = VideoDataset(
+        args.input_dir,
+        mode=args.mode,
+        num_epochs=args.num_epochs,
+        seed=args.seed,
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
+
+    VideoPredictionModel = models.get_model_class(args.model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        mode=args.mode,
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams)
+
+    sequence_length = model.hparams.sequence_length
+    context_frames = model.hparams.context_frames
+    future_length = sequence_length - context_frames
+
+    if args.num_samples:
+        if args.num_samples > dataset.num_examples_per_epoch():
+            raise ValueError('num_samples cannot be larger than the dataset')
+        num_examples_per_epoch = args.num_samples
+    else:
+        num_examples_per_epoch = dataset.num_examples_per_epoch()
+    if num_examples_per_epoch % args.batch_size != 0:
+        raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
+
+    inputs = dataset.make_batch(args.batch_size)
+    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    with tf.variable_scope(''):
+        model.build_graph(input_phs)
+
+    for output_dir in (args.output_gif_dir, args.output_png_dir):
+        if not os.path.exists(output_dir):
+            os.makedirs(output_dir)
+        with open(os.path.join(output_dir, "options.json"), "w") as f:
+            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+            f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
+        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
+            f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+    sess = tf.Session(config=config)
+    sess.graph.as_default()
+    model.restore(sess, args.checkpoint)
+
+    sample_ind = 0
+    while True:
+        if args.num_samples and sample_ind >= args.num_samples:
+            break
+        try:
+            input_results = sess.run(inputs)
+        except tf.errors.OutOfRangeError:
+            break
+        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))
+
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        for stochastic_sample_ind in range(args.num_stochastic_samples):
+            gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
+            # only keep the future frames
+            gen_images = gen_images[:, -future_length:]
+            for i, gen_images_ in enumerate(gen_images):
+                #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8)
+                #gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
+                context_images_ = (input_results['images'][i])
+                gen_images_ = (gen_images_)
+
+                gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+                context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
+                if args.gif_length:
+                    context_and_gen_images = context_and_gen_images[:args.gif_length]
+                save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
+                         context_and_gen_images, fps=args.fps)
+                gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
+                for t, gen_image in enumerate(gen_images_):
+                    gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+                    if gen_image.shape[-1] == 1:
+                      gen_image = np.tile(gen_image, (1, 1, 3))
+                    else:
+                      gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
+                    cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
+
+        sample_ind += args.batch_size
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file
diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py
new file mode 100644
index 0000000000000000000000000000000000000000..331559f6287a4f24c1c19ee9f7f4b03309a22abf
--- /dev/null
+++ b/scripts/generate_transfer_learning_finetune.py
@@ -0,0 +1,731 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import math
+import random
+import cv2
+import numpy as np
+import tensorflow as tf
+import pickle
+import hickle
+from random import seed
+import random
+import json
+import numpy as np
+#from six.moves import cPickle
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import matplotlib.animation as animation
+import pandas as pd
+import re
+from video_prediction import datasets, models
+from matplotlib.colors import LinearSegmentedColormap
+#from matplotlib.ticker import MaxNLocator
+#from video_prediction.utils.ffmpeg_gif import save_gif
+from skimage.metrics import structural_similarity as ssim
+import datetime
+# Scarlet 2020/05/28: access to statistical values in json file 
+from os import path
+import sys
+sys.path.append(path.abspath('../video_prediction/datasets/'))
+from era5_dataset_v2 import Norm_data
+from os.path import dirname
+
+with open("../geo_info.json","r") as json_file:
+    geo = json.load(json_file)
+    lat = [round(i,2) for i in geo["lat"]]
+    lon = [round(i,2) for i in geo["lon"]]
+
+
+def psnr(img1, img2):
+    mse = np.mean((img1 - img2) ** 2)
+    if mse == 0: return 100
+    PIXEL_MAX = 1
+    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type = str, required = True,
+                        help = "either a directory containing subdirectories "
+                               "train, val, test, etc, or a directory containing "
+                               "the tfrecords")
+    parser.add_argument("--results_dir", type = str, default = 'results',
+                        help = "ignored if output_gif_dir is specified")
+    parser.add_argument("--results_gif_dir", type = str,
+                        help = "default is results_dir. ignored if output_gif_dir is specified")
+    parser.add_argument("--results_png_dir", type = str,
+                        help = "default is results_dir. ignored if output_png_dir is specified")
+    parser.add_argument("--output_gif_dir", help = "output directory where samples are saved as gifs. default is "
+                                                   "results_gif_dir/model_fname")
+    parser.add_argument("--output_png_dir", help = "output directory where samples are saved as pngs. default is "
+                                                   "results_png_dir/model_fname")
+    parser.add_argument("--checkpoint",
+                        help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+
+    parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val',
+                        help = 'mode for dataset, val or test.')
+
+    parser.add_argument("--dataset", type = str, help = "dataset class name")
+    parser.add_argument("--dataset_hparams", type = str,
+                        help = "a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--model", type = str, help = "model class name")
+    parser.add_argument("--model_hparams", type = str,
+                        help = "a string of comma separated list of model hyperparameters")
+
+    parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch")
+    parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)")
+    parser.add_argument("--num_epochs", type = int, default = 1)
+
+    parser.add_argument("--num_stochastic_samples", type = int, default = 1)
+    parser.add_argument("--gif_length", type = int, help = "default is sequence_length")
+    parser.add_argument("--fps", type = int, default = 4)
+
+    parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use")
+    parser.add_argument("--seed", type = int, default = 7)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+    #Bing:20200518 
+    input_dir = args.input_dir
+    temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/"
+    print ("temporal_dir:",temporal_dir)
+    args.results_gif_dir = args.results_gif_dir or args.results_dir
+    args.results_png_dir = args.results_png_dir or args.results_dir
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir,
+                                                                  os.path.split(checkpoint_dir)[1])
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir,
+                                                                  os.path.split(checkpoint_dir)[1])
+    else:
+        if not args.dataset:
+            raise ValueError('dataset is required when checkpoint is not specified')
+        if not args.model:
+            raise ValueError('model is required when checkpoint is not specified')
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    dataset = VideoDataset(
+        args.input_dir,
+        mode = args.mode,
+        num_epochs = args.num_epochs,
+        seed = args.seed,
+        hparams_dict = dataset_hparams_dict,
+        hparams = args.dataset_hparams)
+
+    VideoPredictionModel = models.get_model_class(args.model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        mode = args.mode,
+        hparams_dict = hparams_dict,
+        hparams = args.model_hparams)
+
+    sequence_length = model.hparams.sequence_length
+    context_frames = model.hparams.context_frames
+    future_length = sequence_length - context_frames
+
+    if args.num_samples:
+        if args.num_samples > dataset.num_examples_per_epoch():
+            raise ValueError('num_samples cannot be larger than the dataset')
+        num_examples_per_epoch = args.num_samples
+    else:
+        num_examples_per_epoch = dataset.num_examples_per_epoch()
+    if num_examples_per_epoch % args.batch_size != 0:
+        raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
+
+    inputs = dataset.make_batch(args.batch_size)
+    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    with tf.variable_scope(''):
+        model.build_graph(input_phs)
+
+    for output_dir in (args.output_gif_dir, args.output_png_dir):
+        if not os.path.exists(output_dir):
+            os.makedirs(output_dir)
+        with open(os.path.join(output_dir, "options.json"), "w") as f:
+            f.write(json.dumps(vars(args), sort_keys = True, indent = 4))
+        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+            f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4))
+        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
+            f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4))
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac)
+    config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True)
+    sess = tf.Session(config = config)
+    sess.graph.as_default()
+    sess.run(tf.global_variables_initializer())
+    sess.run(tf.local_variables_initializer())
+    model.restore(sess, args.checkpoint)
+
+    sample_ind = 0
+    gen_images_all = []
+    #Bing:20200410
+    persistent_images_all = []
+    input_images_all = []
+    #Bing:20201417
+    print ("temporal_dir:",temporal_dir)
+    test_temporal_pkl = pickle.load(open(os.path.join(temporal_dir,"T_test.pkl"),"rb"))
+    #val_temporal_pkl = pickle.load(open("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/T_val.pkl","rb"))
+    print("test temporal_pkl file looks like folowing", test_temporal_pkl)
+
+    #X_val = hickle.load("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/X_val.hkl")
+    X_test = hickle.load(os.path.join(temporal_dir,"X_test.hkl"))
+    is_first=True
+
+    #+++Scarlet:20200528    
+    norm_cls  = Norm_data('T2')
+    norm = 'minmax'
+    with open(os.path.join(dirname(input_dir),"hickle/splits/statistics.json")) as js_file:
+         norm_cls.check_and_set_norm(json.load(js_file),norm)
+    #---Scarlet:20200528    
+    while True:
+        print("Sample id", sample_ind)
+        if sample_ind <= 24:
+            pass
+        elif sample_ind >= len(X_test):
+            break
+        else:
+            gen_images_stochastic = []
+            if args.num_samples and sample_ind >= args.num_samples:
+                break
+            try:
+                input_results = sess.run(inputs)
+                input_images = input_results["images"]
+
+
+            except tf.errors.OutOfRangeError:
+                break
+
+            feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+            for stochastic_sample_ind in range(args.num_stochastic_samples):
+                input_images_all.extend(input_images)
+                with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files:
+                    pickle.dump(list(input_images_all), input_files)
+                
+                gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)
+                gen_images_stochastic.append(gen_images)
+                #print("Stochastic_sample,", stochastic_sample_ind)
+                for i in range(args.batch_size):
+                    #bing:20200417
+                    t_stampe = test_temporal_pkl[sample_ind+i]
+                    print("timestamp:",type(t_stampe))
+                    persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1)
+                    print ("persistent ts",persistent_ts)
+                    persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts))
+                    persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length]
+                    print("persistent index in test set:", persistent_idx)
+                    print("persistent_X.shape",persistent_X.shape)
+                    persistent_images_all.append(persistent_X)
+
+                        
+                #print("batch", i)
+                #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
+
+
+                    cmap_name = 'my_list'
+                    if sample_ind < 100:
+                        #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
+                        #    sample_ind) + " + Sample_" + str(i)
+                        name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S")
+                        print ("name",name)
+                        gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
+                        #gen_images_ =  gen_images[i, :]
+                        input_images_ = input_images[i, :]
+                        #Bing:20200417
+                        #persistent_images = ?
+                        #+++Scarlet:20200528   
+                        #print('Scarlet1')
+                        input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm)
+                        persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm)
+                        #---Scarlet:20200528    
+                        gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
+                                        range(sequence_length)]  # return the list with 10 (sequence) mse
+                        persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in
+                                        range(sequence_length)]  # return the list with 10 (sequence) mse
+
+                        fig = plt.figure(figsize=(18,6))
+                        gs = gridspec.GridSpec(1, 10)
+                        gs.update(wspace = 0., hspace = 0.)
+                        ts = list(range(10,20)) #[10,11,12,..]
+                        xlables = [round(i,2) for i  in list(np.linspace(np.min(lon),np.max(lon),5))]
+                        ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+
+                        for t in ts:
+
+                            #if t==0 : ax1=plt.subplot(gs[t])
+                            ax1 = plt.subplot(gs[ts.index(t)])
+                            #+++Scarlet:20200528
+                            #print('Scarlet2')
+                            input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm)
+                            #---Scarlet:20200528
+                            plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
+                            ax1.title.set_text("t = " + str(t+1-10))
+                            plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+                            if t == 0:
+                                plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
+                                plt.ylabel("Ground Truth", fontsize=10)
+                        plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
+                        plt.clf()
+
+                        fig = plt.figure(figsize=(12,6))
+                        gs = gridspec.GridSpec(1, 10)
+                        gs.update(wspace = 0., hspace = 0.)
+
+                        for t in ts:
+                            #if t==0 : ax1=plt.subplot(gs[t])
+                            ax1 = plt.subplot(gs[ts.index(t)])
+                            #+++Scarlet:20200528
+                            #print('Scarlet3')
+                            gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm)
+                            #---Scarlet:20200528
+                            plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
+                            ax1.title.set_text("t = " + str(t+1-10))
+                            plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+                        plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
+                        plt.clf()
+
+
+                        fig = plt.figure(figsize=(12,6))
+                        gs = gridspec.GridSpec(1, 10)
+                        gs.update(wspace = 0., hspace = 0.)
+                        for t in ts:
+                            #if t==0 : ax1=plt.subplot(gs[t])
+                            ax1 = plt.subplot(gs[ts.index(t)])
+                            #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+                            plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300)
+                            ax1.title.set_text("t = " + str(t+1-10))
+                            plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+                        plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg"))
+                        plt.clf()
+
+                        
+                with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files:
+                    pickle.dump(list(persistent_images_all), input_files)
+                    print ("Save persistent all")
+                if is_first:
+                    gen_images_all = gen_images_stochastic
+                    is_first = False
+                else:
+                    gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
+
+                if args.num_stochastic_samples == 1:
+                    with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files:
+                        pickle.dump(list(gen_images_all[0]), gen_files)
+                        print ("Save generate all")
+                else:
+                    with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
+                        pickle.dump(list(gen_images_stochastic), gen_files)
+                    with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
+                        pickle.dump(list(gen_images_all), gen_files)
+##                
+##
+##                    # fig = plt.figure()
+##                    # gs = gridspec.GridSpec(4,6)
+##                    # gs.update(wspace = 0.7,hspace=0.8)
+##                    # ax1 = plt.subplot(gs[0:2,0:3])
+##                    # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
+##                    # ax3 = plt.subplot(gs[2:4,0:3])
+##                    # ax4 = plt.subplot(gs[2:4,3:])
+##                    # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
+##                    # ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+##                    # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
+##                    # ax1.title.set_text("(a) Ground Truth")
+##                    # ax2.title.set_text("(b) SAVP")
+##                    # ax3.title.set_text("(c) Diff.")
+##                    # ax4.title.set_text("(d) MSE")
+##                    #
+##                    # ax1.xaxis.set_tick_params(labelsize=7)
+##                    # ax1.yaxis.set_tick_params(labelsize = 7)
+##                    # ax2.xaxis.set_tick_params(labelsize=7)
+##                    # ax2.yaxis.set_tick_params(labelsize = 7)
+##                    # ax3.xaxis.set_tick_params(labelsize=7)
+##                    # ax3.yaxis.set_tick_params(labelsize = 7)
+##                    #
+##                    # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
+##                    # print("inti images shape", init_images.shape)
+##                    # xdata, ydata = [], []
+##                    # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
+##                    # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
+##                    # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+##                    # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+##                    # #x = np.linspace(0, 64, 64)
+##                    # #y = np.linspace(0, 64, 64)
+##                    # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+##                    # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+##                    # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
+##                    # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
+##                    #
+##                    # cm = LinearSegmentedColormap.from_list(
+##                    #     cmap_name, "bwr", N = 5)
+##                    #
+##                    # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r',
+##                    # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm)  # cmap = 'PuBu_r',
+##                    # plot4, = ax4.plot([], [], color = "r")
+##                    # ax4.set_xlim(0, future_length-1)
+##                    # ax4.set_ylim(0, 20)
+##                    # #ax4.set_ylim(0, 0.5)
+##                    # ax4.set_xlabel("Frames", fontsize=10)
+##                    # #ax4.set_ylabel("MSE", fontsize=10)
+##                    # ax4.xaxis.set_tick_params(labelsize=7)
+##                    # ax4.yaxis.set_tick_params(labelsize=7)
+##                    #
+##                    #
+##                    # plots = [plot1, plot2, plot3, plot4]
+##                    #
+##                    # #fig.colorbar(plots[1], ax = [ax1, ax2])
+##                    #
+##                    # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
+##                    # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
+##                    # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
+##                    #
+##                    # def animation_sample(t):
+##                    #     input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
+##                    #     gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
+##                    #     diff_image = input_gen_diff[t,:,:]
+##                    #     # p = sns.lineplot(x=x,y=data,color="b")
+##                    #     # p.tick_params(labelsize=17)
+##                    #     # plt.setp(p.lines, linewidth=6)
+##                    #     plots[0].set_data(input_image)
+##                    #     plots[1].set_data(gen_image)
+##                    #     #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+##                    #     #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+##                    #     plots[2].set_data(diff_image)
+##                    #
+##                    #     if t >= future_length:
+##                    #         #data = gen_mse_avg_[:t + 1]
+##                    #         # x = list(range(len(gen_mse_avg_)))[:t+1]
+##                    #         xdata.append(t-future_length)
+##                    #         print("xdata", xdata)
+##                    #         ydata.append(gen_mse_avg_[t])
+##                    #         print("ydata", ydata)
+##                    #         plots[3].set_data(xdata, ydata)
+##                    #         fig.suptitle("Predicted Frame " + str(t-future_length))
+##                    #     else:
+##                    #         #plots[3].set_data(xdata, ydata)
+##                    #         fig.suptitle("Context Frame " + str(t))
+##                    #     return plots
+##                    #
+##                    # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
+##                    #                               repeat_delay=2000)
+##                    # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
+##
+####                else:
+####                    pass
+##
+        sample_ind += args.batch_size
+
+
+    #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     plt.xlim(0,len(gen_mse_avg_))
+    #         #     plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg))
+    #         #     plt.xlabel("Frames")
+    #         #     plt.ylabel("MSE_AVG")
+    #         #     #X = list(range(len(gen_mse_avg_)))
+    #         #     #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
+    #         #     def animate_metric(j):
+    #         #         data = gen_mse_avg_[:(j+1)]
+    #         #         x = list(range(len(gen_mse_avg_)))[:(j+1)]
+    #         #         p = sns.lineplot(x=x,y=data,color="b")
+    #         #         p.tick_params(labelsize=17)
+    #         #         plt.setp(p.lines, linewidth=6)
+    #         #     ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif"))
+    #         #
+    #         #
+    #         # for i, input_images_ in enumerate(input_images):
+    #         #     #context_images_ = (input_results['images'][i])
+    #         #     #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, input_image in enumerate(input_images_):
+    #         #         im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2,"Frame_" + str(t))
+    #         #         ims.append([im,ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000)
+    #         #     ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif"))
+    #         #     #plt.show()
+    #         #
+    #         # for i,gen_images_ in enumerate(gen_images):
+    #         #     ims = []
+    #         #     fig = plt.figure()
+    #         #     for t, gen_image in enumerate(gen_images_):
+    #         #         im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #         #         ttl = plt.text(1.5, 2, "Frame_" + str(t))
+    #         #         ims.append([im, ttl])
+    #         #     ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    #         #     ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif"))
+    #
+    #
+    #             # for i, gen_images_ in enumerate(gen_images):
+    #             #     #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8)
+    #             #     #gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
+    #             #     #bing
+    #             #     context_images_ = (input_results['images'][i])
+    #             #     gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+    #             #     context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
+    #             #     plt.figure(figsize = (10,2))
+    #             #     gs = gridspec.GridSpec(2,10)
+    #             #     gs.update(wspace=0.,hspace=0.)
+    #             #     for t, gen_image in enumerate(gen_images_):
+    #             #         gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1)))
+    #             #         gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #         plt.subplot(gs[t])
+    #             #         plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')  # the last index sets the channel. 0 = t2
+    #             #         # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Actual', fontsize = 10)
+    #             #
+    #             #         plt.subplot(gs[t + 10])
+    #             #         plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
+    #             #         # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
+    #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
+    #             #                         right = False, labelbottom = False, labelleft = False)
+    #             #         if t == 0: plt.ylabel('Predicted', fontsize = 10)
+    #             #     plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png')
+    #             #     plt.clf()
+    #
+    #             # if args.gif_length:
+    #             #     context_and_gen_images = context_and_gen_images[:args.gif_length]
+    #             # save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
+    #             #          context_and_gen_images, fps=args.fps)
+    #             #
+    #             # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
+    #             # for t, gen_image in enumerate(gen_images_):
+    #             #     gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
+    #             #     if gen_image.shape[-1] == 1:
+    #             #       gen_image = np.tile(gen_image, (1, 1, 3))
+    #             #     else:
+    #             #       gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
+    #             #     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
+
+
+
+    with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files:
+        input_images_all = pickle.load(input_files)
+
+    with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files:
+        gen_images_all = pickle.load(gen_files)
+
+    with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files:
+        persistent_images_all = pickle.load(gen_files)
+
+    #+++Scarlet:20200528
+    #print('Scarlet4')
+    input_images_all = np.array(input_images_all)
+    input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm)
+    #---Scarlet:20200528
+    persistent_images_all = np.array(persistent_images_all)
+    if len(np.array(gen_images_all).shape) == 6:
+        for i in range(len(gen_images_all)):
+            #+++Scarlet:20200528
+            #print('Scarlet5')
+            gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:]
+            gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm)
+            #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+            #---Scarlet:20200528
+            mse_all = []
+            psnr_all = []
+            ssim_all = []
+            f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w')
+            for i in range(future_length):
+                mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :,
+                                                                            0]) ** 2)  # look at all timesteps except the first
+                psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0])
+                ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0],
+                                  data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min(
+                                      input_images_all[:, i + 10, :, :, 0].flatten()))
+                mse_all.extend([mse_model])
+                psnr_all.extend([psnr_model])
+                ssim_all.extend([ssim_model])
+                results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
+                f.write("##########Predicted Frame {}\n".format(str(i + 1)))
+                f.write("Model MSE: %f\n" % mse_model)
+                # f.write("Previous Frame MSE: %f\n" % mse_prev)
+                f.write("Model PSNR: %f\n" % psnr_model)
+                f.write("Model SSIM: %f\n" % ssim_model)
+
+
+            pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb"))
+            # f.write("Previous frame PSNR: %f\n" % psnr_prev)
+            f.write("Shape of X_test: " + str(input_images_all.shape))
+            f.write("")
+            f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape))
+
+    else:
+        #+++Scarlet:20200528
+        #print('Scarlet6')
+        gen_images_all = np.array(gen_images_all)
+        gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm)
+        #---Scarlet:20200528
+        
+        # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
+        # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
+        # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
+        mse_all = []
+        psnr_all = []
+        ssim_all = []
+        persistent_mse_all = []
+        persistent_psnr_all = []
+        persistent_ssim_all = []
+        f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w')
+        for i in range(future_length):
+            mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :,
+                                                                        0]) ** 2)  # look at all timesteps except the first
+            persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :,
+                                                                        0]) ** 2)  # look at all timesteps except the first
+            
+            psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0])
+            ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0],
+                              data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min(
+                                  input_images_all[:, i + 10, :, :, 0].flatten()))
+            persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0])
+            persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0],
+                              data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten()))
+            mse_all.extend([mse_model])
+            psnr_all.extend([psnr_model])
+            ssim_all.extend([ssim_model])
+            persistent_mse_all.extend([persistent_mse_model])
+            persistent_psnr_all.extend([persistent_psnr_model])
+            persistent_ssim_all.extend([persistent_ssim_model])
+            results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
+
+            persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all}
+            f.write("##########Predicted Frame {}\n".format(str(i + 1)))
+            f.write("Model MSE: %f\n" % mse_model)
+            # f.write("Previous Frame MSE: %f\n" % mse_prev)
+            f.write("Model PSNR: %f\n" % psnr_model)
+            f.write("Model SSIM: %f\n" % ssim_model)
+
+        pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb"))
+        pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb"))
+        # f.write("Previous frame PSNR: %f\n" % psnr_prev)
+        f.write("Shape of X_test: " + str(input_images_all.shape))
+        f.write("")
+        f.write("Shape of X_hat: " + str(gen_images_all.shape))
+
+
+
+    #psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
+    #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
+    #psnr_prev = psnr(input_images_all[:, :, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
+
+    # ims = []
+    # fig = plt.figure()
+    # for frame in range(20):
+    #     input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel)
+    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
+    #     #pix_std = np.std(input_gen_diff, axis=0)
+    #     im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
+    #     if frame == 0:
+    #         fig.colorbar(im)
+    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
+    #     ims.append([im, ttl])
+    # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000)
+    # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
+    # plt.close("all")
+
+    # ims = []
+    # fig = plt.figure()
+    # for frame in range(19):
+    #     pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
+    #     #pix_mean = np.mean(input_gen_diff, axis = 0)
+    #     #pix_std = np.std(input_gen_diff, axis=0)
+    #     im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
+    #     if frame == 0:
+    #         fig.colorbar(im)
+    #     ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
+    #     ims.append([im, ttl])
+    # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
+    # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
+
+    # seed(1)
+    # s = random.sample(range(len(gen_images_all)), 100)
+    # print("******KDP******")
+    # #kernel density plot for checking the model collapse
+    # fig = plt.figure()
+    # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
+    # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
+    # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
+    # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
+    # plt.clf()
+
+    #line plot for evaluating the prediction and groud-truth
+    # for i in [0,3,6,9,12,15,18]:
+    #     fig = plt.figure()
+    #     plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
+    #     #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
+    #     plt.xlabel("Prediction")
+    #     plt.ylabel("Real values")
+    #     plt.title("Frame_{}".format(i+1))
+    #     plt.plot([250,300], [250,300],color="black")
+    #     plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
+    #     plt.clf()
+    #
+    # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
+    # x = [str(i+1) for i in list(range(19))]
+    # fig,axis = plt.subplots()
+    # mean_f = np.mean(mse_model_by_frames, axis = 0)
+    # median = np.median(mse_model_by_frames, axis=0)
+    # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
+    # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
+    # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
+    # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
+    # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
+    # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
+    # plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
+    # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
+    # plt.title(f'MSE percentile')
+    # plt.xlabel("Frames")
+    # plt.legend(loc=2, fontsize=8)
+    # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/plot_results.py b/scripts/plot_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddf49066db5c4276bfb23167bbcde35295c270d4
--- /dev/null
+++ b/scripts/plot_results.py
@@ -0,0 +1,254 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import glob
+import os
+
+import numpy as np
+
+
+def load_metrics(prefix_fname):
+    import csv
+    with open('%s.csv' % prefix_fname, newline='') as csvfile:
+        reader = csv.reader(csvfile, delimiter='\t', quotechar='|')
+        rows = list(reader)
+        # skip header (first row), indices (first column), and means (last column)
+        metrics = np.array(rows)[1:, 1:-1].astype(np.float32)
+    return metrics
+
+
+def plot_metric(metric, start_x=0, color=None, label=None, zorder=None):
+    import matplotlib.pyplot as plt
+    metric_mean = np.mean(metric, axis=0)
+    metric_se = np.std(metric, axis=0) / np.sqrt(len(metric))
+    kwargs = {}
+    if color:
+        kwargs['color'] = color
+    if zorder:
+        kwargs['zorder'] = zorder
+    plt.errorbar(np.arange(len(metric_mean)) + start_x,
+                 metric_mean, yerr=metric_se, linewidth=2,
+                 label=label, **kwargs)
+    # metric_std = np.std(metric, axis=0)
+    # plt.plot(np.arange(len(metric_mean)) + start_x, metric_mean,
+    #          linewidth=2, color=color, label=label)
+    # plt.fill_between(np.arange(len(metric_mean)) + start_x,
+    #                  metric_mean - metric_std, metric_mean + metric_std,
+    #                  color=color, alpha=0.5)
+
+
+def get_color(method_name):
+    import matplotlib.pyplot as plt
+    color_mapping = {
+        'ours_vae_gan': plt.cm.Vega20(0),
+        'ours_gan': plt.cm.Vega20(2),
+        'ours_vae': plt.cm.Vega20(4),
+        'ours_vae_l1': plt.cm.Vega20(4),
+        'ours_vae_l2': plt.cm.Vega20(14),
+        'ours_deterministic': plt.cm.Vega20(6),
+        'ours_deterministic_l1': plt.cm.Vega20(6),
+        'ours_deterministic_l2': plt.cm.Vega20(10),
+        'sna_l1': plt.cm.Vega20(8),
+        'sna_l2': plt.cm.Vega20(9),
+        'sv2p_time_variant': plt.cm.Vega20(16),
+        'sv2p_time_invariant': plt.cm.Vega20(16),
+        'svg_lp': plt.cm.Vega20(18),
+        'svg_fp': plt.cm.Vega20(18),
+        'svg_fp_resized_data_loader': plt.cm.Vega20(18),
+        'mathieu': plt.cm.Vega20(8),
+        'mcnet': plt.cm.Vega20(8),
+        'repeat': 'k',
+    }
+    if method_name in color_mapping:
+        color = color_mapping[method_name]
+    else:
+        color = None
+        for k, v in color_mapping.items():
+            if method_name.startswith(k):
+                color = v
+                break
+    return color
+
+
+def get_method_name(method_name):
+    method_name_mapping = {
+        'ours_vae_gan': 'Ours, SAVP',
+        'ours_gan': 'Ours, GAN-only',
+        'ours_vae': 'Ours, VAE-only',
+        'ours_vae_l1': 'Ours, VAE-only, $\mathcal{L}_1$',
+        'ours_vae_l2': 'Ours, VAE-only, $\mathcal{L}_2$',
+        'ours_deterministic': 'Ours, deterministic',
+        'ours_deterministic_l1': 'Ours, deterministic, $\mathcal{L}_1$',
+        'ours_deterministic_l2': 'Ours, deterministic, $\mathcal{L}_2$',
+        'sna_l1': 'SNA, $\mathcal{L}_1$ (Ebert et al.)',
+        'sna_l2': 'SNA, $\mathcal{L}_2$ (Ebert et al.)',
+        'sv2p_time_variant': 'SV2P time-variant (Babaeizadeh et al.)',
+        'sv2p_time_invariant': 'SV2P time-invariant (Babaeizadeh et al.)',
+        'mathieu': 'Mathieu et al.',
+        'mcnet': 'MCnet (Villegas et al.)',
+        'repeat': 'Copy last frame',
+    }
+    return method_name_mapping.get(method_name, method_name)
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("results_dir", type=str)
+    parser.add_argument("--dataset_name", type=str)
+    parser.add_argument("--method_dirs", type=str, nargs='+', help='directories in results_dir (all of them by default)')
+    parser.add_argument("--method_names", type=str, nargs='+', help='method names for the header')
+    parser.add_argument("--web_dir", type=str, help='default is results_dir/web')
+    parser.add_argument("--plot_fname", type=str, default='metrics.pdf')
+    parser.add_argument('--usetex', '--use_tex', action='store_true')
+    parser.add_argument('--save', action='store_true')
+    parser.add_argument('--mode', choices=['paper', 'rebuttal'], default='paper')
+    parser.add_argument("--plot_metric_names", type=str, nargs='+')
+    args = parser.parse_args()
+
+    if args.save:
+        import matplotlib
+        matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
+    import matplotlib.pyplot as plt
+
+    if args.usetex:
+        plt.rc('text', usetex=True)
+        plt.rc('text.latex', preview=True)
+        plt.rc('font', family='serif')
+
+    if args.web_dir is None:
+        args.web_dir = os.path.join(args.results_dir, 'web')
+
+    if args.method_dirs is None:
+        unsorted_method_dirs = os.listdir(args.results_dir)
+        # exclude web_dir and all directories that starts with web
+        if args.web_dir in unsorted_method_dirs:
+            unsorted_method_dirs.remove(args.web_dir)
+        unsorted_method_dirs = [method_dir for method_dir in unsorted_method_dirs if not os.path.basename(method_dir).startswith('web')]
+        # put ground_truth and repeat in the front (if any)
+        method_dirs = []
+        for first_method_dir in ['ground_truth', 'repeat']:
+            if first_method_dir in unsorted_method_dirs:
+                unsorted_method_dirs.remove(first_method_dir)
+                method_dirs.append(first_method_dir)
+        method_dirs.extend(sorted(unsorted_method_dirs))
+    else:
+        method_dirs = list(args.method_dirs)
+    if args.method_names is None:
+        method_names = [get_method_name(method_dir) for method_dir in method_dirs]
+    else:
+        method_names = list(args.method_names)
+    if args.usetex:
+        method_names = [method_name.replace('kl_weight', r'$\lambda_{\textsc{kl}}$') for method_name in method_names]
+    method_dirs = [os.path.join(args.results_dir, method_dir) for method_dir in method_dirs]
+
+    # infer task and metric names from first method
+    metric_fnames = sorted(glob.glob('%s/*_max/metrics/*.csv' % glob.escape(method_dirs[0])))
+    task_names = []
+    metric_names = []  # all the metric names inferred from file names
+    for metric_fname in metric_fnames:
+        head, tail = os.path.split(metric_fname)
+        task_name = head.split('/')[-2]
+        metric_name, _ = os.path.splitext(tail)
+        task_names.append(task_name)
+        metric_names.append(metric_name)
+
+    # save plots
+    dataset_name = args.dataset_name or os.path.split(os.path.normpath(args.results_dir))[1]
+    plots_dir = os.path.join(args.web_dir, 'plots')
+    if not os.path.exists(plots_dir):
+        os.makedirs(plots_dir)
+
+    if dataset_name in ('bair', 'bair_action_free'):
+        context_frames = 2
+        training_sequence_length = 12
+        plot_metric_names = ('psnr', 'ssim_finn', 'vgg_csim')
+    elif dataset_name == 'kth':
+        context_frames = 10
+        training_sequence_length = 20
+        plot_metric_names = ('psnr', 'ssim_scikit', 'vgg_csim')
+    elif dataset_name == 'ucf101':
+        context_frames = 4
+        training_sequence_length = 8
+        plot_metric_names = ('psnr', 'ssim_mcnet', 'vgg_csim')
+    else:
+        raise NotImplementedError
+    plot_metric_names = args.plot_metric_names or plot_metric_names  # metric names to plot
+
+    if args.mode == 'paper':
+        fig = plt.figure(figsize=(4 * len(plot_metric_names), 5))
+    elif args.mode == 'rebuttal':
+        fig = plt.figure(figsize=(4, 3 * len(plot_metric_names)))
+    else:
+        raise ValueError
+    i_task = 0
+    for task_name, metric_name in zip(task_names, metric_names):
+        if not task_name.endswith('max'):
+            continue
+        if metric_name not in plot_metric_names:
+            continue
+
+        if args.mode == 'paper':
+            plt.subplot(1, len(plot_metric_names), i_task + 1)
+        elif args.mode == 'rebuttal':
+            plt.subplot(len(plot_metric_names), 1, i_task + 1)
+
+        for method_name, method_dir in zip(method_names, method_dirs):
+            metric_fname = os.path.join(method_dir, task_name, 'metrics', metric_name)
+            if not os.path.isfile('%s.csv' % metric_fname):
+                print('Skipping', metric_fname)
+                continue
+            metric = load_metrics(metric_fname)
+            plot_metric(metric, context_frames + 1, color=get_color(os.path.basename(method_dir)), label=method_name)
+
+        plt.grid(axis='y')
+        plt.axvline(x=training_sequence_length, linewidth=1, color='k')
+        fontsize = 12 if args.mode == 'rebuttal' else 15
+        legend_fontsize = 10 if args.mode == 'rebuttal' else 15
+        labelsize = 10
+        if args.mode == 'paper':
+            plt.xlabel('Time Step', fontsize=fontsize)
+        plt.ylabel({
+            'psnr': 'Average PSNR',
+            'ssim': 'Average SSIM',
+            'ssim_scikit': 'Average SSIM',
+            'ssim_finn': 'Average SSIM',
+            'ssim_mcnet': 'Average SSIM',
+            'vgg_csim': 'Average VGG cosine similarity',
+        }[metric_name], fontsize=fontsize)
+        plt.xlim((context_frames + 1, metric.shape[1] + context_frames))
+        plt.tick_params(labelsize=labelsize)
+
+        if args.mode == 'paper':
+            if i_task == 1:
+                # plt.title({
+                #     'bair': 'Action-conditioned BAIR Dataset',
+                #     'bair_action_free': 'Action-free BAIR Dataset',
+                #     'kth': 'KTH Dataset',
+                # }[dataset_name], fontsize=16)
+                if len(method_names) <= 4 and sum([len(method_name) for method_name in method_names]) < 90:
+                    ncol = len(method_names)
+                else:
+                    ncol = (len(method_names) + 1) // 2
+                # ncol = 2
+                plt.legend(bbox_to_anchor=(0.5, -0.12), loc='upper center', ncol=ncol, fontsize=legend_fontsize)
+        elif args.mode == 'rebuttal':
+            if i_task == 0:
+                # plt.legend(fontsize=legend_fontsize)
+                plt.legend(bbox_to_anchor=(0.4, -0.12), loc='upper center', fontsize=legend_fontsize)
+            plt.ylim(ymin=0.8)
+            plt.xlim((context_frames + 1, metric.shape[1] + context_frames))
+        i_task += 1
+    fig.tight_layout(rect=(0, 0.1, 1, 1))
+
+    if args.save:
+        plt.show(block=False)
+        print("Saving to", os.path.join(plots_dir, args.plot_fname))
+        plt.savefig(os.path.join(plots_dir, args.plot_fname), bbox_inches='tight')
+    else:
+        plt.show()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/plot_results_all.sh b/scripts/plot_results_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..11045ca4831cea4c01ef355dd2c333418d83daeb
--- /dev/null
+++ b/scripts/plot_results_all.sh
@@ -0,0 +1,80 @@
+python scripts/plot_results.py results_test/bair_action_free --method_dirs \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sv2p_time_invariant \
+    svg_lp \
+    --save --use_tex --plot_fname metrics_all.pdf
+
+python scripts/plot_results.py results_test/bair --method_dirs \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sna_l1 \
+    sna_l2 \
+    sv2p_time_variant \
+    --save --use_tex --plot_fname metrics_all.pdf
+
+python scripts/plot_results.py results_test/kth --method_dirs \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sv2p_time_variant \
+    sv2p_time_invariant \
+    svg_fp_resized_data_loader \
+    --save --use_tex --plot_fname metrics_all.pdf
+
+
+python scripts/plot_results.py results_test/bair_action_free --method_dirs \
+    sv2p_time_invariant \
+    svg_lp \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics.pdf; \
+python scripts/plot_results.py results_test/bair_action_free --method_dirs \
+    ours_deterministic \
+    ours_vae \
+    ours_gan \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics_ablation.pdf; \
+python scripts/plot_results.py results_test/bair_action_free --method_dirs \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf; \
+python scripts/plot_results.py results_test/kth --method_dirs \
+    sv2p_time_variant \
+    svg_fp_resized_data_loader \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics.pdf; \
+python scripts/plot_results.py results_test/kth --method_dirs \
+    ours_deterministic \
+    ours_vae \
+    ours_gan \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics_ablation.pdf; \
+python scripts/plot_results.py results_test/bair --method_dirs \
+    sv2p_time_variant \
+    ours_deterministic \
+    ours_vae \
+    ours_gan \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics.pdf; \
+python scripts/plot_results.py results_test/bair --
+
+method_dirs \
+    sna_l1 \
+    sna_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..79f317ab1b6f30ab7708331ae4671de25b18167b
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,367 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import random
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from video_prediction import datasets, models
+
+
+def add_tag_suffix(summary, tag_suffix):
+    summary_proto = tf.Summary()
+    summary_proto.ParseFromString(summary)
+    summary = summary_proto
+
+    for value in summary.value:
+        tag_split = value.tag.split('/')
+        value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:])
+    return summary.SerializeToString()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
+    parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
+    parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
+                                             "default is logs_dir/model_fname, where model_fname consists of "
+                                             "information from model and model_hparams")
+    parser.add_argument("--output_dir_postfix", default="")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+    parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
+    parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
+
+    parser.add_argument("--summary_freq", type=int, default=1000, help="save frequency of summaries (except for image and eval summaries) for train/validation set")
+    parser.add_argument("--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set")
+    parser.add_argument("--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set")
+    parser.add_argument("--accum_eval_summary_freq", type=int, default=100000, help="save frequency of accumulated eval summaries for validation set only")
+    parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps")
+    parser.add_argument("--save_freq", type=int, default=5000, help="save frequence of model, 0 to disable")
+
+    parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training")
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+    if args.output_dir is None:
+        list_depth = 0
+        model_fname = ''
+        for t in ('model=%s,%s' % (args.model, args.model_hparams)):
+            if t == '[':
+                list_depth += 1
+            if t == ']':
+                list_depth -= 1
+            if list_depth and t == ',':
+                t = '..'
+            if t in '=,':
+                t = '.'
+            if t in '[]':
+                t = ''
+            model_fname += t
+        args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix
+
+    if args.resume:
+        if args.checkpoint:
+            raise ValueError('resume and checkpoint cannot both be specified')
+        args.checkpoint = args.output_dir
+
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+    if args.dataset_hparams_dict:
+        with open(args.dataset_hparams_dict) as f:
+            dataset_hparams_dict.update(json.loads(f.read()))
+    if args.model_hparams_dict:
+        with open(args.model_hparams_dict) as f:
+            model_hparams_dict.update(json.loads(f.read()))
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict.update(json.loads(f.read()))
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict.update(json.loads(f.read()))
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    train_dataset = VideoDataset(
+        args.input_dir,
+        mode='train',
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
+    val_dataset = VideoDataset(
+        args.val_input_dir or args.input_dir,
+        mode='val',
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
+    if val_dataset.hparams.long_sequence_length != val_dataset.hparams.sequence_length:
+        # the longer dataset is only used for the accum_eval_metrics
+        long_val_dataset = VideoDataset(
+            args.val_input_dir or args.input_dir,
+            mode='val',
+            hparams_dict=dataset_hparams_dict,
+            hparams=args.dataset_hparams)
+        long_val_dataset.set_sequence_length(val_dataset.hparams.long_sequence_length)
+    else:
+        long_val_dataset = None
+
+    variable_scope = tf.get_variable_scope()
+    variable_scope.set_use_resource(True)
+
+    VideoPredictionModel = models.get_model_class(args.model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': train_dataset.hparams.context_frames,#Bing: TODO what is context_frames?
+        'sequence_length': train_dataset.hparams.sequence_length,#Bing: TODO what is sequence_frames
+        'repeat': train_dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams,
+        aggregate_nccl=args.aggregate_nccl)
+
+    batch_size = model.hparams.batch_size
+    train_tf_dataset = train_dataset.make_dataset(batch_size)#Bing: adopt the meteo data prepartion here
+    train_iterator = train_tf_dataset.make_one_shot_iterator()#Bing:for era5, the problem happen in sess.run(feches) should come from here
+    # The `Iterator.string_handle()` method returns a tensor that can be evaluated
+    # and used to feed the `handle` placeholder.
+    train_handle = train_iterator.string_handle()
+    val_tf_dataset = val_dataset.make_dataset(batch_size)
+    val_iterator = val_tf_dataset.make_one_shot_iterator()
+    val_handle = val_iterator.string_handle()
+    iterator = tf.data.Iterator.from_string_handle(
+        train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
+    inputs = iterator.get_next()
+    #Bing for debug
+    with tf.Session() as sess:
+        for i in range(2):
+            print(sess.run(tf.shape(inputs["images"])))
+
+    # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles
+    model.build_graph(inputs)
+
+    if long_val_dataset is not None:
+        # separately build a model for the longer sequence.
+        # this is needed because the model doesn't support dynamic shapes.
+        long_hparams_dict = dict(hparams_dict)
+        long_hparams_dict['sequence_length'] = long_val_dataset.hparams.sequence_length
+        # use smaller batch size for longer model to prevenet running out of memory
+        long_hparams_dict['batch_size'] = model.hparams.batch_size // 2
+        long_model = VideoPredictionModel(
+            mode="test",  # to not build the losses and discriminators
+            hparams_dict=long_hparams_dict,
+            hparams=args.model_hparams,
+            aggregate_nccl=args.aggregate_nccl)
+        tf.get_variable_scope().reuse_variables()
+        long_model.build_graph(long_val_dataset.make_batch(batch_size))
+    else:
+        long_model = None
+
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+
+    with tf.name_scope("parameter_count"):
+        # exclude trainable variables that are replicas (used in multi-gpu setting)
+        trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables)
+        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables])
+
+    saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2)
+
+    # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero
+    if (args.summary_freq != 0 or args.image_summary_freq != 0 or
+            args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0):
+        summary_writer = tf.summary.FileWriter(args.output_dir)
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+    global_step = tf.train.get_or_create_global_step()
+    max_steps = model.hparams.max_steps
+    with tf.Session(config=config) as sess:
+        print("parameter_count =", sess.run(parameter_count))
+        sess.run(tf.global_variables_initializer())
+        sess.run(tf.local_variables_initializer())
+        #coord = tf.train.Coordinator()
+        #threads = tf.train.start_queue_runners(sess = sess, coord = coord)
+        print("Init done: {sess.run(tf.local_variables_initializer())}%")
+        model.restore(sess, args.checkpoint)
+        print("Restore processed finished")
+        sess.run(model.post_init_ops)
+        print("Model run started")
+        val_handle_eval = sess.run(val_handle)
+        print("val handle done")
+        sess.graph.finalize()
+        print("graph inalize done")
+        start_step = sess.run(global_step)
+        print("global step done")
+
+        def should(step, freq):
+            if freq is None:
+                return (step + 1) == (max_steps - start_step)
+            else:
+                return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step))
+
+        def should_eval(step, freq):
+            # never run eval summaries at the beginning since it's expensive, unless it's the last iteration
+            return should(step, freq) and (step >= 0 or (step + 1) == (max_steps - start_step))
+
+        # start at one step earlier to log everything without doing any training
+        # step is relative to the start_step
+        for step in range(-1, max_steps - start_step):
+            if step == 1:
+                # skip step -1 and 0 for timing purposes (for warmstarting)
+                start_time = time.time()
+
+            fetches = {"global_step": global_step}
+            if step >= 0:
+                fetches["train_op"] = model.train_op
+            if should(step, args.progress_freq):
+                fetches['d_loss'] = model.d_loss
+                fetches['g_loss'] = model.g_loss
+                fetches['d_losses'] = model.d_losses
+                fetches['g_losses'] = model.g_losses
+                if isinstance(model.learning_rate, tf.Tensor):
+                    fetches["learning_rate"] = model.learning_rate
+            if should(step, args.summary_freq):
+                fetches["summary"] = model.summary_op
+            if should(step, args.image_summary_freq):
+                fetches["image_summary"] = model.image_summary_op
+            if should_eval(step, args.eval_summary_freq):
+                fetches["eval_summary"] = model.eval_summary_op
+
+            run_start_time = time.time()
+            results = sess.run(fetches) #fetch the elements in dictinoary fetch
+
+            run_elapsed_time = time.time() - run_start_time
+            if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}:
+                print('running train_op took too long (%0.1fs)' % run_elapsed_time)
+
+            if (should(step, args.summary_freq) or
+                    should(step, args.image_summary_freq) or
+                    should_eval(step, args.eval_summary_freq)):
+                val_fetches = {"global_step": global_step}
+                if should(step, args.summary_freq):
+                    val_fetches["summary"] = model.summary_op
+                if should(step, args.image_summary_freq):
+                    val_fetches["image_summary"] = model.image_summary_op
+                if should_eval(step, args.eval_summary_freq):
+                    val_fetches["eval_summary"] = model.eval_summary_op
+                val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
+                for name, summary in val_results.items():
+                    if name == 'global_step':
+                        continue
+                    val_results[name] = add_tag_suffix(summary, '_1')
+
+            if should(step, args.summary_freq):
+                print("recording summary")
+                summary_writer.add_summary(results["summary"], results["global_step"])
+                summary_writer.add_summary(val_results["summary"], val_results["global_step"])
+                print("done")
+            if should(step, args.image_summary_freq):
+                print("recording image summary")
+                summary_writer.add_summary(results["image_summary"], results["global_step"])
+                summary_writer.add_summary(val_results["image_summary"], val_results["global_step"])
+                print("done")
+            if should_eval(step, args.eval_summary_freq):
+                print("recording eval summary")
+                summary_writer.add_summary(results["eval_summary"], results["global_step"])
+                summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"])
+                print("done")
+            if should_eval(step, args.accum_eval_summary_freq):
+                val_datasets = [val_dataset]
+                val_models = [model]
+                if long_model is not None:
+                    val_datasets.append(long_val_dataset)
+                    val_models.append(long_model)
+                for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)):
+                    sess.run(val_model.accum_eval_metrics_reset_op)
+                    # traverse (roughly up to rounding based on the batch size) all the validation dataset
+                    accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size
+                    val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op}
+                    for update_step in range(accum_eval_summary_num_updates):
+                        print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates))
+                        val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
+                    accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1))
+                    print("recording accum eval summary")
+                    summary_writer.add_summary(accum_eval_summary, val_results["global_step"])
+                    print("done")
+            if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or
+                    should_eval(step, args.eval_summary_freq) or should_eval(step, args.accum_eval_summary_freq)):
+                summary_writer.flush()
+            if should(step, args.progress_freq):
+                # global_step will have the correct step count if we resume from a checkpoint
+                # global step is read before it's incremented
+                steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size
+                train_epoch = results["global_step"] / steps_per_epoch
+                print("progress  global step %d  epoch %0.1f" % (results["global_step"] + 1, train_epoch))
+                if step > 0:
+                    elapsed_time = time.time() - start_time
+                    average_time = elapsed_time / step
+                    images_per_sec = batch_size / average_time
+                    remaining_time = (max_steps - (start_step + step + 1)) * average_time
+                    print("          image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)" %
+                          (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))
+
+                if results['d_losses']:
+                    print("d_loss", results["d_loss"])
+                for name, loss in results['d_losses'].items():
+                    print("  ", name, loss)
+                if results['g_losses']:
+                    print("g_loss", results["g_loss"])
+                for name, loss in results['g_losses'].items():
+                    print("  ", name, loss)
+                if isinstance(model.learning_rate, tf.Tensor):
+                    print("learning_rate", results["learning_rate"])
+
+            if should(step, args.save_freq):
+                print("saving model to", args.output_dir)
+                saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step)
+                print("done")
+
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/train_all.sh b/scripts/train_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c695a8b2453956dcac54c5440c0058e5598fa03d
--- /dev/null
+++ b/scripts/train_all.sh
@@ -0,0 +1,40 @@
+# BAIR action-free robot pushing dataset
+for model in \
+  ours_deterministic_l1 \
+  ours_deterministic_l2 \
+  ours_vae_l1 \
+  ours_vae_l2 \
+  ours_gan \
+  ours_savp \
+; do
+  CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair --model savp --model_hparams_dict hparams/bair_action_free/${model}/model_hparams.json --output_dir logs/bair_action_free/${model}
+done
+
+# KTH human actions dataset
+for model in \
+  ours_deterministic_l1 \
+  ours_deterministic_l2 \
+  ours_vae_l1 \
+  ours_gan \
+  ours_savp \
+; do
+  CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/kth --dataset kth --model savp --model_hparams_dict hparams/kth/${model}/model_hparams.json --output_dir logs/kth/${model}
+done
+
+# BAIR action-conditioned robot pushing dataset
+for model in \
+  ours_deterministic_l1 \
+  ours_deterministic_l2 \
+  ours_vae_l1 \
+  ours_vae_l2 \
+  ours_gan \
+  ours_savp \
+; do
+  CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair --dataset_hparams use_state=True --model savp --model_hparams_dict hparams/bair/${model}/model_hparams.json --output_dir logs/bair/${model}
+done
+for model in \
+  sna_l1 \
+  sna_l2 \
+; do
+  CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair --dataset_hparams use_state=True --model sna --model_hparams_dict hparams/bair/${model}/model_hparams.json --output_dir logs/bair/${model}
+done
diff --git a/scripts/train_dummy.py b/scripts/train_dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f892f69c901f1eaa0a7ce2e57a3d0f6f131a7f9
--- /dev/null
+++ b/scripts/train_dummy.py
@@ -0,0 +1,274 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import random
+import time
+import numpy as np
+import tensorflow as tf
+from video_prediction import datasets, models
+
+
+def add_tag_suffix(summary, tag_suffix):
+    summary_proto = tf.Summary()
+    summary_proto.ParseFromString(summary)
+    summary = summary_proto
+
+    for value in summary.value:
+        tag_split = value.tag.split('/')
+        value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:])
+    return summary.SerializeToString()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
+    parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
+    parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
+                                             "default is logs_dir/model_fname, where model_fname consists of "
+                                             "information from model and model_hparams")
+    parser.add_argument("--output_dir_postfix", default="")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+    parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
+    parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
+
+   # parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training")
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+    if args.output_dir is None:
+        list_depth = 0
+        model_fname = ''
+        for t in ('model=%s,%s' % (args.model, args.model_hparams)):
+            if t == '[':
+                list_depth += 1
+            if t == ']':
+                list_depth -= 1
+            if list_depth and t == ',':
+                t = '..'
+            if t in '=,':
+                t = '.'
+            if t in '[]':
+                t = ''
+            model_fname += t
+        args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix
+
+    if args.resume:
+        if args.checkpoint:
+            raise ValueError('resume and checkpoint cannot both be specified')
+        args.checkpoint = args.output_dir
+
+
+    model_hparams_dict = {}
+    if args.model_hparams_dict:
+        with open(args.model_hparams_dict) as f:
+            model_hparams_dict.update(json.loads(f.read()))
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict.update(json.loads(f.read()))
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    train_dataset = VideoDataset(
+        args.input_dir,
+        mode='train')
+    val_dataset = VideoDataset(
+        args.val_input_dir or args.input_dir,
+        mode='val')
+
+    variable_scope = tf.get_variable_scope()
+    variable_scope.set_use_resource(True)
+
+    VideoPredictionModel = models.get_model_class(args.model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': train_dataset.hparams.context_frames,
+        'sequence_length': train_dataset.hparams.sequence_length,
+        'repeat': train_dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams)
+
+    batch_size = model.hparams.batch_size
+    train_tf_dataset = train_dataset.make_dataset_v2(batch_size)
+    train_iterator = train_tf_dataset.make_one_shot_iterator()
+    # The `Iterator.string_handle()` method returns a tensor that can be evaluated
+    # and used to feed the `handle` placeholder.
+    train_handle = train_iterator.string_handle()
+    val_tf_dataset = val_dataset.make_dataset_v2(batch_size)
+    val_iterator = val_tf_dataset.make_one_shot_iterator()
+    val_handle = val_iterator.string_handle()
+    #iterator = tf.data.Iterator.from_string_handle(
+    #    train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
+    inputs = train_iterator.get_next()
+    val = val_iterator.get_next()
+   
+    model.build_graph(inputs)
+
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+
+    with tf.name_scope("parameter_count"):
+        # exclude trainable variables that are replicas (used in multi-gpu setting)
+        trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables)
+        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables])
+
+    saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2)
+
+    # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero
+    summary_writer = tf.summary.FileWriter(args.output_dir)
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+  
+ 
+    max_steps = model.hparams.max_steps
+    print ("max_steps",max_steps)
+    with tf.Session(config=config) as sess:
+        print("parameter_count =", sess.run(parameter_count))
+        sess.run(tf.global_variables_initializer())
+        sess.run(tf.local_variables_initializer())
+        #coord = tf.train.Coordinator()
+        #threads = tf.train.start_queue_runners(sess = sess, coord = coord)
+        print("Init done: {sess.run(tf.local_variables_initializer())}%")
+        model.restore(sess, args.checkpoint)
+
+        #sess.run(model.post_init_ops)
+
+        #val_handle_eval = sess.run(val_handle)
+        #print ("val_handle_val",val_handle_eval)
+        #print("val handle done")
+        sess.graph.finalize()
+        start_step = sess.run(model.global_step)
+
+
+        # start at one step earlier to log everything without doing any training
+        # step is relative to the start_step
+        for step in range(-1, max_steps - start_step):
+            global_step = sess.run(model.global_step)
+            print ("global_step:", global_step)
+            val_handle_eval = sess.run(val_handle)
+
+            if step == 1:
+                # skip step -1 and 0 for timing purposes (for warmstarting)
+                start_time = time.time()
+            
+            fetches = {"global_step":model.global_step}
+            fetches["train_op"] = model.train_op
+
+           # fetches["latent_loss"] = model.latent_loss
+            fetches["total_loss"] = model.total_loss
+            if model.__class__.__name__ == "McNetVideoPredictionModel":
+                fetches["L_p"] = model.L_p
+                fetches["L_gdl"] = model.L_gdl
+                fetches["L_GAN"]  =model.L_GAN
+          
+         
+
+            fetches["summary"] = model.summary_op
+
+            run_start_time = time.time()
+            #Run training results
+            #X = inputs["images"].eval(session=sess)           
+
+            results = sess.run(fetches)
+
+            run_elapsed_time = time.time() - run_start_time
+            if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}:
+                print('running train_op took too long (%0.1fs)' % run_elapsed_time)
+
+            #Run testing results
+            #val_fetches = {"global_step":global_step}
+            val_fetches = {}
+            #val_fetches["latent_loss"] = model.latent_loss
+            #val_fetches["total_loss"] = model.total_loss
+            val_fetches["summary"] = model.summary_op
+            val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval})
+          
+            summary_writer.add_summary(results["summary"])
+            summary_writer.add_summary(val_results["summary"])
+             
+
+
+           
+            val_datasets = [val_dataset]
+            val_models = [model]
+
+            # for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)):
+            #     sess.run(val_model.accum_eval_metrics_reset_op)
+            #     # traverse (roughly up to rounding based on the batch size) all the validation dataset
+            #     accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size
+            #     val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op}
+            #     for update_step in range(accum_eval_summary_num_updates):
+            #         print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates))
+            #         val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
+            #     accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1))
+            #     print("recording accum eval summary")
+            #     summary_writer.add_summary(accum_eval_summary, val_results["global_step"])
+            summary_writer.flush()
+
+            # global_step will have the correct step count if we resume from a checkpoint
+            # global step is read before it's incremented
+            steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size
+            #train_epoch = results["global_step"] / steps_per_epoch
+            train_epoch = global_step/steps_per_epoch
+            print("progress  global step %d  epoch %0.1f" % (global_step + 1, train_epoch))
+            if step > 0:
+                elapsed_time = time.time() - start_time
+                average_time = elapsed_time / step
+                images_per_sec = batch_size / average_time
+                remaining_time = (max_steps - (start_step + step + 1)) * average_time
+                print("image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)" %
+                      (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))
+
+
+            print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"]))
+            
+            print("saving model to", args.output_dir)
+            saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)##Bing: cheat here a little bit because of the global step issue
+            print("done")
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/train_v2.py b/scripts/train_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..f83192d6af7d3953c666f40cc9e6d3766a78e92e
--- /dev/null
+++ b/scripts/train_v2.py
@@ -0,0 +1,362 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import json
+import os
+import random
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from video_prediction import datasets, models
+
+
+def add_tag_suffix(summary, tag_suffix):
+    summary_proto = tf.Summary()
+    summary_proto.ParseFromString(summary)
+    summary = summary_proto
+
+    for value in summary.value:
+        tag_split = value.tag.split('/')
+        value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:])
+    return summary.SerializeToString()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
+    parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
+    parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
+                                             "default is logs_dir/model_fname, where model_fname consists of "
+                                             "information from model and model_hparams")
+    parser.add_argument("--output_dir_postfix", default="")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+    parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
+    parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
+
+    parser.add_argument("--summary_freq", type=int, default=1000, help="save frequency of summaries (except for image and eval summaries) for train/validation set")
+    parser.add_argument("--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set")
+    parser.add_argument("--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set")
+    parser.add_argument("--accum_eval_summary_freq", type=int, default=100000, help="save frequency of accumulated eval summaries for validation set only")
+    parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps")
+    parser.add_argument("--save_freq", type=int, default=5000, help="save frequence of model, 0 to disable")
+
+    parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training")
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+    if args.output_dir is None:
+        list_depth = 0
+        model_fname = ''
+        for t in ('model=%s,%s' % (args.model, args.model_hparams)):
+            if t == '[':
+                list_depth += 1
+            if t == ']':
+                list_depth -= 1
+            if list_depth and t == ',':
+                t = '..'
+            if t in '=,':
+                t = '.'
+            if t in '[]':
+                t = ''
+            model_fname += t
+        args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix
+
+    if args.resume:
+        if args.checkpoint:
+            raise ValueError('resume and checkpoint cannot both be specified')
+        args.checkpoint = args.output_dir
+
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+    if args.dataset_hparams_dict:
+        with open(args.dataset_hparams_dict) as f:
+            dataset_hparams_dict.update(json.loads(f.read()))
+    if args.model_hparams_dict:
+        with open(args.model_hparams_dict) as f:
+            model_hparams_dict.update(json.loads(f.read()))
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict.update(json.loads(f.read()))
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict.update(json.loads(f.read()))
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    train_dataset = VideoDataset(
+        args.input_dir,
+        mode='train',
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
+    val_dataset = VideoDataset(
+        args.val_input_dir or args.input_dir,
+        mode='val',
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
+    if val_dataset.hparams.long_sequence_length != val_dataset.hparams.sequence_length:
+        # the longer dataset is only used for the accum_eval_metrics
+        long_val_dataset = VideoDataset(
+            args.val_input_dir or args.input_dir,
+            mode='val',
+            hparams_dict=dataset_hparams_dict,
+            hparams=args.dataset_hparams)
+        long_val_dataset.set_sequence_length(val_dataset.hparams.long_sequence_length)
+    else:
+        long_val_dataset = None
+
+    variable_scope = tf.get_variable_scope()
+    variable_scope.set_use_resource(True)
+
+    VideoPredictionModel = models.get_model_class(args.model)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': train_dataset.hparams.context_frames,
+        'sequence_length': train_dataset.hparams.sequence_length,
+        'repeat': train_dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams,
+        aggregate_nccl=args.aggregate_nccl)
+
+    batch_size = model.hparams.batch_size
+    train_tf_dataset = train_dataset.make_dataset_v2(batch_size)#Bing: adopt the meteo data prepartion here
+    train_iterator = train_tf_dataset.make_one_shot_iterator()#Bing:for era5, the problem happen in sess.run(feches) should come from here
+    # The `Iterator.string_handle()` method returns a tensor that can be evaluated
+    # and used to feed the `handle` placeholder.
+    train_handle = train_iterator.string_handle()
+    val_tf_dataset = val_dataset.make_dataset_v2(batch_size)
+    val_iterator = val_tf_dataset.make_one_shot_iterator()
+    val_handle = val_iterator.string_handle()
+    #iterator = tf.data.Iterator.from_string_handle(
+    #    train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
+    inputs = train_iterator.get_next()
+
+    # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles
+    model.build_graph(inputs, finetune=True)
+
+    if long_val_dataset is not None:
+        # separately build a model for the longer sequence.
+        # this is needed because the model doesn't support dynamic shapes.
+        long_hparams_dict = dict(hparams_dict)
+        long_hparams_dict['sequence_length'] = long_val_dataset.hparams.sequence_length
+        # use smaller batch size for longer model to prevenet running out of memory
+        long_hparams_dict['batch_size'] = model.hparams.batch_size // 2
+        long_model = VideoPredictionModel(
+            mode="test",  # to not build the losses and discriminators
+            hparams_dict=long_hparams_dict,
+            hparams=args.model_hparams,
+            aggregate_nccl=args.aggregate_nccl)
+        tf.get_variable_scope().reuse_variables()
+        long_model.build_graph(long_val_dataset.make_batch(batch_size))
+    else:
+        long_model = None
+
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
+        f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
+    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
+        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+
+    with tf.name_scope("parameter_count"):
+        # exclude trainable variables that are replicas (used in multi-gpu setting)
+        trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables)
+        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables])
+
+    saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2)
+
+    # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero
+    if (args.summary_freq != 0 or args.image_summary_freq != 0 or
+            args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0):
+        summary_writer = tf.summary.FileWriter(args.output_dir)
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+    global_step = tf.train.get_or_create_global_step()
+    max_steps = model.hparams.max_steps
+    with tf.Session(config=config) as sess:
+        print("parameter_count =", sess.run(parameter_count))
+        sess.run(tf.global_variables_initializer())
+        sess.run(tf.local_variables_initializer())
+        #coord = tf.train.Coordinator()
+        #threads = tf.train.start_queue_runners(sess = sess, coord = coord)
+        print("Init done: {sess.run(tf.local_variables_initializer())}%")
+        model.restore(sess, args.checkpoint)
+        print("Restore processed finished")
+        sess.run(model.post_init_ops)
+        print("Model run started")
+        val_handle_eval = sess.run(val_handle)
+        print("val handle done")
+        sess.graph.finalize()
+        print("graph inalize done")
+        start_step = sess.run(global_step)
+        print("global step done")
+
+        def should(step, freq):
+            if freq is None:
+                return (step + 1) == (max_steps - start_step)
+            else:
+                return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step))
+
+        def should_eval(step, freq):
+            # never run eval summaries at the beginning since it's expensive, unless it's the last iteration
+            return should(step, freq) and (step >= 0 or (step + 1) == (max_steps - start_step))
+
+        # start at one step earlier to log everything without doing any training
+        # step is relative to the start_step
+        for step in range(-1, max_steps - start_step):
+            if step == 1:
+                # skip step -1 and 0 for timing purposes (for warmstarting)
+                start_time = time.time()
+
+            fetches = {"global_step": global_step}
+            if step >= 0:
+                fetches["train_op"] = model.train_op
+            if should(step, args.progress_freq):
+                fetches['d_loss'] = model.d_loss
+                fetches['g_loss'] = model.g_loss
+                fetches['d_losses'] = model.d_losses
+                fetches['g_losses'] = model.g_losses
+                if isinstance(model.learning_rate, tf.Tensor):
+                    fetches["learning_rate"] = model.learning_rate
+            if should(step, args.summary_freq):
+                fetches["summary"] = model.summary_op
+            if should(step, args.image_summary_freq):
+                fetches["image_summary"] = model.image_summary_op
+            if should_eval(step, args.eval_summary_freq):
+                fetches["eval_summary"] = model.eval_summary_op
+
+            run_start_time = time.time()
+            results = sess.run(fetches) #fetch the elements in dictinoary fetch
+
+            run_elapsed_time = time.time() - run_start_time
+            if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}:
+                print('running train_op took too long (%0.1fs)' % run_elapsed_time)
+
+            if (should(step, args.summary_freq) or
+                    should(step, args.image_summary_freq) or
+                    should_eval(step, args.eval_summary_freq)):
+                val_fetches = {"global_step": global_step}
+                if should(step, args.summary_freq):
+                    val_fetches["summary"] = model.summary_op
+                if should(step, args.image_summary_freq):
+                    val_fetches["image_summary"] = model.image_summary_op
+                if should_eval(step, args.eval_summary_freq):
+                    val_fetches["eval_summary"] = model.eval_summary_op
+                val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
+                for name, summary in val_results.items():
+                    if name == 'global_step':
+                        continue
+                    val_results[name] = add_tag_suffix(summary, '_1')
+
+            if should(step, args.summary_freq):
+                print("recording summary")
+                summary_writer.add_summary(results["summary"], results["global_step"])
+                summary_writer.add_summary(val_results["summary"], val_results["global_step"])
+                print("done")
+            if should(step, args.image_summary_freq):
+                print("recording image summary")
+                summary_writer.add_summary(results["image_summary"], results["global_step"])
+                summary_writer.add_summary(val_results["image_summary"], val_results["global_step"])
+                print("done")
+            if should_eval(step, args.eval_summary_freq):
+                print("recording eval summary")
+                summary_writer.add_summary(results["eval_summary"], results["global_step"])
+                summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"])
+                print("done")
+            if should_eval(step, args.accum_eval_summary_freq):
+                val_datasets = [val_dataset]
+                val_models = [model]
+                if long_model is not None:
+                    val_datasets.append(long_val_dataset)
+                    val_models.append(long_model)
+                for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)):
+                    sess.run(val_model.accum_eval_metrics_reset_op)
+                    # traverse (roughly up to rounding based on the batch size) all the validation dataset
+                    accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size
+                    val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op}
+                    for update_step in range(accum_eval_summary_num_updates):
+                        print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates))
+                        val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
+                    accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1))
+                    print("recording accum eval summary")
+                    summary_writer.add_summary(accum_eval_summary, val_results["global_step"])
+                    print("done")
+            if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or
+                    should_eval(step, args.eval_summary_freq) or should_eval(step, args.accum_eval_summary_freq)):
+                summary_writer.flush()
+            if should(step, args.progress_freq):
+                # global_step will have the correct step count if we resume from a checkpoint
+                # global step is read before it's incremented
+                steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size
+                train_epoch = results["global_step"] / steps_per_epoch
+                print("progress  global step %d  epoch %0.1f" % (results["global_step"] + 1, train_epoch))
+                if step > 0:
+                    elapsed_time = time.time() - start_time
+                    average_time = elapsed_time / step
+                    images_per_sec = batch_size / average_time
+                    remaining_time = (max_steps - (start_step + step + 1)) * average_time
+                    print("          image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)" %
+                          (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))
+
+                if results['d_losses']:
+                    print("d_loss", results["d_loss"])
+                for name, loss in results['d_losses'].items():
+                    print("  ", name, loss)
+                if results['g_losses']:
+                    print("g_loss", results["g_loss"])
+                for name, loss in results['g_losses'].items():
+                    print("  ", name, loss)
+                if isinstance(model.learning_rate, tf.Tensor):
+                    print("learning_rate", results["learning_rate"])
+
+            if should(step, args.save_freq):
+                print("saving model to", args.output_dir)
+                saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step)
+                print("done")
+
+if __name__ == '__main__':
+    main()
diff --git a/video_prediction/.DS_Store b/video_prediction/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..c43ee62a22d53cc75b181ee2096e26474c200495
Binary files /dev/null and b/video_prediction/.DS_Store differ
diff --git a/video_prediction/__init__.py b/video_prediction/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3089b251bbbcbe086a03a4ea63571ce9427b2cb9
--- /dev/null
+++ b/video_prediction/__init__.py
@@ -0,0 +1,3 @@
+from . import losses
+from . import metrics
+from . import ops
diff --git a/video_prediction/datasets/__init__.py b/video_prediction/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..736b8202172051f36586db5579c545863c72e14d
--- /dev/null
+++ b/video_prediction/datasets/__init__.py
@@ -0,0 +1,29 @@
+from .base_dataset import BaseVideoDataset
+from .base_dataset import VideoDataset, SequenceExampleVideoDataset, VarLenFeatureVideoDataset
+from .google_robot_dataset import GoogleRobotVideoDataset
+from .sv2p_dataset import SV2PVideoDataset
+from .softmotion_dataset import SoftmotionVideoDataset
+from .kth_dataset import KTHVideoDataset
+from .ucf101_dataset import UCF101VideoDataset
+from .cartgripper_dataset import CartgripperVideoDataset
+from .era5_dataset_v2 import ERA5Dataset_v2
+from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly
+
+def get_dataset_class(dataset):
+    dataset_mappings = {
+        'google_robot': 'GoogleRobotVideoDataset',
+        'sv2p': 'SV2PVideoDataset',
+        'softmotion': 'SoftmotionVideoDataset',
+        'bair': 'SoftmotionVideoDataset',  # alias of softmotion
+        'kth': 'KTHVideoDataset',
+        'ucf101': 'UCF101VideoDataset',
+        'cartgripper': 'CartgripperVideoDataset',
+        "era5":"ERA5Dataset_v2",
+        "era5_anomaly":"ERA5Dataset_v2_anomaly",
+    }
+    dataset_class = dataset_mappings.get(dataset, dataset)
+    print("datset_class",dataset_class)
+    dataset_class = globals().get(dataset_class)
+    if dataset_class is None or not issubclass(dataset_class, BaseVideoDataset):
+        raise ValueError('Invalid dataset %s' % dataset)
+    return dataset_class
diff --git a/video_prediction/datasets/base_dataset.py b/video_prediction/datasets/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fb6b00550e178ce845f585f9f07a255220c26da
--- /dev/null
+++ b/video_prediction/datasets/base_dataset.py
@@ -0,0 +1,521 @@
+import glob
+import os
+import random
+import re
+from collections import OrderedDict
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib.training import HParams
+
+
+class BaseVideoDataset(object):
+    def __init__(self, input_dir, mode='train', num_epochs=None, seed=None,
+                 hparams_dict=None, hparams=None):
+        """
+        Args:
+            input_dir: either a directory containing subdirectories train,
+                val, test, etc, or a directory containing the tfrecords.
+            mode: either train, val, or test
+            num_epochs: if None, dataset is iterated indefinitely.
+            seed: random seed for the op that samples subsequences.
+            hparams_dict: a dict of `name=value` pairs, where `name` must be
+                defined in `self.get_default_hparams()`.
+            hparams: a string of comma separated list of `name=value` pairs,
+                where `name` must be defined in `self.get_default_hparams()`.
+                These values overrides any values in hparams_dict (if any).
+        Note:
+            self.input_dir is the directory containing the tfrecords.
+        """
+
+        self.input_dir = os.path.normpath(os.path.expanduser(input_dir))
+        self.mode = mode
+        self.num_epochs = num_epochs
+        self.seed = seed
+
+        if self.mode not in ('train', 'val', 'test'):
+            raise ValueError('Invalid mode %s' % self.mode)
+
+        if not os.path.exists(self.input_dir):
+            raise FileNotFoundError("input_dir %s does not exist" % self.input_dir)
+        self.filenames = None
+        # look for tfrecords in input_dir and input_dir/mode directories
+        for input_dir in [self.input_dir, os.path.join(self.input_dir, self.mode)]:
+            filenames = glob.glob(os.path.join(input_dir, '*.tfrecord*'))
+            if filenames:
+                self.input_dir = input_dir
+                self.filenames = sorted(filenames)  # ensures order is the same across systems
+                break
+        if not self.filenames:
+            raise FileNotFoundError('No tfrecords were found in %s.' % self.input_dir)
+        self.dataset_name = os.path.basename(os.path.split(self.input_dir)[0])
+
+        self.state_like_names_and_shapes = OrderedDict()
+        self.action_like_names_and_shapes = OrderedDict()
+
+        self.hparams = self.parse_hparams(hparams_dict, hparams)
+        #Bing: add this for anomaly
+        if os.path.exists(input_dir+"_mean"):
+            input_mean_dir = input_dir+"_mean"
+            self.filenames_mean = sorted(glob.glob(os.path.join(input_mean_dir, '*.tfrecord*')))
+        else:
+            self.filenames_mean = None
+
+
+    def get_default_hparams_dict(self):
+        """
+        Returns:
+            A dict with the following hyperparameters.
+
+            crop_size: crop image into a square with sides of this length.
+            scale_size: resize image to this size after it has been cropped.
+            context_frames: the number of ground-truth frames to pass in at
+                start.
+            sequence_length: the number of frames in the video sequence, so
+                state-like sequences are of length sequence_length and
+                action-like sequences are of length sequence_length - 1.
+                This number includes the context frames.
+            long_sequence_length: the number of frames for the long version.
+                The default is the same as sequence_length.
+            frame_skip: number of frames to skip in between outputted frames,
+                so frame_skip=0 denotes no skipping.
+            time_shift: shift in time by multiples of this, so time_shift=1
+                denotes all possible shifts. time_shift=0 denotes no shifting.
+                It is ignored (equiv. to time_shift=0) when mode != 'train'.
+            force_time_shift: whether to do the shift in time regardless of
+                mode.
+            shuffle_on_val: whether to shuffle the samples regardless if mode
+                is 'train' or 'val'. Shuffle never happens when mode is 'test'.
+            use_state: whether to load and return state and actions.
+        """
+        hparams = dict(
+            crop_size=0,
+            scale_size=0,
+            context_frames=1,
+            sequence_length=0,
+            long_sequence_length=0,
+            frame_skip=0,
+            time_shift=1,
+            force_time_shift=False,
+            shuffle_on_val=False,
+            use_state=False,
+        )
+        return hparams
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
+
+    def parse_hparams(self, hparams_dict, hparams):
+        parsed_hparams = self.get_default_hparams().override_from_dict(hparams_dict or {})
+        if hparams:
+            if not isinstance(hparams, (list, tuple)):
+                hparams = [hparams]
+            for hparam in hparams:
+                parsed_hparams.parse(hparam)
+        if parsed_hparams.long_sequence_length == 0:
+            parsed_hparams.long_sequence_length = parsed_hparams.sequence_length
+        return parsed_hparams
+
+    @property
+    def jpeg_encoding(self):
+        raise NotImplementedError
+
+    def set_sequence_length(self, sequence_length):
+        self.hparams.sequence_length = sequence_length
+
+    def filter(self, serialized_example):
+        return tf.convert_to_tensor(True)
+
+    def parser(self, serialized_example):
+        """
+        Parses a single tf.train.Example or tf.train.SequenceExample into
+        images, states, actions, etc tensors.
+        """
+
+
+        raise NotImplementedError
+
+    def make_dataset(self, batch_size):
+        filenames = self.filenames
+        shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
+        if shuffle:
+            random.shuffle(filenames)
+
+        dataset = tf.data.TFRecordDataset(filenames, buffer_size= 8 * 1024 * 1024) #todo: what is buffer_size
+        dataset = dataset.filter(self.filter)
+        if shuffle:
+            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1024, count=self.num_epochs))
+        else:
+            dataset = dataset.repeat(self.num_epochs)
+
+        def _parser(serialized_example):
+            state_like_seqs, action_like_seqs = self.parser(serialized_example)
+            seqs = OrderedDict(list(state_like_seqs.items()) + list(action_like_seqs.items()))
+            return seqs
+
+        num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
+        dataset = dataset.apply(tf.contrib.data.map_and_batch(
+            _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
+        dataset = dataset.prefetch(batch_size) #Bing: Take the data to buffer inorder to save the waiting time for GPU
+        return dataset
+
+    def make_batch(self, batch_size):
+        dataset = self.make_dataset(batch_size)
+        iterator = dataset.make_one_shot_iterator()
+        return iterator.get_next()
+
+
+    def decode_and_preprocess_images(self, image_buffers, image_shape):
+        def decode_and_preprocess_image(image_buffer):
+            print("image buffer", tf.shape(image_buffer))
+
+            image_buffer = tf.reshape(image_buffer,[],name="reshape_1")
+
+            if self.jpeg_encoding:
+                image = tf.image.decode_jpeg(image_buffer)
+                print("14********image decode_jpeg********", image)
+            else:
+                image = tf.decode_raw(image_buffer, tf.uint8)
+                print("15 ********image decode_raw********", tf.shape(image))
+            print("16 ******** image shape", image_shape)
+
+            image = tf.reshape(image, image_shape, name="reshape_4") ##Bing:the bug #issue 1 is here
+            crop_size = self.hparams.crop_size
+            scale_size = self.hparams.scale_size
+            if crop_size or scale_size:
+                if not crop_size:
+                    crop_size = min(image_shape[0], image_shape[1])
+                image = tf.image.resize_image_with_crop_or_pad(image, crop_size, crop_size)
+                image = tf.reshape(image, [crop_size, crop_size, 3],"reshape_3")
+                if scale_size:
+                    # upsample with bilinear interpolation but downsample with area interpolation
+                    if crop_size < scale_size:
+                        image = tf.image.resize_images(image, [scale_size, scale_size],
+                                                       method=tf.image.ResizeMethod.BILINEAR)
+                    elif crop_size > scale_size:
+                        image = tf.image.resize_images(image, [scale_size, scale_size],
+                                                       method=tf.image.ResizeMethod.AREA)
+                    else:
+                        # image remains unchanged
+                        pass
+            return image
+
+        if not isinstance(image_buffers, (list, tuple)):
+            image_buffers = tf.unstack(image_buffers)
+        print("17 **************image buffer", image_buffers[0])
+        images = [decode_and_preprocess_image(image_buffer) for image_buffer in image_buffers]
+        images = tf.image.convert_image_dtype(images, dtype=tf.float32)
+        return images
+
+    def slice_sequences(self, state_like_seqs, action_like_seqs, example_sequence_length):
+        """
+        Slices sequences of length `example_sequence_length` into subsequences
+        of length `sequence_length`. The dicts of sequences are updated
+        in-place and the same dicts are returned.
+        """
+        # handle random shifting and frame skip
+        sequence_length = self.hparams.sequence_length  # desired sequence length
+        frame_skip = self.hparams.frame_skip
+        time_shift = self.hparams.time_shift
+        print("22***********example sequence_length",example_sequence_length)
+        if (time_shift and self.mode == 'train') or self.hparams.force_time_shift:
+            print("23***********I am here")
+            assert time_shift > 0 and isinstance(time_shift, int)
+            if isinstance(example_sequence_length, tf.Tensor):
+                example_sequence_length = tf.cast(example_sequence_length, tf.int32)
+            num_shifts = ((example_sequence_length - 1) - (sequence_length - 1) * (frame_skip + 1)) // time_shift
+            assert_message = ('example_sequence_length has to be at least %d when '
+                              'sequence_length=%d, frame_skip=%d.' %
+                              ((sequence_length - 1) * (frame_skip + 1) + 1,
+                               sequence_length, frame_skip))
+            with tf.control_dependencies([tf.assert_greater_equal(num_shifts, 0,
+                    data=[example_sequence_length, num_shifts], message=assert_message)]):
+                t_start = tf.random_uniform([], 0, num_shifts + 1, dtype=tf.int32, seed=self.seed) * time_shift
+        else:
+            t_start = 0
+        print("20:**********************sequence_len: {}, t_start:{}, frame_skip:{}".format(sequence_length,tf.shape(t_start),frame_skip))
+        state_like_t_slice = slice(t_start, t_start + (sequence_length - 1) * (frame_skip + 1) + 1, frame_skip + 1)
+        action_like_t_slice = slice(t_start, t_start + (sequence_length - 1) * (frame_skip + 1))
+
+        for example_name, seq in state_like_seqs.items():
+            print("21*****************seq*******",seq)
+            seq = tf.convert_to_tensor(seq)[state_like_t_slice]
+            print("25**************ses.shape", [self.hparams.sequence_length] + seq.shape.as_list()[1:])
+            seq.set_shape([sequence_length] + seq.shape.as_list()[1:])
+            state_like_seqs[example_name] = seq
+        for example_name, seq in action_like_seqs.items():
+            seq = tf.convert_to_tensor(seq)[action_like_t_slice]
+            seq.set_shape([(sequence_length - 1) * (frame_skip + 1)] + seq.shape.as_list()[1:])
+            # concatenate actions of skipped frames into single macro actions
+            seq = tf.reshape(seq, [sequence_length - 1, -1])
+            action_like_seqs[example_name] = seq
+        return state_like_seqs, action_like_seqs
+
+    def num_examples_per_epoch(self):
+        raise NotImplementedError
+
+
+
+class VideoDataset(BaseVideoDataset):
+    """
+    This class supports reading tfrecords where a sequence is stored as
+    multiple tf.train.Example and each of them is stored under a different
+    feature name (which is indexed by the time step).
+    """
+    def __init__(self, *args, **kwargs):
+        super(VideoDataset, self).__init__(*args, **kwargs)
+        self._max_sequence_length = None
+        self._dict_message = None
+
+    def _check_or_infer_shapes(self):
+        """
+        Should be called after state_like_names_and_shapes and
+        action_like_names_and_shapes have been finalized.
+        """
+        state_like_names_and_shapes = OrderedDict([(k, list(v)) for k, v in self.state_like_names_and_shapes.items()])
+        action_like_names_and_shapes = OrderedDict([(k, list(v)) for k, v in self.action_like_names_and_shapes.items()])
+        from google.protobuf.json_format import MessageToDict
+        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
+        self._dict_message = MessageToDict(tf.train.Example.FromString(example))
+        for example_name, name_and_shape in (list(state_like_names_and_shapes.items()) +
+                                             list(action_like_names_and_shapes.items())):
+            name, shape = name_and_shape
+            feature = self._dict_message['features']['feature']
+            names = [name_ for name_ in feature.keys() if re.search(name.replace('%d', '\d+'), name_) is not None]
+            if not names:
+                raise ValueError('Could not found any feature with name pattern %s.' % name)
+            if example_name in self.state_like_names_and_shapes:
+                sequence_length = len(names)
+            else:
+                sequence_length = len(names) + 1
+            if self._max_sequence_length is None:
+                self._max_sequence_length = sequence_length
+            else:
+                self._max_sequence_length = min(sequence_length, self._max_sequence_length)
+            name = names[0]
+            feature = feature[name]
+            list_type, = feature.keys()
+            if list_type == 'floatList':
+                inferred_shape = (len(feature[list_type]['value']),)
+                if shape is None:
+                    name_and_shape[1] = inferred_shape
+                else:
+                    if inferred_shape != shape:
+                        raise ValueError('Inferred shape for feature %s is %r but instead got shape %r.' %
+                                         (name, inferred_shape, shape))
+            elif list_type == 'bytesList':
+                image_str, = feature[list_type]['value']
+                # try to infer image shape
+                inferred_shape = None
+                if not self.jpeg_encoding:
+                    spatial_size = len(image_str) // 4
+                    height = width = int(np.sqrt(spatial_size))  # assume square image
+                    if len(image_str) == (height * width * 4):
+                        inferred_shape = (height, width, 3)
+                if shape is None:
+                    if inferred_shape is not None:
+                        name_and_shape[1] = inferred_shape
+                    else:
+                        raise ValueError('Unable to infer shape for feature %s of size %d.' % (name, len(image_str)))
+                else:
+                    if inferred_shape is not None and inferred_shape != shape:
+                        raise ValueError('Inferred shape for feature %s is %r but instead got shape %r.' %
+                                         (name, inferred_shape, shape))
+            else:
+                raise NotImplementedError
+        self.state_like_names_and_shapes = OrderedDict([(k, tuple(v)) for k, v in state_like_names_and_shapes.items()])
+        self.action_like_names_and_shapes = OrderedDict([(k, tuple(v)) for k, v in action_like_names_and_shapes.items()])
+
+        # set sequence_length to the longest possible if it is not specified
+        if not self.hparams.sequence_length:
+            self.hparams.sequence_length = (self._max_sequence_length - 1) // (self.hparams.frame_skip + 1) + 1
+
+    def set_sequence_length(self, sequence_length):
+        if not sequence_length:
+            sequence_length = (self._max_sequence_length - 1) // (self.hparams.frame_skip + 1) + 1
+        self.hparams.sequence_length = sequence_length
+
+    def parser(self, serialized_example):
+        """
+        Parses a single tf.train.Example into images, states, actions, etc tensors.
+        """
+        features = dict()
+        for i in range(self._max_sequence_length):
+            for example_name, (name, shape) in self.state_like_names_and_shapes.items():
+                if example_name == 'images':  # special handling for image
+                    features[name % i] = tf.FixedLenFeature([1], tf.string)
+                else:
+                    features[name % i] = tf.FixedLenFeature(shape, tf.float32)
+        for i in range(self._max_sequence_length - 1):
+            for example_name, (name, shape) in self.action_like_names_and_shapes.items():
+                features[name % i] = tf.FixedLenFeature(shape, tf.float32)
+
+        # check that the features are in the tfrecord
+        for name in features.keys():
+            if name not in self._dict_message['features']['feature']:
+                raise ValueError('Feature with name %s not found in tfrecord. Possible feature names are:\n%s' %
+                                 (name, '\n'.join(sorted(self._dict_message['features']['feature'].keys()))))
+
+        # parse all the features of all time steps together
+        features = tf.parse_single_example(serialized_example, features=features)
+
+
+        state_like_seqs = OrderedDict([(example_name, []) for example_name in self.state_like_names_and_shapes])
+        action_like_seqs = OrderedDict([(example_name, []) for example_name in self.action_like_names_and_shapes])
+        for i in range(self._max_sequence_length):
+            for example_name, (name, shape) in self.state_like_names_and_shapes.items():
+                state_like_seqs[example_name].append(features[name % i])
+        for i in range(self._max_sequence_length - 1):
+            for example_name, (name, shape) in self.action_like_names_and_shapes.items():
+                action_like_seqs[example_name].append(features[name % i])
+
+        # for this class, it's much faster to decode and preprocess the entire sequence before sampling a slice
+        _, image_shape = self.state_like_names_and_shapes['images']
+        state_like_seqs['images'] = self.decode_and_preprocess_images(state_like_seqs['images'], image_shape)
+
+        state_like_seqs, action_like_seqs = \
+            self.slice_sequences(state_like_seqs, action_like_seqs, self._max_sequence_length)
+        return state_like_seqs, action_like_seqs
+
+
+class SequenceExampleVideoDataset(BaseVideoDataset):
+    """
+    This class supports reading tfrecords where an entire sequence is stored as
+    a single tf.train.SequenceExample.
+    """
+    def parser(self, serialized_example):
+        """
+        Parses a single tf.train.SequenceExample into images, states, actions, etc tensors.
+        """
+        sequence_features = dict()
+        for example_name, (name, shape) in self.state_like_names_and_shapes.items():
+            if example_name == 'images':  # special handling for image
+                sequence_features[name] = tf.FixedLenSequenceFeature([1], tf.string)
+            else:
+                sequence_features[name] = tf.FixedLenSequenceFeature(shape, tf.float32)
+        for example_name, (name, shape) in self.action_like_names_and_shapes.items():
+            sequence_features[name] = tf.FixedLenSequenceFeature(shape, tf.float32)
+
+        _, sequence_features = tf.parse_single_sequence_example(
+            serialized_example, sequence_features=sequence_features)
+
+        state_like_seqs = OrderedDict()
+        action_like_seqs = OrderedDict()
+        for example_name, (name, shape) in self.state_like_names_and_shapes.items():
+            state_like_seqs[example_name] = sequence_features[name]
+        for example_name, (name, shape) in self.action_like_names_and_shapes.items():
+            action_like_seqs[example_name] = sequence_features[name]
+
+        # the sequence_length of this example is determined by the shortest sequence
+        example_sequence_length = []
+        for example_name, seq in state_like_seqs.items():
+            example_sequence_length.append(tf.shape(seq)[0])
+        for example_name, seq in action_like_seqs.items():
+            example_sequence_length.append(tf.shape(seq)[0] + 1)
+        example_sequence_length = tf.reduce_min(example_sequence_length)
+        #bing
+        state_like_seqs, action_like_seqs = \
+            self.slice_sequences(state_like_seqs, action_like_seqs, example_sequence_length)
+
+        # decode and preprocess images on the sampled slice only
+        _, image_shape = self.state_like_names_and_shapes['images']
+        state_like_seqs['images'] = self.decode_and_preprocess_images(state_like_seqs['images'], image_shape)
+        return state_like_seqs, action_like_seqs
+
+
+class VarLenFeatureVideoDataset(BaseVideoDataset):
+    """
+    This class supports reading tfrecords where an entire sequence is stored as
+    a single tf.train.Example.
+
+    https://github.com/tensorflow/tensorflow/issues/15977
+    """
+    def filter(self, serialized_example):
+        features = dict()
+        features['sequence_length'] = tf.FixedLenFeature((), tf.int64)
+        features = tf.parse_single_example(serialized_example, features=features)
+        example_sequence_length = features['sequence_length']
+        return tf.greater_equal(example_sequence_length, self.hparams.sequence_length)
+
+    def parser(self, serialized_example):
+        """
+        Parses a single tf.train.SequenceExample into images, states, actions, etc tensors.
+        """
+        print("1.***parser function from class VarLenFeatureVideoDatase")
+        features = dict()
+        features['sequence_length'] = tf.FixedLenFeature((), tf.int64)
+        for example_name, (name, shape) in self.state_like_names_and_shapes.items():
+            if example_name == 'images':
+                #Bing
+                #features[name] = tf.FixedLenFeature([1], tf.string)
+                features[name] = tf.VarLenFeature(tf.string)
+            else:
+                features[name] = tf.VarLenFeature(tf.float32)
+        for example_name, (name, shape) in self.action_like_names_and_shapes.items():
+            features[name] = tf.VarLenFeature(tf.float32)
+
+        features = tf.parse_single_example(serialized_example, features=features)
+        example_sequence_length = features['sequence_length']
+
+        state_like_seqs = OrderedDict()
+        action_like_seqs = OrderedDict()
+        for example_name, (name, shape) in self.state_like_names_and_shapes.items():
+            if example_name == 'images':
+                seq = tf.sparse_tensor_to_dense(features[name], '')
+            else:
+                seq = tf.sparse_tensor_to_dense(features[name])
+                seq = tf.reshape(seq, [example_sequence_length] + list(shape))
+
+            state_like_seqs[example_name] = seq
+
+
+        for example_name, (name, shape) in self.action_like_names_and_shapes.items():
+
+            seq = tf.sparse_tensor_to_dense(features[name])
+            seq = tf.reshape(seq, [example_sequence_length - 1] + list(shape))
+            action_like_seqs[example_name] = seq
+
+        #Bing: I replce the self.slice_sequence to the following three lines , the program works, but I need to figure it out what happend inside this function
+        state_like_seqs, action_like_seqs = \
+              self.slice_sequences(state_like_seqs, action_like_seqs, example_sequence_length)
+        # seq = tf.convert_to_tensor(seq)
+        # print("25**************ses.shape",[self.hparams.sequence_length] + seq.shape.as_list()[1:])
+        # seq.set_shape([self.hparams.sequence_length] + seq.shape.as_list()[1:])
+        # state_like_seqs[example_name] = seq
+        #print("11**********Slide sequences**************** ", action_like_seqs)
+        # decode and preprocess images on the sampled slice only
+        _, image_shape = self.state_like_names_and_shapes['images']
+
+        state_like_seqs['images'] = self.decode_and_preprocess_images(state_like_seqs['images'], image_shape)
+        return state_like_seqs, action_like_seqs
+
+
+if __name__ == '__main__':
+    import cv2
+    from video_prediction import datasets
+
+    datasets = [
+        datasets.SV2PVideoDataset('data/shape', mode='val'),
+        datasets.SV2PVideoDataset('data/humans', mode='val'),
+        datasets.SoftmotionVideoDataset('data/bair', mode='val'),
+        datasets.KTHVideoDataset('data/kth', mode='val'),
+        datasets.KTHVideoDataset('data/kth_128', mode='val'),
+        datasets.UCF101VideoDataset('data/ucf101', mode='val'),
+    ]
+    batch_size = 4
+
+    sess = tf.Session()
+
+    for dataset in datasets:
+        inputs = dataset.make_batch(batch_size)
+        images = inputs['images']
+        images = tf.reshape(images, [-1] + images.get_shape().as_list()[2:])
+        images = sess.run(images)
+        images = (images * 255).astype(np.uint8)
+        for image in images:
+            if image.shape[-1] == 1:
+                image = np.tile(image, [1, 1, 3])
+            else:
+                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+            cv2.imshow(dataset.input_dir, image)
+            cv2.waitKey(50)
diff --git a/video_prediction/datasets/cartgripper_dataset.py b/video_prediction/datasets/cartgripper_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..68275cb75945c9327290d3759e4e9912d5f5c4a3
--- /dev/null
+++ b/video_prediction/datasets/cartgripper_dataset.py
@@ -0,0 +1,24 @@
+import itertools
+
+from .base_dataset import VideoDataset
+from .softmotion_dataset import SoftmotionVideoDataset
+
+
+class CartgripperVideoDataset(SoftmotionVideoDataset):
+    def __init__(self, *args, **kwargs):
+        VideoDataset.__init__(self, *args, **kwargs)
+        self.state_like_names_and_shapes['images'] = '%d/image_view0/encoded', (48, 64, 3)
+        if self.hparams.use_state:
+            self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (6,)
+            self.action_like_names_and_shapes['actions'] = '%d/action', (3,)
+        self._check_or_infer_shapes()
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(CartgripperVideoDataset, self).get_default_hparams_dict()
+        hparams = dict(
+            context_frames=2,
+            sequence_length=15,
+            time_shift=3,
+            use_state=True,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
diff --git a/video_prediction/datasets/era5_dataset_v2.py b/video_prediction/datasets/era5_dataset_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e32e0c638b4b39c588389e906ba29be5144ee35
--- /dev/null
+++ b/video_prediction/datasets/era5_dataset_v2.py
@@ -0,0 +1,331 @@
+import argparse
+import glob
+import itertools
+import os
+import pickle
+import random
+import re
+import hickle as hkl
+import numpy as np
+import json
+import tensorflow as tf
+from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
+# ML 2020/04/14: hack for getting functions of process_netCDF_v2:
+from os import path
+import sys
+sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/'))
+import DataPreprocess.process_netCDF_v2 
+from DataPreprocess.process_netCDF_v2 import get_unique_vars
+from DataPreprocess.process_netCDF_v2 import Calc_data_stat
+#from base_dataset import VarLenFeatureVideoDataset
+from collections import OrderedDict
+from tensorflow.contrib.training import HParams
+
+class ERA5Dataset_v2(VarLenFeatureVideoDataset):
+    def __init__(self, *args, **kwargs):
+        super(ERA5Dataset_v2, self).__init__(*args, **kwargs)
+        from google.protobuf.json_format import MessageToDict
+        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
+        dict_message = MessageToDict(tf.train.Example.FromString(example))
+        feature = dict_message['features']['feature']
+        self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
+        self.image_shape = self.video_shape[1:]
+        self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict()
+        hparams = dict(
+            context_frames=10,#Bing: Todo oriignal is 10
+            sequence_length=20,#bing: TODO original is 20,
+            long_sequence_length=20,
+            force_time_shift=True,
+            shuffle_on_val=True, 
+            use_state=False,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+
+    @property
+    def jpeg_encoding(self):
+        return False
+
+
+
+    def num_examples_per_epoch(self):
+        with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
+            sequence_lengths = sequence_lengths_file.readlines()
+        sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
+        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
+
+
+    def filter(self, serialized_example):
+        return tf.convert_to_tensor(True)
+
+
+    def make_dataset_v2(self, batch_size):
+        def parser(serialized_example):
+            seqs = OrderedDict()
+            keys_to_features = {
+                'width': tf.FixedLenFeature([], tf.int64),
+                'height': tf.FixedLenFeature([], tf.int64),
+                'sequence_length': tf.FixedLenFeature([], tf.int64),
+                'channels': tf.FixedLenFeature([],tf.int64),
+                # 'images/encoded':  tf.FixedLenFeature([], tf.string)
+                'images/encoded': tf.VarLenFeature(tf.float32)
+            }
+            
+            # for i in range(20):
+            #     keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
+            parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
+            print ("Parse features", parsed_features)
+            seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
+           # width = tf.sparse_tensor_to_dense(parsed_features["width"])
+           # height = tf.sparse_tensor_to_dense(parsed_features["height"])
+           # channels  = tf.sparse_tensor_to_dense(parsed_features["channels"])
+           # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
+            images = []
+            # for i in range(20):
+            #    images.append(parsed_features["images/encoded"].values[i])
+            # images = parsed_features["images/encoded"]
+            # images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets)
+            # seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '')
+            # Parse the string into an array of pixels corresponding to the image
+            # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32)
+
+            # images = seq
+            print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
+            images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
+            seqs["images"] = images
+            return seqs
+        filenames = self.filenames
+        print ("FILENAMES",filenames)
+	#TODO:
+	#temporal_filenames = self.temporal_filenames
+        shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
+        if shuffle:
+            random.shuffle(filenames)
+        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)  # todo: what is buffer_size
+        print("files", self.filenames)
+        print("mode", self.mode)
+        dataset = dataset.filter(self.filter)
+        if shuffle:
+            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
+        else:
+            dataset = dataset.repeat(self.num_epochs)
+
+        num_parallel_calls = None if shuffle else 1
+        dataset = dataset.apply(tf.contrib.data.map_and_batch(
+            parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
+        #dataset = dataset.map(parser)
+        # num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
+        # dataset = dataset.apply(tf.contrib.data.map_and_batch(
+        #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
+        dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU
+
+        return dataset
+
+
+
+    def make_batch(self, batch_size):
+        dataset = self.make_dataset_v2(batch_size)
+        iterator = dataset.make_one_shot_iterator()
+        return iterator.get_next()
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _bytes_list_feature(values):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
+
+def _floats_feature(value):
+  return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+def save_tf_record(output_fname, sequences):
+    print('saving sequences to %s' % output_fname)
+    with tf.python_io.TFRecordWriter(output_fname) as writer:
+        for sequence in sequences:
+            num_frames = len(sequence)
+            height, width, channels = sequence[0].shape
+            encoded_sequence = np.array([list(image) for image in sequence])
+
+            features = tf.train.Features(feature={
+                'sequence_length': _int64_feature(num_frames),
+                'height': _int64_feature(height),
+                'width': _int64_feature(width),
+                'channels': _int64_feature(channels),
+                'images/encoded': _floats_feature(encoded_sequence.flatten()),
+            })
+            example = tf.train.Example(features=features)
+            writer.write(example.SerializeToString())
+            
+class Norm_data:
+    """
+     Class for normalizing data. The statistical data for normalization (minimum, maximum, average, standard deviation etc.) is expected to be available from a statistics-dictionary
+     created with the calc_data_stat-class (see 'process_netCDF_v2.py'.
+    """
+    
+    ### set known norms and the requested statistics (to be retrieved from statistics.json) here ###
+    known_norms = {}
+    known_norms["minmax"] = ["min","max"]
+    known_norms["znorm"]  = ["avg","sigma"]
+   
+    def __init__(self,varnames):
+        """Initialize the instance by setting the variable names to be handled and the status (for sanity checks only) as attributes."""
+        varnames_uni, _, nvars = get_unique_vars(varnames)
+        
+        self.varnames = varnames_uni
+        self.status_ok= False
+            
+    def check_and_set_norm(self,stat_dict,norm):
+        """
+         Checks if the statistics-dictionary provides the required data for selected normalization method and expands the instance's attributes accordingly.
+         Example: minmax-normalization requires the minimum and maximum value of a variable named var1. 
+                 If the requested values are provided by the statistics-dictionary, the instance gets the attributes 'var1min' and 'var1max',respectively.
+        """
+        
+        # some sanity checks
+        if not norm in self.known_norms.keys(): # valid normalization requested?
+            print("Please select one of the following known normalizations: ")
+            for norm_avail in self.known_norms.keys():
+                print(norm_avail)
+            raise ValueError("Passed normalization '"+norm+"' is unknown.")
+       
+        if not all(items in stat_dict for items in self.varnames): # all variables found in dictionary?
+            print("Keys in stat_dict:")
+            print(stat_dict.keys())
+            
+            print("Requested variables:")
+            print(self.varnames)
+            raise ValueError("Could not find all requested variables in statistics dictionary.")   
+
+        # create all attributes for the instance
+        for varname in self.varnames:
+            for stat_name in self.known_norms[norm]:
+                #setattr(self,varname+stat_name,stat_dict[varname][0][stat_name])
+                setattr(self,varname+stat_name,Calc_data_stat.get_stat_vars(stat_dict,stat_name,varname))
+                
+        self.status_ok = True           # set status for normalization -> ready
+                
+    def norm_var(self,data,varname,norm):
+        """ 
+         Performs given normalization on input data (given that the instance is already set up)
+        """
+        
+        # some sanity checks
+        if not self.status_ok: raise ValueError("Norm_data-instance needs to be initialized and checked first.") # status ready?
+        
+        if not norm in self.known_norms.keys():                                # valid normalization requested?
+            print("Please select one of the following known normalizations: ")
+            for norm_avail in self.known_norms.keys():
+                print(norm_avail)
+            raise ValueError("Passed normalization '"+norm+"' is unknown.")
+        
+        # do the normalization and return
+        if norm == "minmax":
+            return((data[...] - getattr(self,varname+"min"))/(getattr(self,varname+"max") - getattr(self,varname+"min")))
+        elif norm == "znorm":
+            return((data[...] - getattr(self,varname+"avg"))/getattr(self,varname+"sigma")**2)
+        
+    def denorm_var(self,data,varname,norm):
+        """ 
+         Performs given denormalization on input data (given that the instance is already set up), i.e. inverse method to norm_var
+        """
+        
+        # some sanity checks
+        if not self.status_ok: raise ValueError("Norm_data-instance needs to be initialized and checked first.") # status ready?        
+        
+        if not norm in self.known_norms.keys():                                # valid normalization requested?
+            print("Please select one of the following known normalizations: ")
+            for norm_avail in self.known_norms.keys():
+                print(norm_avail)
+            raise ValueError("Passed normalization '"+norm+"' is unknown.")
+        
+        # do the denormalization and return
+        if norm == "minmax":
+            return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"min"))
+        elif norm == "znorm":
+            return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg"))
+        
+
+def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128
+    # ML 2020/04/08:
+    # Include vars_in for more flexible data handling (normalization and reshaping)
+    # and optional keyword argument for kind of normalization
+    
+    if 'norm' in kwargs:
+        norm = kwargs.get("norm")
+    else:
+        norm = "minmax"
+        print("Make use of default minmax-normalization...")
+
+    output_dir = os.path.join(output_dir,partition_name)
+    os.makedirs(output_dir,exist_ok=True)
+    
+    norm_cls  = Norm_data(vars_in)       # init normalization-instance
+    nvars     = len(vars_in)
+    
+    # open statistics file and feed it to norm-instance
+    with open(os.path.join(input_dir,"statistics.json")) as js_file:
+        norm_cls.check_and_set_norm(json.load(js_file),norm)        
+    
+    sequences = []
+    sequence_iter = 0
+    sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
+    X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl"))
+    X_possible_starts = [i for i in range(len(X_train) - seq_length)]
+    for X_start in X_possible_starts:
+        print("Interation", sequence_iter)
+        X_end = X_start + seq_length
+        #seq = X_train[X_start:X_end, :, :,:]
+        seq = X_train[X_start:X_end,:,:]
+        #print("*****len of seq ***.{}".format(len(seq)))
+        #seq = list(np.array(seq).reshape((len(seq), 64, 64, 3)))
+        seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
+        if not sequences:
+            last_start_sequence_iter = sequence_iter
+            print("reading sequences starting at sequence %d" % sequence_iter)
+        sequences.append(seq)
+        sequence_iter += 1
+        sequence_lengths_file.write("%d\n" % len(seq))
+
+        if len(sequences) == sequences_per_file:
+            ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables
+            sequences = np.array(sequences)
+            ### normalization
+            for i in range(nvars):    
+                sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm)
+
+            output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
+            output_fname = os.path.join(output_dir, output_fname)
+            save_tf_record(output_fname, list(sequences))
+            sequences = []
+    sequence_lengths_file.close()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
+    parser.add_argument("output_dir", type=str)
+    # ML 2020/04/08 S
+    # Add vars for ensuring proper normalization and reshaping of sequences
+    parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.")
+    parser.add_argument("-height",type=int,default=64)
+    parser.add_argument("-width",type = int,default=64)
+    parser.add_argument("-seq_length",type=int,default=20)
+    args = parser.parse_args()
+    current_path = os.getcwd()
+    #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
+    #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5"
+    partition_names = ['train','val',  'test'] #64,64,3 val has issue#
+  
+    for partition_name in partition_names:
+        read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq
+        #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5
+
+if __name__ == '__main__':
+    main()
+
diff --git a/video_prediction/datasets/era5_dataset_v2_anomaly.py b/video_prediction/datasets/era5_dataset_v2_anomaly.py
new file mode 100644
index 0000000000000000000000000000000000000000..0daa13b332a50f848b1d41fb8bd5c75079c88b7e
--- /dev/null
+++ b/video_prediction/datasets/era5_dataset_v2_anomaly.py
@@ -0,0 +1,274 @@
+import argparse
+import glob
+import itertools
+import os
+import pickle
+import random
+import re
+import netCDF4
+import hickle as hkl
+import numpy as np
+import tensorflow as tf
+import pandas as pd
+from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
+from collections import OrderedDict
+from tensorflow.contrib.training import HParams
+
+units = "hours since 2000-01-01 00:00:00"
+calendar = "gregorian"
+
+class ERA5Dataset_v2_anomaly(VarLenFeatureVideoDataset):
+    def __init__(self, *args, **kwargs):
+        super(ERA5Dataset_v2_anomaly, self).__init__(*args, **kwargs)
+        from google.protobuf.json_format import MessageToDict
+        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
+        dict_message = MessageToDict(tf.train.Example.FromString(example))
+        feature = dict_message['features']['feature']
+        image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels'])
+        self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(ERA5Dataset_v2_anomaly, self).get_default_hparams_dict()
+        hparams = dict(
+            context_frames=10,
+            sequence_length=20,
+            long_sequence_length=40,
+            force_time_shift=True,
+            shuffle_on_val=True,
+            use_state=False,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+    @property
+    def jpeg_encoding(self):
+        return False
+
+
+    def num_examples_per_epoch(self):
+        with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
+            sequence_lengths = sequence_lengths_file.readlines()
+        sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
+        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
+
+
+    def filter(self, serialized_example):
+        return tf.convert_to_tensor(True)
+
+
+
+    def make_dataset_v2(self, batch_size):
+        def parser(serialized_example):
+            seqs = OrderedDict()
+            keys_to_features = {
+                # 'width': tf.FixedLenFeature([], tf.int64),
+                # 'height': tf.FixedLenFeature([], tf.int64),
+                'sequence_length': tf.FixedLenFeature([], tf.int64),
+                # 'channels': tf.FixedLenFeature([],tf.int64),
+                # 'images/encoded':  tf.FixedLenFeature([], tf.string)
+                'images/encoded': tf.VarLenFeature(tf.float32)
+            }
+            # for i in range(20):
+            #     keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
+            parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
+            seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
+            images = []
+            # for i in range(20):
+            #    images.append(parsed_features["images/encoded"].values[i])
+            # images = parsed_features["images/encoded"]
+            # images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets)
+            # seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '')
+            # Parse the string into an array of pixels corresponding to the image
+            # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32)
+
+            # images = seq
+            images = tf.reshape(seq, [20, 64, 64, 1], name = "reshape_new")
+            seqs["images"] = images
+            return seqs
+        filenames = self.filenames
+        filenames_mean = self.filenames_mean
+        shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
+        if shuffle:
+            random.shuffle(filenames)
+        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8 * 1024 * 1024)  # todo: what is buffer_size
+        dataset = dataset.filter(self.filter)
+        #Bing: for Anomaly
+        dataset_mean = tf.data.TFRecordDataset(filenames_mean, buffer_size = 8 * 1024 * 1024)
+        dataset_mean = dataset_mean.filter(self.filter)
+        if shuffle:
+            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size = 1024, count = self.num_epochs))
+            dataset_mean = dataset_mean.apply(tf.contrib.data.shuffle_and_repeat(buffer_size = 1024, count = self.num_epochs))
+        else:
+            dataset = dataset.repeat(self.num_epochs)
+            dataset_mean = dataset_mean.repeat(self.num_epochs)
+
+        num_parallel_calls = None if shuffle else 1
+        dataset = dataset.apply(tf.contrib.data.map_and_batch(
+            parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
+        dataset_mean = dataset_mean.apply(tf.contrib.data.map_and_batch(
+            parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
+        #dataset = dataset.map(parser)
+        # num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
+        # dataset = dataset.apply(tf.contrib.data.map_and_batch(
+        #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
+        dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU
+        dataset_mean = dataset_mean.prefetch(batch_size)
+        return dataset, dataset_mean
+
+    def make_batch_v2(self, batch_size):
+        dataset, dataset_mean = self.make_dataset_v2(batch_size)
+        iterator = dataset.make_one_shot_iterator()
+        interator2 = dataset_mean.make_one_shot_iterator()
+        return iterator.get_next(), interator2.get_next()
+
+
+    def make_data_mean(self,batch_size):
+        pass
+
+
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _bytes_list_feature(values):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
+
+def _floats_feature(value):
+  return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+
+def save_tf_record(output_fname, sequences):
+    print('saving sequences to %s' % output_fname)
+    with tf.python_io.TFRecordWriter(output_fname) as writer:
+        for sequence in sequences:
+            num_frames = len(sequence)
+            height, width, channels = sequence[0].shape
+            encoded_sequence = np.array([list(image) for image in sequence])
+
+            features = tf.train.Features(feature={
+                'sequence_length': _int64_feature(num_frames),
+                'height': _int64_feature(height),
+                'width': _int64_feature(width),
+                'channels': _int64_feature(channels),
+                'images/encoded': _floats_feature(encoded_sequence.flatten()),
+            })
+            example = tf.train.Example(features=features)
+            writer.write(example.SerializeToString())
+
+
+def extract_anomaly_one_pixel(X, X_timestamps,pixel):
+    print("Processing Pixel {}, {}".format(pixel[0],pixel[1]))
+    dates = [x.date() for x in X_timestamps]
+    df = pd.DataFrame(data = X[:, pixel[0], pixel[1]], index = dates)
+    df_mean = df.groupby(df.index).mean()
+    df2 = pd.merge(df, df_mean, left_index = True, right_index = True)
+    df2.columns = ["Real","Daily_mean"]
+    df2["Anomaly"] = df2["Real"] - df2["Daily_mean"]
+    daily_mean = df2["Daily_mean"].values
+    anomaly = df2["Anomaly"].values
+    return daily_mean, anomaly
+
+def extract_anomaly_all_pixels(X, X_timestamps):
+    #daily_mean, anomaly = extract_anomaly_one_pixel(X, X_timestamps, pixel = [0, 0])
+    daily_mean_pixels = np.zeros((X.shape[0], X.shape[1], X.shape[2]))
+    anomaly_pixels = np.zeros((X.shape[0], X.shape[1], X.shape[2]))
+    #daily_mean_all_pixels = [extract_anomaly_one_pixel(X, X_timestamps, pixel = [i, j])[0] for i in range(X.shape[1]) for j in range(X.shape[2])]
+    #anomaly_all_pixels = [extract_anomaly_one_pixel(X, X_timestamps, pixel = [i, j])[1] for i in range(X.shape[1]) for j in range(X.shape[2])]
+    for i in range(X.shape[1]):
+        for j in range(X.shape[2]):
+            daily_mean, anomaly = extract_anomaly_one_pixel(X, X_timestamps, pixel = [i, j])
+            daily_mean_pixels[:,i,j] = daily_mean
+            anomaly_pixels[:,i,j] = anomaly
+    return daily_mean_pixels, anomaly_pixels
+
+
+def read_frames_and_save_tf_records(output_dir, input_dir, partition_name, N_seq, sequences_per_file=128):#Bing: original 128
+    output_orig_dir = os.path.join(output_dir,partition_name + "_orig")
+    output_time_dir = os.path.join(output_dir,partition_name + "_time")
+    output_mean_dir = os.path.join(output_dir,partition_name + "_mean")
+    output_anomaly_dir = os.path.join(output_dir, partition_name )
+
+
+    if not os.path.exists(output_orig_dir): os.mkdir(output_orig_dir)
+    if not os.path.exists(output_time_dir): os.mkdir(output_time_dir)
+    if not os.path.exists(output_mean_dir): os.mkdir(output_mean_dir)
+    if not os.path.exists(output_anomaly_dir): os.mkdir(output_anomaly_dir)
+    sequences = []
+    sequences_time = []
+    sequences_mean = []
+    sequences_anomaly = []
+
+    sequence_iter = 0
+    sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
+    X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl"))
+    X_time = hkl.load(os.path.join(input_dir, "Time_time_" + partition_name + ".hkl"))
+    print ("X shape", X_train.shape)
+    X_timestamps = [netCDF4.num2date(x, units = units, calendar = calendar) for x in X_time]
+
+    print("X_time example", X_time[:10])
+    print("X_time after to date", X_timestamps[:10])
+    daily_mean_all_pixels, anomaly_all_pixels = extract_anomaly_all_pixels(X_train, X_timestamps)
+
+    X_possible_starts = [i for i in range(len(X_train) - N_seq)]
+    for X_start in X_possible_starts:
+        print("Interation", sequence_iter)
+        X_end = X_start + N_seq
+        #seq = X_train[X_start:X_end, :, :,:]
+        seq = X_train[X_start:X_end,:,:]
+        seq_time = X_time[X_start:X_end]
+        seq_mean = daily_mean_all_pixels[X_start:X_end,:,:]
+        seq_anomaly = anomaly_all_pixels[X_start:X_end,:,:]
+        #print("*****len of seq ***.{}".format(len(seq)))
+        seq = list(np.array(seq).reshape((len(seq), 64, 64, 1)))
+        seq_time = list(np.array(seq_time))
+        seq_mean = list(np.array(seq_mean).reshape((len(seq_mean), 64, 64, 1)))
+        seq_anomaly = list(np.array(seq_anomaly).reshape((len(seq_anomaly), 64, 64, 1)))
+        if not sequences:
+            last_start_sequence_iter = sequence_iter
+            print("reading sequences starting at sequence %d" % sequence_iter)
+        sequences.append(seq)
+        sequences_time.append(seq_time)
+        sequences_mean.append(seq_mean)
+        sequences_anomaly.append(seq_anomaly)
+        sequence_iter += 1
+        sequence_lengths_file.write("%d\n" % len(seq))
+
+        if len(sequences) == sequences_per_file:
+            output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
+            output_orig_fname = os.path.join(output_orig_dir, output_fname)
+            output_time_fname = os.path.join(output_time_dir,'sequence_{0}_to_{1}.hkl'.format(last_start_sequence_iter, sequence_iter - 1))
+            output_mean_fname = os.path.join(output_mean_dir, output_fname)
+            output_anomaly_fname = os.path.join(output_anomaly_dir, output_fname)
+
+            save_tf_record(output_orig_fname, sequences)
+            hkl.dump(sequences_time,output_time_fname )
+            #save_tf_record(output_time_fname,sequences_time)
+            save_tf_record(output_mean_fname, sequences_mean)
+            save_tf_record(output_anomaly_fname, sequences_anomaly)
+            sequences[:] = []
+            sequences_time[:] = []
+            sequences_mean[:] = []
+            sequences_anomaly[:] = []
+    sequence_lengths_file.close()
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
+    parser.add_argument("output_dir", type=str)
+    # parser.add_argument("image_size_h", type=int)
+    # parser.add_argument("image_size_v", type = int)
+    args = parser.parse_args()
+    current_path = os.getcwd()
+    #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
+    #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5"
+    partition_names = ['train', 'val', 'test']
+    for partition_name in partition_names:
+        read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,partition_name=partition_name, N_seq=20) #Bing: Todo need check the N_seq
+        #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5
+
+if __name__ == '__main__':
+    main()
+
diff --git a/video_prediction/datasets/google_robot_dataset.py b/video_prediction/datasets/google_robot_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd2e975cf5bd83e7c342d7effbead68c1063498a
--- /dev/null
+++ b/video_prediction/datasets/google_robot_dataset.py
@@ -0,0 +1,40 @@
+import itertools
+import os
+
+from .base_dataset import VideoDataset
+
+
+class GoogleRobotVideoDataset(VideoDataset):
+    """
+    https://sites.google.com/site/brainrobotdata/home/push-dataset
+    """
+    def __init__(self, *args, **kwargs):
+        super(GoogleRobotVideoDataset, self).__init__(*args, **kwargs)
+        self.state_like_names_and_shapes['images'] = 'move/%d/image/encoded', (512, 640, 3)
+        if self.hparams.use_state:
+            self.state_like_names_and_shapes['states'] = 'move/%d/endeffector/vec_pitch_yaw', (5,)
+            self.action_like_names_and_shapes['actions'] = 'move/%d/commanded_pose/vec_pitch_yaw', (5,)
+        self._check_or_infer_shapes()
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(GoogleRobotVideoDataset, self).get_default_hparams_dict()
+        hparams = dict(
+            context_frames=2,
+            sequence_length=15,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def num_examples_per_epoch(self):
+        if os.path.basename(self.input_dir) == 'push_train':
+            count = 51615
+        elif os.path.basename(self.input_dir) == 'push_testseen':
+            count = 1038
+        elif os.path.basename(self.input_dir) == 'push_testnovel':
+            count = 995
+        else:
+            raise NotImplementedError
+        return count
+
+    @property
+    def jpeg_encoding(self):
+        return True
diff --git a/video_prediction/datasets/kth_dataset.py b/video_prediction/datasets/kth_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..40fb6bf57b8219fc7e75c3759df9e6b38fffeb30
--- /dev/null
+++ b/video_prediction/datasets/kth_dataset.py
@@ -0,0 +1,159 @@
+import argparse
+import glob
+import itertools
+import os
+import pickle
+import random
+import re
+import tensorflow as tf
+import numpy as np
+import skimage.io
+
+from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
+
+
+class KTHVideoDataset(VarLenFeatureVideoDataset):
+    def __init__(self, *args, **kwargs):
+        super(KTHVideoDataset, self).__init__(*args, **kwargs)
+        from google.protobuf.json_format import MessageToDict
+        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
+        dict_message = MessageToDict(tf.train.Example.FromString(example))
+        feature = dict_message['features']['feature']
+        image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels'])
+
+        self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(KTHVideoDataset, self).get_default_hparams_dict()
+        hparams = dict(
+            context_frames=10,
+            sequence_length=20,
+            long_sequence_length=40,
+            force_time_shift=True,
+            shuffle_on_val=True,
+            use_state=False,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    @property
+    def jpeg_encoding(self):
+        return False
+
+    def num_examples_per_epoch(self):
+        with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
+            sequence_lengths = sequence_lengths_file.readlines()
+        sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
+        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
+
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _bytes_list_feature(values):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
+
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def partition_data(input_dir):
+    # List files and corresponding person IDs
+    fnames = glob.glob(os.path.join(input_dir, '*/*'))
+    fnames = [fname for fname in fnames if os.path.isdir(fname)]
+    print("frames",fnames[0])
+
+    persons = [re.match('person(\d+)_\w+_\w+', os.path.split(fname)[1]).group(1) for fname in fnames]
+    persons = np.array([int(person) for person in persons])
+
+    train_mask = persons <= 16
+
+    train_fnames = [fnames[i] for i in np.where(train_mask)[0]]
+    test_fnames = [fnames[i] for i in np.where(~train_mask)[0]]
+
+    random.shuffle(train_fnames)
+
+    pivot = int(0.95 * len(train_fnames))
+    train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:]
+    return train_fnames, val_fnames, test_fnames
+
+
+def save_tf_record(output_fname, sequences):
+    print('saving sequences to %s' % output_fname)
+    with tf.python_io.TFRecordWriter(output_fname) as writer:
+        for sequence in sequences:
+            num_frames = len(sequence)
+            height, width, channels = sequence[0].shape
+            encoded_sequence = [image.tostring() for image in sequence]
+            features = tf.train.Features(feature={
+                'sequence_length': _int64_feature(num_frames),
+                'height': _int64_feature(height),
+                'width': _int64_feature(width),
+                'channels': _int64_feature(channels),
+                'images/encoded': _bytes_list_feature(encoded_sequence),
+            })
+            example = tf.train.Example(features=features)
+            writer.write(example.SerializeToString())
+
+
+def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128):
+    partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test
+    sequences = []
+    sequence_iter = 0
+    sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
+    for video_iter, video_dir in enumerate(video_dirs): #Interate group (e.g. walking) each person
+        meta_partition_name = partition_name if partition_name == 'test' else 'train'
+        meta_fname = os.path.join(os.path.split(video_dir)[0], '%s_meta%dx%d.pkl' %
+                                  (meta_partition_name, image_size, image_size))
+        with open(meta_fname, "rb") as f:
+            data = pickle.load(f) # The data has 62 items, each item is a dict, with three keys.  "vid","n", and "files", Each file has 4 channels, each channel has n sequence images with 64*64 png
+
+        vid = os.path.split(video_dir)[1]
+        (d,) = [d for d in data if d['vid'] == vid]
+        for frame_fnames_iter, frame_fnames in enumerate(d['files']):
+            frame_fnames = [os.path.join(video_dir, frame_fname) for frame_fname in frame_fnames]
+            frames = skimage.io.imread_collection(frame_fnames)
+            # they are grayscale images, so just keep one of the channels
+            frames = [frame[..., 0:1] for frame in frames]
+
+            if not sequences: #The length of the sequence in sequences could be different
+                last_start_sequence_iter = sequence_iter
+                print("reading sequences starting at sequence %d" % sequence_iter)
+
+            sequences.append(frames)
+            sequence_iter += 1
+            sequence_lengths_file.write("%d\n" % len(frames))
+
+            if (len(sequences) == sequences_per_file or
+                    (video_iter == (len(video_dirs) - 1) and frame_fnames_iter == (len(d['files']) - 1))):
+                output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
+                output_fname = os.path.join(output_dir, output_fname)
+                save_tf_record(output_fname, sequences)
+                sequences[:] = []
+    sequence_lengths_file.close()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, help="directory containing the processed directories "
+                                                    "boxing, handclapping, handwaving, "
+                                                    "jogging, running, walking")
+    parser.add_argument("--output_dir", type=str)
+    parser.add_argument("--image_size", type=int)
+    args = parser.parse_args()
+
+    partition_names = ['train', 'val', 'test']
+    print("input dir", args.input_dir)
+    partition_fnames = partition_data(args.input_dir)
+    print("partiotion_fnames[0]", partition_fnames[0])
+
+    for partition_name, partition_fnames in zip(partition_names, partition_fnames):
+        partition_dir = os.path.join(args.output_dir, partition_name)
+        if not os.path.exists(partition_dir):
+            os.makedirs(partition_dir)
+        read_frames_and_save_tf_records(partition_dir, partition_fnames, args.image_size)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/video_prediction/datasets/softmotion_dataset.py b/video_prediction/datasets/softmotion_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..106869ca46d857d56fdd8d0f0d20ae5e2b69c3b5
--- /dev/null
+++ b/video_prediction/datasets/softmotion_dataset.py
@@ -0,0 +1,82 @@
+import itertools
+import os
+import re
+
+import tensorflow as tf
+
+from video_prediction.utils import tf_utils
+from .base_dataset import VideoDataset
+
+
+class SoftmotionVideoDataset(VideoDataset):
+    """
+    https://sites.google.com/view/sna-visual-mpc
+    """
+    def __init__(self, *args, **kwargs):
+        super(SoftmotionVideoDataset, self).__init__(*args, **kwargs)
+        # infer name of image feature and check if object_pos feature is present
+        from google.protobuf.json_format import MessageToDict
+        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
+        dict_message = MessageToDict(tf.train.Example.FromString(example))
+        feature = dict_message['features']['feature']
+        image_names = set()
+        for name in feature.keys():
+            m = re.search('\d+/(\w+)/encoded', name)
+            if m:
+                image_names.add(m.group(1))
+        # look for image_aux1 and image_view0 in that order of priority
+        image_name = None
+        for name in ['image_aux1', 'image_view0']:
+            if name in image_names:
+                image_name = name
+                break
+        if not image_name:
+            if len(image_names) == 1:
+                image_name = image_names.pop()
+            else:
+                raise ValueError('The examples have images under more than one name.')
+        self.state_like_names_and_shapes['images'] = '%%d/%s/encoded' % image_name, None
+        if self.hparams.use_state:
+            self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (3,)
+            self.action_like_names_and_shapes['actions'] = '%d/action', (4,)
+        if any([re.search('\d+/object_pos', name) for name in feature.keys()]):
+            self.state_like_names_and_shapes['object_pos'] = '%d/object_pos', None  # shape is (2 * num_designated_pixels)
+        self._check_or_infer_shapes()
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(SoftmotionVideoDataset, self).get_default_hparams_dict()
+        hparams = dict(
+            context_frames=2,
+            sequence_length=12,
+            long_sequence_length=30,
+            time_shift=2,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    @property
+    def jpeg_encoding(self):
+        return False
+
+    def parser(self, serialized_example):
+        state_like_seqs, action_like_seqs = super(SoftmotionVideoDataset, self).parser(serialized_example)
+        if 'object_pos' in state_like_seqs:
+            object_pos = state_like_seqs['object_pos']
+            height, width, _ = self.state_like_names_and_shapes['images'][1]
+            object_pos = tf.reshape(object_pos, [object_pos.shape[0].value, -1, 2])
+            pix_distribs = tf.stack([tf_utils.pixel_distribution(object_pos_, height, width)
+                                     for object_pos_ in tf.unstack(object_pos, axis=1)], axis=-1)
+            state_like_seqs['pix_distribs'] = pix_distribs
+        return state_like_seqs, action_like_seqs
+
+    def num_examples_per_epoch(self):
+        # extract information from filename to count the number of trajectories in the dataset
+        count = 0
+        for filename in self.filenames:
+            match = re.search('traj_(\d+)_to_(\d+).tfrecords', os.path.basename(filename))
+            start_traj_iter = int(match.group(1))
+            end_traj_iter = int(match.group(2))
+            count += end_traj_iter - start_traj_iter + 1
+
+        # alternatively, the dataset size can be determined like this, but it's very slow
+        # count = sum(sum(1 for _ in tf.python_io.tf_record_iterator(filename)) for filename in filenames)
+        return count
diff --git a/video_prediction/datasets/sv2p_dataset.py b/video_prediction/datasets/sv2p_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed072297a7c853b2b389d7c1816885743bad0cf
--- /dev/null
+++ b/video_prediction/datasets/sv2p_dataset.py
@@ -0,0 +1,65 @@
+import itertools
+import os
+
+from .base_dataset import VideoDataset
+
+
+class SV2PVideoDataset(VideoDataset):
+    def __init__(self, *args, **kwargs):
+        super(SV2PVideoDataset, self).__init__(*args, **kwargs)
+        self.dataset_name = os.path.basename(os.path.split(self.input_dir)[0])
+        self.state_like_names_and_shapes['images'] = 'image_%d', (64, 64, 3)
+        if self.dataset_name == 'shape':
+            if self.hparams.use_state:
+                self.state_like_names_and_shapes['states'] = 'state_%d', (2,)
+                self.action_like_names_and_shapes['actions'] = 'action_%d', (2,)
+        elif self.dataset_name == 'humans':
+            if self.hparams.use_state:
+                raise ValueError('SV2PVideoDataset does not have states, use_state should be False')
+        else:
+            raise NotImplementedError
+        self._check_or_infer_shapes()
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(SV2PVideoDataset, self).get_default_hparams_dict()
+        if self.dataset_name == 'shape':
+            hparams = dict(
+                context_frames=1,
+                sequence_length=6,
+                time_shift=0,
+                use_state=False,
+            )
+        elif self.dataset_name == 'humans':
+            hparams = dict(
+                context_frames=10,
+                sequence_length=20,
+                use_state=False,
+            )
+        else:
+            raise NotImplementedError
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def num_examples_per_epoch(self):
+        if self.dataset_name == 'shape':
+            if os.path.basename(self.input_dir) == 'train':
+                count = 43415
+            elif os.path.basename(self.input_dir) == 'val':
+                count = 2898
+            else:  # shape dataset doesn't have a test set
+                raise NotImplementedError
+        elif self.dataset_name == 'humans':
+            if os.path.basename(self.input_dir) == 'train':
+                count = 23910
+            elif os.path.basename(self.input_dir) == 'val':
+                count = 10472
+            elif os.path.basename(self.input_dir) == 'test':
+                count = 7722
+            else:
+                raise NotImplementedError
+        else:
+            raise NotImplementedError
+        return count
+
+    @property
+    def jpeg_encoding(self):
+        return True
diff --git a/video_prediction/datasets/ucf101_dataset.py b/video_prediction/datasets/ucf101_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cba078ab4f66acb3fd80546a016fb9dd94b1d551
--- /dev/null
+++ b/video_prediction/datasets/ucf101_dataset.py
@@ -0,0 +1,212 @@
+import argparse
+import glob
+import itertools
+import os
+import random
+import re
+from multiprocessing import Pool
+import cv2
+import tensorflow as tf
+
+from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
+
+
+class UCF101VideoDataset(VarLenFeatureVideoDataset):
+    def __init__(self, *args, **kwargs):
+        super(UCF101VideoDataset, self).__init__(*args, **kwargs)
+        self.state_like_names_and_shapes['images'] = 'images/encoded', (240, 320, 3)
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(UCF101VideoDataset, self).get_default_hparams_dict()
+        hparams = dict(
+            context_frames=4,
+            sequence_length=8,
+            random_crop_size=0,
+            use_state=False,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    @property
+    def jpeg_encoding(self):
+        return True
+
+    def decode_and_preprocess_images(self, image_buffers, image_shape):
+        if self.hparams.crop_size:
+            raise NotImplementedError
+        if self.hparams.scale_size:
+            raise NotImplementedError
+        image_buffers = tf.reshape(image_buffers, [-1])
+        if not isinstance(image_buffers, (list, tuple)):
+            image_buffers = tf.unstack(image_buffers)
+        image_size = tf.image.extract_jpeg_shape(image_buffers[0])[:2]  # should be the same as image_shape[:2]
+        if self.hparams.random_crop_size:
+            random_crop_size = [self.hparams.random_crop_size] * 2
+            crop_y = tf.random_uniform([], minval=0, maxval=image_size[0] - random_crop_size[0], dtype=tf.int32)
+            crop_x = tf.random_uniform([], minval=0, maxval=image_size[1] - random_crop_size[1], dtype=tf.int32)
+            crop_window = [crop_y, crop_x] + random_crop_size
+            images = [tf.image.decode_and_crop_jpeg(image_buffer, crop_window) for image_buffer in image_buffers]
+            images = tf.image.convert_image_dtype(images, dtype=tf.float32)
+            images.set_shape([None] + random_crop_size + [image_shape[-1]])
+        else:
+            images = [tf.image.decode_jpeg(image_buffer) for image_buffer in image_buffers]
+            images = tf.image.convert_image_dtype(images, dtype=tf.float32)
+            images.set_shape([None] + list(image_shape))
+        # TODO: only random crop for training
+        return images
+
+    def num_examples_per_epoch(self):
+        # extract information from filename to count the number of trajectories in the dataset
+        count = 0
+        for filename in self.filenames:
+            match = re.search('sequence_(\d+)_to_(\d+).tfrecords', os.path.basename(filename))
+            start_traj_iter = int(match.group(1))
+            end_traj_iter = int(match.group(2))
+            count += end_traj_iter - start_traj_iter + 1
+
+        # alternatively, the dataset size can be determined like this, but it's very slow
+        # count = sum(sum(1 for _ in tf.python_io.tf_record_iterator(filename)) for filename in filenames)
+        return count
+
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _bytes_list_feature(values):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
+
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def partition_data(input_dir, train_test_list_dir):
+    train_list_fnames = glob.glob(os.path.join(train_test_list_dir, 'trainlist*.txt'))
+    test_list_fnames = glob.glob(os.path.join(train_test_list_dir, 'testlist*.txt'))
+    test_list_fnames_mathieu = [os.path.join(train_test_list_dir, 'testlist01.txt')]
+
+    def read_fnames(list_fnames):
+        fnames = []
+        for list_fname in sorted(list_fnames):
+            with open(list_fname, 'r') as f:
+                while True:
+                    fname = f.readline()
+                    if not fname:
+                        break
+                    fnames.append(fname.split('\n')[0].split(' ')[0])
+        return fnames
+
+    train_fnames = read_fnames(train_list_fnames)
+    test_fnames = read_fnames(test_list_fnames)
+    test_fnames_mathieu = read_fnames(test_list_fnames_mathieu)
+
+    train_fnames = [os.path.join(input_dir, train_fname) for train_fname in train_fnames]
+    test_fnames = [os.path.join(input_dir, test_fname) for test_fname in test_fnames]
+    test_fnames_mathieu = [os.path.join(input_dir, test_fname) for test_fname in test_fnames_mathieu]
+    # only use every 10 videos as in Mathieu et al.
+    test_fnames_mathieu = test_fnames_mathieu[::10]
+
+    random.shuffle(train_fnames)
+
+    pivot = int(0.95 * len(train_fnames))
+    train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:]
+    return train_fnames, val_fnames, test_fnames, test_fnames_mathieu
+
+
+def read_video(fname):
+    if not os.path.isfile(fname):
+        raise FileNotFoundError
+    vidcap = cv2.VideoCapture(fname)
+    frames, (success, image) = [], vidcap.read()
+    while success:
+        frames.append(image)
+        success, image = vidcap.read()
+    return frames
+
+
+def save_tf_record(output_fname, sequences, preprocess_image):
+    print('saving sequences to %s' % output_fname)
+    with tf.python_io.TFRecordWriter(output_fname) as writer:
+        for sequence in sequences:
+            num_frames = len(sequence)
+            height, width, channels = sequence[0].shape
+            encoded_sequence = [preprocess_image(image) for image in sequence]
+            features = tf.train.Features(feature={
+                'sequence_length': _int64_feature(num_frames),
+                'height': _int64_feature(height),
+                'width': _int64_feature(width),
+                'channels': _int64_feature(channels),
+                'images/encoded': _bytes_list_feature(encoded_sequence),
+            })
+            example = tf.train.Example(features=features)
+            writer.write(example.SerializeToString())
+
+
+def read_videos_and_save_tf_records(output_dir, fnames, start_sequence_iter=None,
+                                    end_sequence_iter=None, sequences_per_file=128):
+    print('started process with PID:', os.getpid())
+
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+
+    if start_sequence_iter is None:
+        start_sequence_iter = 0
+    if end_sequence_iter is None:
+        end_sequence_iter = len(fnames)
+
+    def preprocess_image(image):
+        if image.shape != (240, 320, 3):
+            image = cv2.resize(image, (320, 240), interpolation=cv2.INTER_LINEAR)
+        return tf.compat.as_bytes(cv2.imencode(".jpg", image)[1].tobytes())
+
+    print('reading and saving sequences {0} to {1}'.format(start_sequence_iter, end_sequence_iter))
+
+    sequences = []
+    for sequence_iter in range(start_sequence_iter, end_sequence_iter):
+        if not sequences:
+            last_start_sequence_iter = sequence_iter
+            print("reading sequences starting at sequence %d" % sequence_iter)
+
+        sequences.append(read_video(fnames[sequence_iter]))
+
+        if len(sequences) == sequences_per_file or sequence_iter == (end_sequence_iter - 1):
+            output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter)
+            output_fname = os.path.join(output_dir, output_fname)
+            save_tf_record(output_fname, sequences, preprocess_image)
+            sequences[:] = []
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("input_dir", type=str, help="directory containing the directories of "
+                                                    "classes, each of which contains avi files.")
+    parser.add_argument("train_test_list_dir", type=str, help='directory containing trainlist*.txt'
+                                                              'and testlist*.txt files.')
+    parser.add_argument("output_dir", type=str)
+    parser.add_argument('--num_workers', type=int, default=1, help='number of parallel workers')
+    args = parser.parse_args()
+
+    partition_names = ['train', 'val', 'test', 'test_mathieu']
+    partition_fnames = partition_data(args.input_dir, args.train_test_list_dir)
+
+    for partition_name, partition_fnames in zip(partition_names, partition_fnames):
+        partition_dir = os.path.join(args.output_dir, partition_name)
+        if not os.path.exists(partition_dir):
+            os.makedirs(partition_dir)
+
+        if args.num_workers > 1:
+            num_seqs_per_worker = len(partition_fnames) // args.num_workers
+            start_seq_iters = [num_seqs_per_worker * i for i in range(args.num_workers)]
+            end_seq_iters = [num_seqs_per_worker * (i + 1) - 1 for i in range(args.num_workers)]
+            end_seq_iters[-1] = len(partition_fnames)
+
+            p = Pool(args.num_workers)
+            p.starmap(read_videos_and_save_tf_records, zip([partition_dir] * args.num_workers,
+                                                           [partition_fnames] * args.num_workers,
+                                                           start_seq_iters, end_seq_iters))
+        else:
+            read_videos_and_save_tf_records(partition_dir, partition_fnames)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/video_prediction/flow_ops.py b/video_prediction/flow_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..caf9bad772621d6fbbf8a28895f16d03bed0561a
--- /dev/null
+++ b/video_prediction/flow_ops.py
@@ -0,0 +1,79 @@
+import tensorflow as tf
+
+
+def image_warp(im, flow):
+    """Performs a backward warp of an image using the predicted flow.
+
+    Args:
+        im: Batch of images. [num_batch, height, width, channels]
+        flow: Batch of flow vectors. [num_batch, height, width, 2]
+    Returns:
+        warped: transformed image of the same shape as the input image.
+
+    Implementation taken from here: https://github.com/simonmeister/UnFlow
+    """
+    with tf.variable_scope('image_warp'):
+
+        num_batch, height, width, channels = tf.unstack(tf.shape(im))
+        max_x = tf.cast(width - 1, 'int32')
+        max_y = tf.cast(height - 1, 'int32')
+        zero = tf.zeros([], dtype='int32')
+
+        # We have to flatten our tensors to vectorize the interpolation
+        im_flat = tf.reshape(im, [-1, channels])
+        flow_flat = tf.reshape(flow, [-1, 2])
+
+        # Floor the flow, as the final indices are integers
+        # The fractional part is used to control the bilinear interpolation.
+        flow_floor = tf.to_int32(tf.floor(flow_flat))
+        bilinear_weights = flow_flat - tf.floor(flow_flat)
+
+        # Construct base indices which are displaced with the flow
+        pos_x = tf.tile(tf.range(width), [height * num_batch])
+        grid_y = tf.tile(tf.expand_dims(tf.range(height), 1), [1, width])
+        pos_y = tf.tile(tf.reshape(grid_y, [-1]), [num_batch])
+
+        x = flow_floor[:, 0]
+        y = flow_floor[:, 1]
+        xw = bilinear_weights[:, 0]
+        yw = bilinear_weights[:, 1]
+
+        # Compute interpolation weights for 4 adjacent pixels
+        # expand to num_batch * height * width x 1 for broadcasting in add_n below
+        wa = tf.expand_dims((1 - xw) * (1 - yw), 1) # top left pixel
+        wb = tf.expand_dims((1 - xw) * yw, 1) # bottom left pixel
+        wc = tf.expand_dims(xw * (1 - yw), 1) # top right pixel
+        wd = tf.expand_dims(xw * yw, 1) # bottom right pixel
+
+        x0 = pos_x + x
+        x1 = x0 + 1
+        y0 = pos_y + y
+        y1 = y0 + 1
+
+        x0 = tf.clip_by_value(x0, zero, max_x)
+        x1 = tf.clip_by_value(x1, zero, max_x)
+        y0 = tf.clip_by_value(y0, zero, max_y)
+        y1 = tf.clip_by_value(y1, zero, max_y)
+
+        dim1 = width * height
+        batch_offsets = tf.range(num_batch) * dim1
+        base_grid = tf.tile(tf.expand_dims(batch_offsets, 1), [1, dim1])
+        base = tf.reshape(base_grid, [-1])
+
+        base_y0 = base + y0 * width
+        base_y1 = base + y1 * width
+        idx_a = base_y0 + x0
+        idx_b = base_y1 + x0
+        idx_c = base_y0 + x1
+        idx_d = base_y1 + x1
+
+        Ia = tf.gather(im_flat, idx_a)
+        Ib = tf.gather(im_flat, idx_b)
+        Ic = tf.gather(im_flat, idx_c)
+        Id = tf.gather(im_flat, idx_d)
+
+        warped_flat = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
+        warped = tf.reshape(warped_flat, [num_batch, height, width, channels])
+        warped.set_shape(im.shape)
+
+        return warped
diff --git a/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction/layers/BasicConvLSTMCell.py
new file mode 100644
index 0000000000000000000000000000000000000000..321f6cc7e05320cf83e1173d8004429edf07ec24
--- /dev/null
+++ b/video_prediction/layers/BasicConvLSTMCell.py
@@ -0,0 +1,148 @@
+
+import tensorflow as tf
+from .layer_def import *
+
+class ConvRNNCell(object):
+    """Abstract object representing an Convolutional RNN cell.
+    """
+
+    def __call__(self, inputs, state, scope=None):
+        """Run this RNN cell on inputs, starting from the given state.
+        """
+        raise NotImplementedError("Abstract method")
+
+    @property
+    def state_size(self):
+        """size(s) of state(s) used by this cell.
+        """
+        raise NotImplementedError("Abstract method")
+
+    @property
+    def output_size(self):
+        """Integer or TensorShape: size of outputs produced by this cell."""
+        raise NotImplementedError("Abstract method")
+
+    def zero_state(self,input, dtype):
+        """Return zero-filled state tensor(s).
+        Args:
+          batch_size: int, float, or unit Tensor representing the batch size.
+          dtype: the data type to use for the state.
+        Returns:
+          tensor of shape '[batch_size x shape[0] x shape[1] x num_features]
+          filled with zeros
+        """
+
+        shape = self.shape
+        num_features = self.num_features
+        #x= tf.placeholder(tf.float32, shape=[input.shape[0], shape[0], shape[1], num_features * 2])#Bing: add this to
+        zeros = tf.zeros([tf.shape(input)[0], shape[0], shape[1], num_features * 2])
+        #zeros = tf.zeros_like(x)
+        return zeros
+
+
+class BasicConvLSTMCell(ConvRNNCell):
+    """Basic Conv LSTM recurrent network cell. The
+    """
+
+    def __init__(self, shape, filter_size, num_features, forget_bias=1.0, input_size=None,
+                 state_is_tuple=False, activation=tf.nn.tanh):
+        """Initialize the basic Conv LSTM cell.
+        Args:
+          shape: int tuple thats the height and width of the cell
+          filter_size: int tuple thats the height and width of the filter
+          num_features: int thats the depth of the cell
+          forget_bias: float, The bias added to forget gates (see above).
+          input_size: Deprecated and unused.
+          state_is_tuple: If True, accepted and returned states are 2-tuples of
+            the `c_state` and `m_state`.  If False, they are concatenated
+            along the column axis.  The latter behavior will soon be deprecated.
+          activation: Activation function of the inner states.
+        """
+        # if not state_is_tuple:
+        # logging.warn("%s: Using a concatenated state is slower and will soon be "
+        #             "deprecated.  Use state_is_tuple=True.", self)
+        if input_size is not None:
+            logging.warn("%s: The input_size parameter is deprecated.", self)
+        self.shape = shape
+        self.filter_size = filter_size
+        self.num_features = num_features
+        self._forget_bias = forget_bias
+        self._state_is_tuple = state_is_tuple
+        self._activation = activation
+
+    @property
+    def state_size(self):
+        return (LSTMStateTuple(self._num_units, self._num_units)
+                if self._state_is_tuple else 2 * self._num_units)
+
+    @property
+    def output_size(self):
+        return self._num_units
+
+    def __call__(self, inputs, state, scope=None,reuse=None):
+        """Long short-term memory cell (LSTM)."""
+        with tf.variable_scope(scope or type(self).__name__,reuse=reuse):  # "BasicLSTMCell"
+            # Parameters of gates are concatenated into one multiply for efficiency.
+            if self._state_is_tuple:
+                c, h = state
+            else:
+                c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state)
+            concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True)
+
+            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+            i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat)
+
+            new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) *
+                     self._activation(j))
+            new_h = self._activation(new_c) * tf.nn.sigmoid(o)
+
+            if self._state_is_tuple:
+                new_state = LSTMStateTuple(new_c, new_h)
+            else:
+                new_state = tf.concat(axis = 3, values = [new_c, new_h])
+            return new_h, new_state
+
+
+def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=None):
+    """convolution:
+    Args:
+      args: a 4D Tensor or a list of 4D, batch x n, Tensors.
+      filter_size: int tuple of filter height and width.
+      num_features: int, number of features.
+      bias_start: starting value to initialize the bias; 0 by default.
+      scope: VariableScope for the created subgraph; defaults to "Linear".
+    Returns:
+      A 4D Tensor with shape [batch h w num_features]
+    Raises:
+      ValueError: if some of the arguments has unspecified or wrong shape.
+    """
+
+    # Calculate the total size of arguments on dimension 1.
+    total_arg_size_depth = 0
+    shapes = [a.get_shape().as_list() for a in args]
+    for shape in shapes:
+        if len(shape) != 4:
+            raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes))
+        if not shape[3]:
+            raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes))
+        else:
+            total_arg_size_depth += shape[3]
+
+    dtype = [a.dtype for a in args][0]
+
+    # Now the computation.
+    with tf.variable_scope(scope or "Conv"):
+        matrix = tf.get_variable(
+            "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype)
+        if len(args) == 1:
+            res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME')
+        else:
+            res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME')
+        if not bias:
+            return res
+        bias_term = tf.get_variable(
+            "Bias", [num_features],
+            dtype = dtype,
+            initializer = tf.constant_initializer(
+                bias_start, dtype = dtype))
+    return res + bias_term
diff --git a/video_prediction/layers/__init__.py b/video_prediction/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8530ffd70f0899c8d8e0832d0dcd377b78bbe349
--- /dev/null
+++ b/video_prediction/layers/__init__.py
@@ -0,0 +1 @@
+from .normalization import fused_instance_norm
diff --git a/video_prediction/layers/layer_def.py b/video_prediction/layers/layer_def.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b7f4387001c9318507ad809d7176071312742d0
--- /dev/null
+++ b/video_prediction/layers/layer_def.py
@@ -0,0 +1,160 @@
+"""functions used to construct different architectures
+"""
+
+import tensorflow as tf
+import numpy as np
+
+weight_decay = 0.0005
+def _activation_summary(x):
+    """Helper to create summaries for activations.
+    Creates a summary that provides a histogram of activations.
+    Creates a summary that measure the sparsity of activations.
+    Args:
+      x: Tensor
+    Returns:
+      nothing
+    """
+    tensor_name = x.op.name
+    tf.summary.histogram(tensor_name + '/activations', x)
+    tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
+
+def _variable_on_cpu(name, shape, initializer):
+    """Helper to create a Variable stored on CPU memory.
+    Args:
+      name: name of the variable
+      shape: list of ints
+      initializer: initializer for Variable
+    Returns:
+      Variable Tensor
+    """
+    with tf.device('/cpu:0'):
+        var = tf.get_variable(name, shape, initializer=initializer)
+    return var
+
+
+def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.layers.xavier_initializer()):
+    """Helper to create an initialized Variable with weight decay.
+    Note that the Variable is initialized with a truncated normal distribution.
+    A weight decay is added only if one is specified.
+    Args:
+      name: name of the variable
+      shape: list of ints
+      stddev: standard deviation of a truncated Gaussian
+      wd: add L2Loss weight decay multiplied by this float. If None, weight
+          decay is not added for this Variable.
+    Returns:
+      Variable Tensor
+    """
+    #var = _variable_on_cpu(name, shape,tf.truncated_normal_initializer(stddev = stddev))
+    var = _variable_on_cpu(name, shape, initializer)
+    if wd:
+        weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name = 'weight_loss')
+        weight_decay.set_shape([])
+        tf.add_to_collection('losses', weight_decay)
+    return var
+
+
+def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"):
+    print("conv_layer activation function",activate)
+    
+    with tf.variable_scope('{0}_conv'.format(idx)) as scope:
+ 
+        input_channels = inputs.get_shape()[-1]
+        weights = _variable_with_weight_decay('weights',shape = [kernel_size, kernel_size, 
+                                                                 input_channels, num_features],
+                                              stddev = 0.01, wd = weight_decay)
+        biases = _variable_on_cpu('biases', [num_features], initializer)
+        conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding = 'SAME')
+        conv_biased = tf.nn.bias_add(conv, biases)
+        if activate == "linear":
+            return conv_biased
+        elif activate == "relu":
+            conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx))  
+        elif activate == "elu":
+            conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx))   
+        elif activate == "leaky_relu":
+            conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx))
+        else:
+            raise ("activation function is not correct")
+        return conv_rect
+
+
+def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer(),activate="relu"):
+    with tf.variable_scope('{0}_trans_conv'.format(idx)) as scope:
+        input_channels = inputs.get_shape()[3]
+        input_shape = inputs.get_shape().as_list()
+
+
+        weights = _variable_with_weight_decay('weights',
+                                              shape = [kernel_size, kernel_size, num_features, input_channels],
+                                              stddev = 0.1, wd = weight_decay)
+        biases = _variable_on_cpu('biases', [num_features],initializer)
+        batch_size = tf.shape(inputs)[0]
+
+        output_shape = tf.stack(
+            [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features])
+        print ("output_shape",output_shape)
+        conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME')
+        conv_biased = tf.nn.bias_add(conv, biases)
+        if activate == "linear":
+            return conv_biased
+        elif activate == "elu":
+            return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx))       
+        elif activate == "relu":
+            return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx))
+        elif activate == "leaky_relu":
+            return tf.nn.leaky_relu(conv_biased, name = '{0}_transpose_conv'.format(idx))
+        elif activate == "sigmoid":
+            return tf.nn.sigmoid(conv_biased, name ='sigmoid') 
+        else:
+            return conv_biased
+    
+
+def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01,initializer=tf.contrib.layers.xavier_initializer()):
+    with tf.variable_scope('{0}_fc'.format(idx)) as scope:
+        input_shape = inputs.get_shape().as_list()
+        if flat:
+            dim = input_shape[1] * input_shape[2] * input_shape[3]
+            inputs_processed = tf.reshape(inputs, [-1, dim])
+        else:
+            dim = input_shape[1]
+            inputs_processed = inputs
+
+        weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init,
+                                              wd = weight_decay)
+        biases = _variable_on_cpu('biases', [hiddens],initializer)
+        if activate == "linear":
+            return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')
+        elif activate == "sigmoid":
+            return tf.nn.sigmoid(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+        elif activate == "softmax":
+            return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+        elif activate == "relu":
+            return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+        elif activate == "leaky_relu":
+            return tf.nn.leaky_relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))        
+        else:
+            ip = tf.add(tf.matmul(inputs_processed, weights), biases)
+            return tf.nn.elu(ip, name = str(idx) + '_fc')
+        
+def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None):
+    with tf.variable_scope('{0}_bn'.format(idx)) as scope:
+        #Calculate batch mean and variance
+        shape = inputs.get_shape().as_list()
+        scale = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=is_training)
+        beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=is_training)
+        pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
+        pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)
+        
+        if is_training:
+            batch_mean, batch_var = tf.nn.moments(inputs,[0])
+            train_mean = tf.assign(pop_mean,pop_mean * decay + batch_mean * (1 - decay))
+            train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
+            with tf.control_dependencies([train_mean,train_var]):
+                 return tf.nn.batch_normalization(inputs,batch_mean,batch_var,beta,scale,epsilon)
+        else:
+             return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon)
+
+def bn_layers_wrapper(inputs, is_training):
+    pass
+   
\ No newline at end of file
diff --git a/video_prediction/layers/mcnet_ops.py b/video_prediction/layers/mcnet_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..656f66c0df1cf199fff319f7b81b01594f96332c
--- /dev/null
+++ b/video_prediction/layers/mcnet_ops.py
@@ -0,0 +1,178 @@
+import math
+import numpy as np 
+import tensorflow as tf
+
+from tensorflow.python.framework import ops
+from video_prediction.utils.mcnet_utils import *
+
+
+def batch_norm(inputs, name, train=True, reuse=False):
+    return tf.contrib.layers.batch_norm(inputs=inputs,is_training=train,
+                                      reuse=reuse,scope=name,scale=True)
+
+
+def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="conv2d", reuse=False, padding='SAME'):
+    with tf.variable_scope(name, reuse=reuse):
+        w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
+                         initializer=tf.contrib.layers.xavier_initializer())
+        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding)
+ 
+        biases = tf.get_variable('biases', [output_dim],
+                              initializer=tf.constant_initializer(0.0))
+        conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
+ 
+    return conv
+
+
+def deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 
+             name="deconv2d", reuse=False, with_w=False, padding='SAME'):
+        with tf.variable_scope(name, reuse=reuse):
+          # filter : [height, width, output_channels, in_channels]
+            w = tf.get_variable('w', [k_h, k_h, output_shape[-1],
+                                input_.get_shape()[-1]],
+                                initializer=tf.contrib.layers.xavier_initializer())
+    
+            try:
+                deconv = tf.nn.conv2d_transpose(input_, w,
+                                      output_shape=output_shape,
+                                      strides=[1, d_h, d_w, 1],
+                                      padding=padding)
+
+            # Support for verisons of TensorFlow before 0.7.0
+            except AttributeError:
+                deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
+                          strides=[1, d_h, d_w, 1])
+            biases = tf.get_variable('biases', [output_shape[-1]],
+                             initializer=tf.constant_initializer(0.0))
+            deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
+
+            if with_w:
+                return deconv, w, biases
+            else:
+                return deconv
+
+
+def lrelu(x, leak=0.2, name="lrelu"):
+    with tf.variable_scope(name):
+        f1 = 0.5 * (1 + leak)
+        f2 = 0.5 * (1 - leak)
+    return f1 * x + f2 * abs(x)
+
+
+def relu(x):
+    return tf.nn.relu(x)
+
+
+def tanh(x):
+    return tf.nn.tanh(x)
+
+
+def shape2d(a):
+    """
+    a: a int or tuple/list of length 2
+    """
+    if type(a) == int:
+        return [a, a]
+    if isinstance(a, (list, tuple)):
+        assert len(a) == 2
+        return list(a)
+    raise RuntimeError("Illegal shape: {}".format(a))
+
+
+def shape4d(a):
+  # for use with tensorflow
+    return [1] + shape2d(a) + [1]
+
+
+def UnPooling2x2ZeroFilled(x):
+    out = tf.concat(axis=3, values=[x, tf.zeros_like(x)])
+    out = tf.concat(axis=2, values=[out, tf.zeros_like(out)])
+
+    sh = x.get_shape().as_list()
+    if None not in sh[1:]:
+        out_size = [-1, sh[1] * 2, sh[2] * 2, sh[3]]
+        return tf.reshape(out, out_size)
+    else:
+        sh = tf.shape(x)
+        return tf.reshape(out, [-1, sh[1] * 2, sh[2] * 2, sh[3]])
+
+
+def MaxPooling(x, shape, stride=None, padding='VALID'):
+    """
+    MaxPooling on images.
+    :param input: NHWC tensor.
+    :param shape: int or [h, w]
+    :param stride: int or [h, w]. default to be shape.
+    :param padding: 'valid' or 'same'. default to 'valid'
+    :returns: NHWC tensor.
+    """
+    padding = padding.upper()
+    shape = shape4d(shape)
+    if stride is None:
+        stride = shape
+    else:
+        stride = shape4d(stride)
+
+    return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding)
+
+
+#@layer_register()
+def FixedUnPooling(x, shape):
+    """
+    Unpool the input with a fixed mat to perform kronecker product with.
+    :param input: NHWC tensor
+    :param shape: int or [h, w]
+    :returns: NHWC tensor
+    """
+    shape = shape2d(shape)
+  
+     # a faster implementation for this special case
+    return UnPooling2x2ZeroFilled(x)
+
+
+def gdl(gen_frames, gt_frames, alpha):
+    """
+    Calculates the sum of GDL losses between the predicted and gt frames.
+    @param gen_frames: The predicted frames at each scale.
+    @param gt_frames: The ground truth frames at each scale
+    @param alpha: The power to which each gradient term is raised.
+    @return: The GDL loss.
+    """
+    # create filters [-1, 1] and [[1],[-1]]
+    # for diffing to the left and down respectively.
+    pos = tf.constant(np.identity(3), dtype=tf.float32)
+    neg = -1 * pos
+    # [-1, 1]
+    filter_x = tf.expand_dims(tf.stack([neg, pos]), 0)
+    # [[1],[-1]]
+    filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)])
+    strides = [1, 1, 1, 1]  # stride of (1, 1)
+    padding = 'SAME'
+
+    gen_dx = tf.abs(tf.nn.conv2d(gen_frames, filter_x, strides, padding=padding))
+    gen_dy = tf.abs(tf.nn.conv2d(gen_frames, filter_y, strides, padding=padding))
+    gt_dx = tf.abs(tf.nn.conv2d(gt_frames, filter_x, strides, padding=padding))
+    gt_dy = tf.abs(tf.nn.conv2d(gt_frames, filter_y, strides, padding=padding))
+
+    grad_diff_x = tf.abs(gt_dx - gen_dx)
+    grad_diff_y = tf.abs(gt_dy - gen_dy)
+
+    gdl_loss = tf.reduce_mean((grad_diff_x ** alpha + grad_diff_y ** alpha))
+
+    # condense into one tensor and avg
+    return gdl_loss
+
+
+def linear(input_, output_size, name, stddev=0.02, bias_start=0.0,
+           reuse=False, with_w=False):
+    shape = input_.get_shape().as_list()
+
+    with tf.variable_scope(name, reuse=reuse):
+        matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
+                                 tf.random_normal_initializer(stddev=stddev))
+        bias = tf.get_variable("bias", [output_size],
+            initializer=tf.constant_initializer(bias_start))
+        if with_w:
+            return tf.matmul(input_, matrix) + bias, matrix, bias
+        else:
+            return tf.matmul(input_, matrix) + bias
diff --git a/video_prediction/layers/normalization.py b/video_prediction/layers/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f2a9bb9c0cf41eb7ed10f25e606c06791e0d2a3
--- /dev/null
+++ b/video_prediction/layers/normalization.py
@@ -0,0 +1,196 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Contains the normalization layer classes and their functional aliases."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.contrib.framework.python.ops import variables
+from tensorflow.contrib.layers.python.layers import utils
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope
+
+
+DATA_FORMAT_NCHW = 'NCHW'
+DATA_FORMAT_NHWC = 'NHWC'
+
+
+def fused_instance_norm(inputs,
+                        center=True,
+                        scale=True,
+                        epsilon=1e-6,
+                        activation_fn=None,
+                        param_initializers=None,
+                        reuse=None,
+                        variables_collections=None,
+                        outputs_collections=None,
+                        trainable=True,
+                        data_format=DATA_FORMAT_NHWC,
+                        scope=None):
+  """Functional interface for the instance normalization layer.
+
+  Reference: https://arxiv.org/abs/1607.08022.
+
+    "Instance Normalization: The Missing Ingredient for Fast Stylization"
+    Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky
+
+  Args:
+    inputs: A tensor with 2 or more dimensions, where the first dimension has
+      `batch_size`. The normalization is over all but the last dimension if
+      `data_format` is `NHWC` and the second dimension if `data_format` is
+      `NCHW`.
+    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
+      is ignored.
+    scale: If True, multiply by `gamma`. If False, `gamma` is
+      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
+      disabled since the scaling can be done by the next layer.
+    epsilon: Small float added to variance to avoid dividing by zero.
+    activation_fn: Activation function, default set to None to skip it and
+      maintain a linear activation.
+    param_initializers: Optional initializers for beta, gamma, moving mean and
+      moving variance.
+    reuse: Whether or not the layer and its variables should be reused. To be
+      able to reuse the layer scope must be given.
+    variables_collections: Optional collections for the variables.
+    outputs_collections: Collections to add the outputs.
+    trainable: If `True` also add variables to the graph collection
+      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    data_format: A string. `NHWC` (default) and `NCHW` are supported.
+    scope: Optional scope for `variable_scope`.
+
+  Returns:
+    A `Tensor` representing the output of the operation.
+
+  Raises:
+    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
+    ValueError: If the rank of `inputs` is undefined.
+    ValueError: If rank or channels dimension of `inputs` is undefined.
+  """
+  inputs = ops.convert_to_tensor(inputs)
+  inputs_shape = inputs.shape
+  inputs_rank = inputs.shape.ndims
+
+  if inputs_rank is None:
+    raise ValueError('Inputs %s has undefined rank.' % inputs.name)
+  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
+    raise ValueError('data_format has to be either NCHW or NHWC.')
+
+  with variable_scope.variable_scope(
+      scope, 'InstanceNorm', [inputs], reuse=reuse) as sc:
+    if data_format == DATA_FORMAT_NCHW:
+      reduction_axis = 1
+      # For NCHW format, rather than relying on implicit broadcasting, we
+      # explicitly reshape the params to params_shape_broadcast when computing
+      # the moments and the batch normalization.
+      params_shape_broadcast = list(
+          [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
+    else:
+      reduction_axis = inputs_rank - 1
+      params_shape_broadcast = None
+    moments_axes = list(range(inputs_rank))
+    del moments_axes[reduction_axis]
+    del moments_axes[0]
+    params_shape = inputs_shape[reduction_axis:reduction_axis + 1]
+    if not params_shape.is_fully_defined():
+      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
+          inputs.name, params_shape))
+
+    # Allocate parameters for the beta and gamma of the normalization.
+    beta, gamma = None, None
+    dtype = inputs.dtype.base_dtype
+    if param_initializers is None:
+      param_initializers = {}
+    if center:
+      beta_collections = utils.get_variable_collections(
+          variables_collections, 'beta')
+      beta_initializer = param_initializers.get(
+          'beta', init_ops.zeros_initializer())
+      beta = variables.model_variable('beta',
+                                      shape=params_shape,
+                                      dtype=dtype,
+                                      initializer=beta_initializer,
+                                      collections=beta_collections,
+                                      trainable=trainable)
+      if params_shape_broadcast:
+        beta = array_ops.reshape(beta, params_shape_broadcast)
+    if scale:
+      gamma_collections = utils.get_variable_collections(
+          variables_collections, 'gamma')
+      gamma_initializer = param_initializers.get(
+          'gamma', init_ops.ones_initializer())
+      gamma = variables.model_variable('gamma',
+                                       shape=params_shape,
+                                       dtype=dtype,
+                                       initializer=gamma_initializer,
+                                       collections=gamma_collections,
+                                       trainable=trainable)
+      if params_shape_broadcast:
+        gamma = array_ops.reshape(gamma, params_shape_broadcast)
+
+    if data_format == DATA_FORMAT_NHWC:
+      inputs = array_ops.transpose(inputs, list(range(1, reduction_axis)) + [0, reduction_axis])
+    if data_format == DATA_FORMAT_NCHW:
+      inputs = array_ops.transpose(inputs, list(range(2, inputs_rank)) + [0, reduction_axis])
+    hw, n, c = inputs.shape.as_list()[:-2], inputs.shape[-2].value, inputs.shape[-1].value
+    inputs = array_ops.reshape(inputs, [1] + hw + [n * c])
+    if inputs.shape.ndims != 4:
+        # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W]
+        if inputs.shape.ndims > 4:
+            inputs_ndims4_shape = [1, hw[0], -1, n * c]
+        else:
+            inputs_ndims4_shape = [1, 1, -1, n * c]
+        inputs = array_ops.reshape(inputs, inputs_ndims4_shape)
+    beta = array_ops.reshape(array_ops.tile(beta[None, :], [n, 1]), [-1])
+    gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [n, 1]), [-1])
+
+    outputs, _, _ = nn.fused_batch_norm(
+        inputs, gamma, beta, epsilon=epsilon,
+        data_format=DATA_FORMAT_NHWC, name='instancenorm')
+
+    outputs = array_ops.reshape(outputs, hw + [n, c])
+    if data_format == DATA_FORMAT_NHWC:
+      outputs = array_ops.transpose(outputs, [inputs_rank - 2] + list(range(inputs_rank - 2)) + [inputs_rank - 1])
+    if data_format == DATA_FORMAT_NCHW:
+      outputs = array_ops.transpose(outputs, [inputs_rank - 2, inputs_rank - 1] + list(range(inputs_rank - 2)))
+
+    # if data_format == DATA_FORMAT_NHWC:
+    #   inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis)))
+    # inputs_nchw_shape = inputs.shape
+    # inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:])
+    # if inputs.shape.ndims != 4:
+    #     # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W]
+    #     if inputs.shape.ndims > 4:
+    #         inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]]
+    #     else:
+    #         inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1]
+    #     inputs = array_ops.reshape(inputs, inputs_ndims4_shape)
+    # beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
+    # gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
+    #
+    # outputs, _, _ = nn.fused_batch_norm(
+    #     inputs, gamma, beta, epsilon=epsilon,
+    #     data_format=DATA_FORMAT_NCHW, name='instancenorm')
+    #
+    # outputs = array_ops.reshape(outputs, inputs_nchw_shape)
+    # if data_format == DATA_FORMAT_NHWC:
+    #   outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1])
+
+    if activation_fn is not None:
+      outputs = activation_fn(outputs)
+    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
diff --git a/video_prediction/losses.py b/video_prediction/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..662da29dbadf091bc59fcd0e7ed62fd71bcf0f81
--- /dev/null
+++ b/video_prediction/losses.py
@@ -0,0 +1,67 @@
+import tensorflow as tf
+
+from video_prediction.ops import sigmoid_kl_with_logits
+
+
+def l1_loss(pred, target):
+    return tf.reduce_mean(tf.abs(target - pred))
+
+
+def l2_loss(pred, target):
+    return tf.reduce_mean(tf.square(target - pred))
+
+
+def normalize_tensor(tensor, eps=1e-10):
+    norm_factor = tf.norm(tensor, axis=-1, keepdims=True)
+    return tensor / (norm_factor + eps)
+
+
+def cosine_distance(tensor0, tensor1, keep_axis=None):
+    tensor0 = normalize_tensor(tensor0)
+    tensor1 = normalize_tensor(tensor1)
+    return tf.reduce_mean(tf.reduce_sum(tf.square(tensor0 - tensor1), axis=-1)) / 2.0
+                                
+
+def charbonnier_loss(x, epsilon=0.001):
+    return tf.reduce_mean(tf.sqrt(tf.square(x) + tf.square(epsilon)))
+
+
+def gan_loss(logits, labels, gan_loss_type):
+    # use 1.0 (or 1.0 - discrim_label_smooth) for real data and 0.0 for fake data
+    if gan_loss_type == 'GAN':
+        # discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
+        # gen_loss = tf.reduce_mean(-tf.log(predict_fake + EPS))
+        if labels in (0.0, 1.0):
+            labels = tf.constant(labels, dtype=logits.dtype, shape=logits.get_shape())
+            loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
+        else:
+            loss = tf.reduce_mean(sigmoid_kl_with_logits(logits, labels))
+    elif gan_loss_type == 'LSGAN':
+        # discrim_loss = tf.reduce_mean((tf.square(predict_real - 1) + tf.square(predict_fake)))
+        # gen_loss = tf.reduce_mean(tf.square(predict_fake - 1))
+        loss = tf.reduce_mean(tf.square(logits - labels))
+    elif gan_loss_type == 'SNGAN':
+        # this is the form of the loss used in the official implementation of the SNGAN paper, but it leads to
+        # worse results in our video prediction experiments
+        if labels == 0.0:
+            loss = tf.reduce_mean(tf.nn.softplus(logits))
+        elif labels == 1.0:
+            loss = tf.reduce_mean(tf.nn.softplus(-logits))
+        else:
+            raise NotImplementedError
+    else:
+        raise ValueError('Unknown GAN loss type %s' % gan_loss_type)
+    return loss
+
+
+def kl_loss(mu, log_sigma_sq, mu2=None, log_sigma2_sq=None):
+    if mu2 is None and log_sigma2_sq is None:
+        sigma_sq = tf.exp(log_sigma_sq)
+        return -0.5 * tf.reduce_mean(tf.reduce_sum(1 + log_sigma_sq - tf.square(mu) - sigma_sq, axis=-1))
+    else:
+        mu1 = mu
+        log_sigma1_sq = log_sigma_sq
+        return tf.reduce_mean(tf.reduce_sum(
+            (log_sigma2_sq - log_sigma1_sq) / 2
+            + (tf.exp(log_sigma1_sq) + tf.square(mu1 - mu2)) / (2 * tf.exp(log_sigma2_sq))
+            - 1 / 2, axis=-1))
diff --git a/video_prediction/metrics.py b/video_prediction/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c057fc7ac59c137fcba64579e4b52a4bd0a55c
--- /dev/null
+++ b/video_prediction/metrics.py
@@ -0,0 +1,26 @@
+import tensorflow as tf
+#import lpips_tf
+
+
+def mse(a, b):
+    return tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1])
+
+
+def psnr(a, b):
+    return tf.image.psnr(a, b, 1.0)
+
+
+def ssim(a, b):
+    return tf.image.ssim(a, b, 1.0)
+
+
+# def lpips(input0, input1):
+#     if input0.shape[-1].value == 1:
+#         input0 = tf.tile(input0, [1] * (input0.shape.ndims - 1) + [3])
+#     if input1.shape[-1].value == 1:
+#         input1 = tf.tile(input1, [1] * (input1.shape.ndims - 1) + [3])
+#
+#     distance = lpips_tf.lpips(input0, input1)
+#     return -distance
+
+
diff --git a/video_prediction/models/__init__.py b/video_prediction/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d7323f3750949b0ddb411d4a98934928537bc53
--- /dev/null
+++ b/video_prediction/models/__init__.py
@@ -0,0 +1,30 @@
+from .base_model import BaseVideoPredictionModel
+from .base_model import VideoPredictionModel
+from .non_trainable_model import NonTrainableVideoPredictionModel
+from .non_trainable_model import GroundTruthVideoPredictionModel
+from .non_trainable_model import RepeatVideoPredictionModel
+from .savp_model import SAVPVideoPredictionModel
+from .dna_model import DNAVideoPredictionModel
+from .sna_model import SNAVideoPredictionModel
+from .sv2p_model import SV2PVideoPredictionModel
+from .vanilla_vae_model import VanillaVAEVideoPredictionModel
+from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel
+from .mcnet_model import McNetVideoPredictionModel
+def get_model_class(model):
+    model_mappings = {
+        'ground_truth': 'GroundTruthVideoPredictionModel',
+        'repeat': 'RepeatVideoPredictionModel',
+        'savp': 'SAVPVideoPredictionModel',
+        'dna': 'DNAVideoPredictionModel',
+        'sna': 'SNAVideoPredictionModel',
+        'sv2p': 'SV2PVideoPredictionModel',
+        'vae': 'VanillaVAEVideoPredictionModel',
+        'convLSTM': 'VanillaConvLstmVideoPredictionModel',
+        'mcnet': 'McNetVideoPredictionModel',
+        
+        }
+    model_class = model_mappings.get(model, model)
+    model_class = globals().get(model_class)
+    if model_class is None or not issubclass(model_class, BaseVideoPredictionModel):
+        raise ValueError('Invalid model %s' % model)
+    return model_class
diff --git a/video_prediction/models/base_model.py b/video_prediction/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ebe228fcc9c90addf610bed44bb46f090c7e514
--- /dev/null
+++ b/video_prediction/models/base_model.py
@@ -0,0 +1,878 @@
+import functools
+import itertools
+import os
+import re
+from collections import OrderedDict
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib.training import HParams
+from tensorflow.python.util import nest
+
+import video_prediction as vp
+from video_prediction.utils import tf_utils
+from video_prediction.utils.tf_utils import compute_averaged_gradients, reduce_tensors, local_device_setter, \
+    replace_read_ops, print_loss_info, transpose_batch_time, add_gif_summaries, add_scalar_summaries, \
+    add_plot_and_scalar_summaries, add_summaries
+
+
+class BaseVideoPredictionModel(object):
+    def __init__(self, mode='train', hparams_dict=None, hparams=None,
+                 num_gpus=None, eval_num_samples=100,
+                 eval_num_samples_for_diversity=10, eval_parallel_iterations=1):
+        """
+        Base video prediction model.
+
+        Trainable and non-trainable video prediction models can be derived
+        from this base class.
+
+        Args:
+            mode: `'train'` or `'test'`.
+            hparams_dict: a dict of `name=value` pairs, where `name` must be
+                defined in `self.get_default_hparams()`.
+            hparams: a string of comma separated list of `name=value` pairs,
+                where `name` must be defined in `self.get_default_hparams()`.
+                These values overrides any values in hparams_dict (if any).
+        """
+        if mode not in ('train', 'test'):
+            raise ValueError('mode must be train or test, but %s given' % mode)
+        self.mode = mode
+        cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
+        if cuda_visible_devices == '':
+            max_num_gpus = 0
+        else:
+            max_num_gpus = len(cuda_visible_devices.split(','))
+        if num_gpus is None:
+            num_gpus = max_num_gpus
+        elif num_gpus > max_num_gpus:
+            raise ValueError('num_gpus=%d is greater than the number of visible devices %d' % (num_gpus, max_num_gpus))
+        self.num_gpus = num_gpus
+        self.eval_num_samples = eval_num_samples
+        self.eval_num_samples_for_diversity = eval_num_samples_for_diversity
+        self.eval_parallel_iterations = eval_parallel_iterations
+        self.hparams = self.parse_hparams(hparams_dict, hparams)
+        if self.hparams.context_frames == -1:
+            raise ValueError('Invalid context_frames %r. It might have to be '
+                             'specified.' % self.hparams.context_frames)
+        if self.hparams.sequence_length == -1:
+            raise ValueError('Invalid sequence_length %r. It might have to be '
+                             'specified.' % self.hparams.sequence_length)
+
+        # should be overriden by descendant class if the model is stochastic
+        self.deterministic = True
+
+        # member variables that should be set by `self.build_graph`
+        self.inputs = None
+        self.gen_images = None
+        self.outputs = None
+        self.metrics = None
+        self.eval_outputs = None
+        self.eval_metrics = None
+        self.accum_eval_metrics = None
+        self.saveable_variables = None
+        self.post_init_ops = None
+
+    def get_default_hparams_dict(self):
+        """
+        The keys of this dict define valid hyperparameters for instances of
+        this class. A class inheriting from this one should override this
+        method if it has a different set of hyperparameters.
+
+        Returns:
+            A dict with the following hyperparameters.
+
+            context_frames: the number of ground-truth frames to pass in at
+                start. Must be specified during instantiation.
+            sequence_length: the number of frames in the video sequence,
+                including the context frames, so this model predicts
+                `sequence_length - context_frames` future frames. Must be
+                specified during instantiation.
+            repeat: the number of repeat actions (if applicable).
+        """
+        hparams = dict(
+            context_frames=-1,
+            sequence_length=-1,
+            repeat=1,
+        )
+        return hparams
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
+
+    def parse_hparams(self, hparams_dict, hparams):
+        parsed_hparams = self.get_default_hparams().override_from_dict(hparams_dict or {})
+        if hparams:
+            if not isinstance(hparams, (list, tuple)):
+                hparams = [hparams]
+            for hparam in hparams:
+                parsed_hparams.parse(hparam)
+        return parsed_hparams
+
+    def build_graph(self, inputs):
+        self.inputs = inputs
+
+    def metrics_fn(self, inputs, outputs):
+        metrics = OrderedDict()
+        sequence_length = tf.shape(inputs['images'])[0]
+        context_frames = self.hparams.context_frames
+        future_length = sequence_length - context_frames
+        # target_images and pred_images include only the future frames
+        target_images = inputs['images'][-future_length:]
+        pred_images = outputs['gen_images'][-future_length:]
+        metric_fns = [
+            ('psnr', vp.metrics.psnr),
+            ('mse', vp.metrics.mse),
+            ('ssim', vp.metrics.ssim),
+            #('lpips', vp.metrics.lpips), #bing : remove lpips metric course the url fetching issue
+        ]
+        for metric_name, metric_fn in metric_fns:
+            metrics[metric_name] = tf.reduce_mean(metric_fn(target_images, pred_images))
+        return metrics
+
+    def eval_outputs_and_metrics_fn(self, inputs, outputs, num_samples=None,
+                                    num_samples_for_diversity=None, parallel_iterations=None):
+        num_samples = num_samples or self.eval_num_samples
+        num_samples_for_diversity = num_samples_for_diversity or self.eval_num_samples_for_diversity
+        parallel_iterations = parallel_iterations or self.eval_parallel_iterations
+
+        sequence_length, batch_size = inputs['images'].shape[:2].as_list()
+        if batch_size is None:
+            batch_size = tf.shape(inputs['images'])[1]
+        if sequence_length is None:
+            sequence_length = tf.shape(inputs['images'])[0]
+        context_frames = self.hparams.context_frames
+        future_length = sequence_length - context_frames
+        # the outputs include all the frames, whereas the metrics include only the future frames
+        eval_outputs = OrderedDict()
+        eval_metrics = OrderedDict()
+        metric_fns = [
+            ('psnr', vp.metrics.psnr),
+            ('mse', vp.metrics.mse),
+            ('ssim', vp.metrics.ssim),
+           # ('lpips', vp.metrics.lpips), #bing
+        ]
+        # images and gen_images include all the frames
+        images = inputs['images']
+        gen_images = outputs['gen_images']
+        # target_images and pred_images include only the future frames
+        target_images = inputs['images'][-future_length:]
+        pred_images = outputs['gen_images'][-future_length:]
+        # ground truth is the same for deterministic and stochastic models
+        eval_outputs['eval_images'] = images
+        if self.deterministic:
+            for metric_name, metric_fn in metric_fns:
+                metric = metric_fn(target_images, pred_images)
+                eval_metrics['eval_%s/min' % metric_name] = metric
+                eval_metrics['eval_%s/avg' % metric_name] = metric
+                eval_metrics['eval_%s/max' % metric_name] = metric
+            eval_outputs['eval_gen_images'] = gen_images
+        else:
+            def where_axis1(cond, x, y):
+                return transpose_batch_time(tf.where(cond, transpose_batch_time(x), transpose_batch_time(y)))
+
+            def sort_criterion(x):
+                return tf.reduce_mean(x, axis=0)
+
+            def accum_gen_images_and_metrics_fn(a, unused):
+                with tf.variable_scope(self.generator_scope, reuse=True):
+                    outputs_sample = self.generator_fn(inputs)
+                    gen_images_sample = outputs_sample['gen_images']
+                    pred_images_sample = gen_images_sample[-future_length:]
+                    # set the posisbly static shape since it might not have been inferred correctly
+                    pred_images_sample = tf.reshape(pred_images_sample, tf.shape(a['eval_pred_images_last']))
+                for name, metric_fn in metric_fns:
+                    metric = metric_fn(target_images, pred_images_sample)  # time, batch_size
+                    cond_min = tf.less(sort_criterion(metric), sort_criterion(a['eval_%s/min' % name]))
+                    cond_max = tf.greater(sort_criterion(metric), sort_criterion(a['eval_%s/max' % name]))
+                    a['eval_%s/min' % name] = where_axis1(cond_min, metric, a['eval_%s/min' % name])
+                    a['eval_%s/sum' % name] = metric + a['eval_%s/sum' % name]
+                    a['eval_%s/max' % name] = where_axis1(cond_max, metric, a['eval_%s/max' % name])
+                    a['eval_gen_images_%s/min' % name] = where_axis1(cond_min, gen_images_sample, a['eval_gen_images_%s/min' % name])
+                    a['eval_gen_images_%s/sum' % name] = gen_images_sample + a['eval_gen_images_%s/sum' % name]
+                    a['eval_gen_images_%s/max' % name] = where_axis1(cond_max, gen_images_sample, a['eval_gen_images_%s/max' % name])
+                #bing
+                # a['eval_diversity'] = tf.cond(
+                #     tf.logical_and(tf.less(0, a['eval_sample_ind']),
+                #                    tf.less_equal(a['eval_sample_ind'], num_samples_for_diversity)),
+                #     lambda: -vp.metrics.lpips(a['eval_pred_images_last'], pred_images_sample) + a['eval_diversity'],
+                #     lambda: a['eval_diversity'])
+                a['eval_sample_ind'] = 1 + a['eval_sample_ind']
+                a['eval_pred_images_last'] = pred_images_sample
+                return a
+
+            initializer = {}
+            for name, _ in metric_fns:
+                initializer['eval_gen_images_%s/min' % name] = tf.zeros_like(gen_images)
+                initializer['eval_gen_images_%s/sum' % name] = tf.zeros_like(gen_images)
+                initializer['eval_gen_images_%s/max' % name] = tf.zeros_like(gen_images)
+                initializer['eval_%s/min' % name] = tf.fill([future_length, batch_size], float('inf'))
+                initializer['eval_%s/sum' % name] = tf.zeros([future_length, batch_size])
+                initializer['eval_%s/max' % name] = tf.fill([future_length, batch_size], float('-inf'))
+            #initializer['eval_diversity'] = tf.zeros([future_length, batch_size])
+            initializer['eval_sample_ind'] = tf.zeros((), dtype=tf.int32)
+            initializer['eval_pred_images_last'] = tf.zeros_like(pred_images)
+
+            eval_outputs_and_metrics = tf.foldl(
+                accum_gen_images_and_metrics_fn, tf.zeros([num_samples, 0]), initializer=initializer, back_prop=False,
+                parallel_iterations=parallel_iterations)
+
+            for name, _ in metric_fns:
+                eval_outputs['eval_gen_images_%s/min' % name] = eval_outputs_and_metrics['eval_gen_images_%s/min' % name]
+                eval_outputs['eval_gen_images_%s/avg' % name] = eval_outputs_and_metrics['eval_gen_images_%s/sum' % name] / float(num_samples)
+                eval_outputs['eval_gen_images_%s/max' % name] = eval_outputs_and_metrics['eval_gen_images_%s/max' % name]
+                eval_metrics['eval_%s/min' % name] = eval_outputs_and_metrics['eval_%s/min' % name]
+                eval_metrics['eval_%s/avg' % name] = eval_outputs_and_metrics['eval_%s/sum' % name] / float(num_samples)
+                eval_metrics['eval_%s/max' % name] = eval_outputs_and_metrics['eval_%s/max' % name]
+            #eval_metrics['eval_diversity'] = eval_outputs_and_metrics['eval_diversity'] / float(num_samples_for_diversity)
+        return eval_outputs, eval_metrics
+
+    def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
+        if checkpoints:
+            var_list = self.saveable_variables
+            # possibly restore from multiple checkpoints. useful if subset of weights
+            # (e.g. generator or discriminator) are on different checkpoints.
+            if not isinstance(checkpoints, (list, tuple)):
+                checkpoints = [checkpoints]
+            # automatically skip global_step if more than one checkpoint is provided
+            skip_global_step = len(checkpoints) > 1
+            savers = []
+            for checkpoint in checkpoints:
+                print("creating restore saver from checkpoint %s" % checkpoint)
+                saver, _ = tf_utils.get_checkpoint_restore_saver(
+                    checkpoint, var_list, skip_global_step=skip_global_step,
+                    restore_to_checkpoint_mapping=restore_to_checkpoint_mapping)
+                savers.append(saver)
+            restore_op = [saver.saver_def.restore_op_name for saver in savers]
+            sess.run(restore_op)
+
+
+class VideoPredictionModel(BaseVideoPredictionModel):
+    def __init__(self,
+                 generator_fn,
+                 discriminator_fn=None,
+                 generator_scope='generator',
+                 discriminator_scope='discriminator',
+                 aggregate_nccl=False,
+                 mode='train',
+                 hparams_dict=None,
+                 hparams=None,
+                 **kwargs):
+        """
+        Trainable video prediction model with CPU and multi-GPU support.
+
+        If num_gpus <= 1, the devices for the ops in `self.build_graph` are
+        automatically chosen by TensorFlow (i.e. `tf.device` is not specified),
+        otherwise they are explicitly chosen.
+
+        Args:
+            generator_fn: callable that takes in inputs and returns a dict of
+                tensors.
+            discriminator_fn: callable that takes in fake/real data (and
+                optionally conditioned on inputs) and returns a dict of
+                tensors.
+            hparams_dict: a dict of `name=value` pairs, where `name` must be
+                defined in `self.get_default_hparams()`.
+            hparams: a string of comma separated list of `name=value` pairs,
+                where `name` must be defined in `self.get_default_hparams()`.
+                These values overrides any values in hparams_dict (if any).
+        """
+        super(VideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
+        self.generator_fn = functools.partial(generator_fn, mode=self.mode, hparams=self.hparams)
+        self.discriminator_fn = functools.partial(discriminator_fn, mode=self.mode, hparams=self.hparams) if discriminator_fn else None
+        self.generator_scope = generator_scope
+        self.discriminator_scope = discriminator_scope
+        self.aggregate_nccl = aggregate_nccl
+
+        if any(self.hparams.lr_boundaries):
+            global_step = tf.train.get_or_create_global_step()
+            lr_values = list(self.hparams.lr * 0.1 ** np.arange(len(self.hparams.lr_boundaries) + 1))
+            self.learning_rate = tf.train.piecewise_constant(global_step, self.hparams.lr_boundaries, lr_values)
+        elif any(self.hparams.decay_steps):
+            lr, end_lr = self.hparams.lr, self.hparams.end_lr
+            start_step, end_step = self.hparams.decay_steps
+            if start_step == end_step:
+                schedule = tf.cond(tf.less(tf.train.get_or_create_global_step(), start_step),
+                                   lambda: 0.0, lambda: 1.0)
+            else:
+                step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step)
+                schedule = tf.to_float(step - start_step) / tf.to_float(end_step - start_step)
+            self.learning_rate = lr + (end_lr - lr) * schedule
+        else:
+            self.learning_rate = self.hparams.lr
+
+        if self.hparams.kl_weight:
+            if self.hparams.kl_anneal == 'none':
+                self.kl_weight = tf.constant(self.hparams.kl_weight, tf.float32)
+            elif self.hparams.kl_anneal == 'sigmoid':
+                k = self.hparams.kl_anneal_k
+                if k == -1.0:
+                    raise ValueError('Invalid kl_anneal_k %d when kl_anneal is sigmoid.' % k)
+                iter_num = tf.train.get_or_create_global_step()
+                self.kl_weight = self.hparams.kl_weight / (1 + k * tf.exp(-tf.to_float(iter_num) / k))
+            elif self.hparams.kl_anneal == 'linear':
+                start_step, end_step = self.hparams.kl_anneal_steps
+                step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step)
+                self.kl_weight = self.hparams.kl_weight * tf.to_float(step - start_step) / tf.to_float(end_step - start_step)
+            else:
+                raise NotImplementedError
+        else:
+            self.kl_weight = None
+
+        # member variables that should be set by `self.build_graph`
+        # (in addition to the ones in the base class)
+        self.gen_images_enc = None
+        self.g_losses = None
+        self.d_losses = None
+        self.g_loss = None
+        self.d_loss = None
+        self.g_vars = None
+        self.d_vars = None
+        self.train_op = None
+        self.summary_op = None
+        self.image_summary_op = None
+        self.eval_summary_op = None
+        self.accum_eval_summary_op = None
+        self.accum_eval_metrics_reset_op = None
+
+    def get_default_hparams_dict(self):
+        """
+        The keys of this dict define valid hyperparameters for instances of
+        this class. A class inheriting from this one should override this
+        method if it has a different set of hyperparameters.
+
+        Returns:
+            A dict with the following hyperparameters.
+
+            batch_size: batch size for training.
+            lr: learning rate. if decay steps is non-zero, this is the
+                learning rate for steps <= decay_step.
+            end_lr: learning rate for steps >= end_decay_step if decay_steps
+                is non-zero, ignored otherwise.
+            decay_steps: (decay_step, end_decay_step) tuple.
+            max_steps: number of training steps.
+            beta1: momentum term of Adam.
+            beta2: momentum term of Adam.
+            context_frames: the number of ground-truth frames to pass in at
+                start. Must be specified during instantiation.
+            sequence_length: the number of frames in the video sequence,
+                including the context frames, so this model predicts
+                `sequence_length - context_frames` future frames. Must be
+                specified during instantiation.
+        """
+        default_hparams = super(VideoPredictionModel, self).get_default_hparams_dict()
+        hparams = dict(
+            batch_size=16,
+            lr=0.001,
+            end_lr=0.0,
+            decay_steps=(200000, 300000),
+            lr_boundaries=(0,),
+            max_steps=350000,
+            beta1=0.9,
+            beta2=0.999,
+            context_frames=-1,
+            sequence_length=-1,
+            clip_length=10, #Bing: TODO What is the clip_length, original is 10,
+            l1_weight=0.0,
+            l2_weight=1.0,
+            vgg_cdist_weight=0.0,
+            feature_l2_weight=0.0,
+            ae_l2_weight=0.0,
+            state_weight=0.0,
+            tv_weight=0.0,
+            image_sn_gan_weight=0.0,
+            image_sn_vae_gan_weight=0.0,
+            images_sn_gan_weight=0.0,
+            images_sn_vae_gan_weight=0.0,
+            video_sn_gan_weight=0.0,
+            video_sn_vae_gan_weight=0.0,
+            gan_feature_l2_weight=0.0,
+            gan_feature_cdist_weight=0.0,
+            vae_gan_feature_l2_weight=0.0,
+            vae_gan_feature_cdist_weight=0.0,
+            gan_loss_type='LSGAN',
+            joint_gan_optimization=False,
+            kl_weight=0.0,
+            kl_anneal='linear',
+            kl_anneal_k=-1.0,
+            kl_anneal_steps=(50000, 100000),
+            z_l1_weight=0.0,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def tower_fn(self, inputs):
+        """
+        This method doesn't have side-effects. `inputs`, `targets`, and
+        `outputs` are batch-major but internal calculations use time-major
+        tensors.
+        """
+        # batch-major to time-major
+        inputs = nest.map_structure(transpose_batch_time, inputs)
+
+        with tf.variable_scope(self.generator_scope):
+            gen_outputs = self.generator_fn(inputs)
+
+        if self.discriminator_fn:
+            with tf.variable_scope(self.discriminator_scope) as discrim_scope:
+                discrim_outputs = self.discriminator_fn(inputs, gen_outputs)
+            # post-update discriminator tensors (i.e. after the discriminator weights have been updated)
+            with tf.variable_scope(discrim_scope, reuse=True):
+                discrim_outputs_post = self.discriminator_fn(inputs, gen_outputs)
+        else:
+            discrim_outputs = {}
+            discrim_outputs_post = {}
+
+        outputs = [gen_outputs, discrim_outputs]
+        total_num_outputs = sum([len(output) for output in outputs])
+        outputs = OrderedDict(itertools.chain(*[output.items() for output in outputs]))
+        assert len(outputs) == total_num_outputs  # ensure no output is lost because of repeated keys
+
+        if isinstance(self.learning_rate, tf.Tensor):
+            outputs['learning_rate'] = self.learning_rate
+        if isinstance(self.kl_weight, tf.Tensor):
+            outputs['kl_weight'] = self.kl_weight
+
+        if self.mode == 'train':
+            with tf.name_scope("discriminator_loss"):
+                d_losses = self.discriminator_loss_fn(inputs, outputs)
+                print_loss_info(d_losses, inputs, outputs)
+            with tf.name_scope("generator_loss"):
+                g_losses = self.generator_loss_fn(inputs, outputs)
+                print_loss_info(g_losses, inputs, outputs)
+                if discrim_outputs_post:
+                    outputs_post = OrderedDict(itertools.chain(gen_outputs.items(), discrim_outputs_post.items()))
+                    # generator losses after the discriminator weights have been updated
+                    g_losses_post = self.generator_loss_fn(inputs, outputs_post)
+                else:
+                    g_losses_post = g_losses
+        else:
+            d_losses = {}
+            g_losses = {}
+            g_losses_post = {}
+        with tf.name_scope("metrics"):
+            metrics = self.metrics_fn(inputs, outputs)
+        with tf.name_scope("eval_outputs_and_metrics"):
+            eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs)
+
+        # time-major to batch-major
+        outputs_tuple = (outputs, eval_outputs)
+        outputs_tuple = nest.map_structure(transpose_batch_time, outputs_tuple)
+        losses_tuple = (d_losses, g_losses, g_losses_post)
+        losses_tuple = nest.map_structure(tf.convert_to_tensor, losses_tuple)
+        loss_tuple = tuple(tf.accumulate_n([loss * weight for loss, weight in losses.values()])
+                           if losses else tf.zeros(()) for losses in losses_tuple)
+        metrics_tuple = (metrics, eval_metrics)
+        metrics_tuple = nest.map_structure(transpose_batch_time, metrics_tuple)
+        return outputs_tuple, losses_tuple, loss_tuple, metrics_tuple
+
+    def build_graph(self, inputs,finetune=False):
+        BaseVideoPredictionModel.build_graph(self, inputs)
+
+        global_step = tf.train.get_or_create_global_step()
+        # Capture the variables created from here until the train_op for the
+        # saveable_variables. Note that if variables are being reused (e.g.
+        # they were created by a previously built model), those variables won't
+        # be captured here.
+        original_global_variables = tf.global_variables()
+
+
+        # ########Bing: fine-tune#######
+        # variables_to_restore = tf.contrib.framework.get_variables_to_restore(
+        #     exclude = ["discriminator/video/sn_fc4/dense/bias"])
+        # init_fn = tf.contrib.framework.assign_from_checkpoint_fn(checkpoint)
+        # restore_variables = tf.contrib.framework.get_variables("discriminator/video/sn_fc4/dense/bias")
+        # restore_init = tf.variables_initializer(restore_variables)
+        # restore_optimizer = tf.train.GradientDescentOptimizer(
+        #     learning_rate = 0.001)  # TODO: need to change the learning rate
+        # ###Bing: fine-tune#######
+        # skip_vars = {" discriminator_encoder/video_sn_fc4/dense/bias"}
+
+        if self.num_gpus <= 1:  # cpu or 1 gpu
+            outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs)
+            self.outputs, self.eval_outputs = outputs_tuple
+            self.d_losses, self.g_losses, g_losses_post = losses_tuple
+            self.d_loss, self.g_loss, g_loss_post = loss_tuple
+            self.metrics, self.eval_metrics = metrics_tuple
+
+            self.d_vars = tf.trainable_variables(self.discriminator_scope)
+            self.g_vars = tf.trainable_variables(self.generator_scope)
+            g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
+            d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
+
+            if finetune:
+                ##Bing: fine-tune
+                #self.g_vars = tf.contrib.framework.get_variables("discriminator/video/sn_fc4/dense/bias")#generator/encoder/layer_3/conv2d/kernel/Adam_1
+                self.g_vars = tf.contrib.framework.get_variables("discriminator/encoder/video/sn_conv3_0/conv3d/kernel")
+                self.g_vars_init = tf.variables_initializer(self.g_vars)
+                g_optimizer = tf.train.AdamOptimizer(0.00001)
+                print("############Bing: Fine Tune here##########")
+
+            if self.mode == 'train' and (self.d_losses or self.g_losses):
+                with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
+                    if self.d_losses:
+                        with tf.name_scope('d_compute_gradients'):
+                            d_gradvars = d_optimizer.compute_gradients(self.d_loss, var_list=self.d_vars)
+                        with tf.name_scope('d_apply_gradients'):
+                            d_train_op = d_optimizer.apply_gradients(d_gradvars)
+
+                    else:
+                        d_train_op = tf.no_op()
+                with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []):
+                    if g_losses_post:
+                        if not self.hparams.joint_gan_optimization:
+                            replace_read_ops(g_loss_post, self.d_vars)
+                        with tf.name_scope('g_compute_gradients'):
+                            g_gradvars = g_optimizer.compute_gradients(g_loss_post, var_list=self.g_vars)
+                        with tf.name_scope('g_apply_gradients'):
+                            g_train_op = g_optimizer.apply_gradients(g_gradvars)
+                        # #######Bing; finetune########
+                        # with tf.name_scope("finetune_gradients"):
+                        #     finetune_grads_vars = finetune_vars_optimizer.compute_gradients(self.d_loss, var_list = self.finetune_vars)
+                        # with tf.name_scope("finetune_apply_gradients"):
+                        #     finetune_train_op = finetune_vars_optimizer.apply_gradients(finetune_grads_vars)
+                    else:
+                        g_train_op = tf.no_op()
+                with tf.control_dependencies([g_train_op]):
+                    train_op = tf.assign_add(global_step, 1)
+                self.train_op = train_op
+            else:
+                self.train_op = None
+
+            global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+            self.saveable_variables = [global_step] + global_variables
+            self.post_init_ops = []
+        else:
+            if tf.get_variable_scope().name:
+                # This is because how variable scope works with empty strings when it's not the root scope, causing
+                # repeated forward slashes.
+                raise NotImplementedError('Unable to handle multi-gpu model created within a non-root variable scope.')
+
+            tower_inputs = [OrderedDict() for _ in range(self.num_gpus)]
+            for name, input in self.inputs.items():
+                input_splits = tf.split(input, self.num_gpus)  # assumes batch_size is divisible by num_gpus
+                for i in range(self.num_gpus):
+                    tower_inputs[i][name] = input_splits[i]
+
+            tower_outputs_tuple = []
+            tower_d_losses = []
+            tower_g_losses = []
+            tower_g_losses_post = []
+            tower_d_loss = []
+            tower_g_loss = []
+            tower_g_loss_post = []
+            tower_metrics_tuple = []
+            for i in range(self.num_gpus):
+                worker_device = '/gpu:%d' % i
+                if self.aggregate_nccl:
+                    scope_name = '' if i == 0 else 'v%d' % i
+                    scope_reuse = False
+                    device_setter = worker_device
+                else:
+                    scope_name = ''
+                    scope_reuse = i > 0
+                    device_setter = local_device_setter(worker_device=worker_device)
+                with tf.variable_scope(scope_name, reuse=scope_reuse):
+                    with tf.device(device_setter):
+                        outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(tower_inputs[i])
+                        tower_outputs_tuple.append(outputs_tuple)
+                        d_losses, g_losses, g_losses_post = losses_tuple
+                        tower_d_losses.append(d_losses)
+                        tower_g_losses.append(g_losses)
+                        tower_g_losses_post.append(g_losses_post)
+                        d_loss, g_loss, g_loss_post = loss_tuple
+                        tower_d_loss.append(d_loss)
+                        tower_g_loss.append(g_loss)
+                        tower_g_loss_post.append(g_loss_post)
+                        tower_metrics_tuple.append(metrics_tuple)
+            self.d_vars = tf.trainable_variables(self.discriminator_scope)
+            self.g_vars = tf.trainable_variables(self.generator_scope)
+
+            if self.aggregate_nccl:
+                scope_replica = lambda scope, i: ('' if i == 0 else 'v%d/' % i) + scope
+                tower_d_vars = [tf.trainable_variables(
+                    scope_replica(self.discriminator_scope, i)) for i in range(self.num_gpus)]
+                tower_g_vars = [tf.trainable_variables(
+                    scope_replica(self.generator_scope, i)) for i in range(self.num_gpus)]
+                assert self.d_vars == tower_d_vars[0]
+                assert self.g_vars == tower_g_vars[0]
+                tower_d_optimizer = [tf.train.AdamOptimizer(
+                    self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)]
+                tower_g_optimizer = [tf.train.AdamOptimizer(
+                    self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)]
+
+                if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)):
+                    tower_d_gradvars = []
+                    tower_g_gradvars = []
+                    tower_d_train_op = []
+                    tower_g_train_op = []
+                    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
+                        if any(tower_d_losses):
+                            for i in range(self.num_gpus):
+                                with tf.device('/gpu:%d' % i):
+                                    with tf.name_scope(scope_replica('d_compute_gradients', i)):
+                                        d_gradvars = tower_d_optimizer[i].compute_gradients(
+                                            tower_d_loss[i], var_list=tower_d_vars[i])
+                                        tower_d_gradvars.append(d_gradvars)
+
+                            all_d_grads, all_d_vars = tf_utils.split_grad_list(tower_d_gradvars)
+                            all_d_grads = tf_utils.allreduce_grads(all_d_grads, average=True)
+                            tower_d_gradvars = tf_utils.merge_grad_list(all_d_grads, all_d_vars)
+
+                            for i in range(self.num_gpus):
+                                with tf.device('/gpu:%d' % i):
+                                    with tf.name_scope(scope_replica('d_apply_gradients', i)):
+                                        d_train_op = tower_d_optimizer[i].apply_gradients(tower_d_gradvars[i])
+                                        tower_d_train_op.append(d_train_op)
+                            d_train_op = tf.group(*tower_d_train_op)
+                        else:
+                            d_train_op = tf.no_op()
+                    with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []):
+                        if any(tower_g_losses_post):
+                            for i in range(self.num_gpus):
+                                with tf.device('/gpu:%d' % i):
+                                    if not self.hparams.joint_gan_optimization:
+                                        replace_read_ops(tower_g_loss_post[i], tower_d_vars[i])
+
+                                    with tf.name_scope(scope_replica('g_compute_gradients', i)):
+                                        g_gradvars = tower_g_optimizer[i].compute_gradients(
+                                            tower_g_loss_post[i], var_list=tower_g_vars[i])
+                                        tower_g_gradvars.append(g_gradvars)
+
+                            all_g_grads, all_g_vars = tf_utils.split_grad_list(tower_g_gradvars)
+                            all_g_grads = tf_utils.allreduce_grads(all_g_grads, average=True)
+                            tower_g_gradvars = tf_utils.merge_grad_list(all_g_grads, all_g_vars)
+
+                            for i, g_gradvars in enumerate(tower_g_gradvars):
+                                with tf.device('/gpu:%d' % i):
+                                    with tf.name_scope(scope_replica('g_apply_gradients', i)):
+                                        g_train_op = tower_g_optimizer[i].apply_gradients(g_gradvars)
+                                        tower_g_train_op.append(g_train_op)
+                            g_train_op = tf.group(*tower_g_train_op)
+                        else:
+                            g_train_op = tf.no_op()
+                    with tf.control_dependencies([g_train_op]):
+                        train_op = tf.assign_add(global_step, 1)
+                    self.train_op = train_op
+                else:
+                    self.train_op = None
+
+                global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+                tower_saveable_vars = [[] for _ in range(self.num_gpus)]
+                for var in global_variables:
+                    m = re.match('v(\d+)/.*', var.name)
+                    i = int(m.group(1)) if m else 0
+                    tower_saveable_vars[i].append(var)
+                self.saveable_variables = [global_step] + tower_saveable_vars[0]
+
+                post_init_ops = []
+                for i, saveable_vars in enumerate(tower_saveable_vars[1:], 1):
+                    assert len(saveable_vars) == len(tower_saveable_vars[0])
+                    for var, var0 in zip(saveable_vars, tower_saveable_vars[0]):
+                        assert var.name == 'v%d/%s' % (i, var0.name)
+                        post_init_ops.append(var.assign(var0.read_value()))
+                self.post_init_ops = post_init_ops
+            else:  # not self.aggregate_nccl (i.e. aggregation in cpu)
+                g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
+                d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
+
+                if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)):
+                    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
+                        if any(tower_d_losses):
+                            with tf.name_scope('d_compute_gradients'):
+                                d_gradvars = compute_averaged_gradients(
+                                    d_optimizer, tower_d_loss, var_list=self.d_vars)
+                            with tf.name_scope('d_apply_gradients'):
+                                d_train_op = d_optimizer.apply_gradients(d_gradvars)
+                        else:
+                            d_train_op = tf.no_op()
+                    with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []):
+                        if any(tower_g_losses_post):
+                            for g_loss_post in tower_g_loss_post:
+                                if not self.hparams.joint_gan_optimization:
+                                    replace_read_ops(g_loss_post, self.d_vars)
+                            with tf.name_scope('g_compute_gradients'):
+                                g_gradvars = compute_averaged_gradients(
+                                    g_optimizer, tower_g_loss_post, var_list=self.g_vars)
+                            with tf.name_scope('g_apply_gradients'):
+                                g_train_op = g_optimizer.apply_gradients(g_gradvars)
+                        else:
+                            g_train_op = tf.no_op()
+                    with tf.control_dependencies([g_train_op]):
+                        train_op = tf.assign_add(global_step, 1)
+                    self.train_op = train_op
+                else:
+                    self.train_op = None
+
+                global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+                self.saveable_variables = [global_step] + global_variables
+                self.post_init_ops = []
+
+            # Device that runs the ops to apply global gradient updates.
+            consolidation_device = '/cpu:0'
+            with tf.device(consolidation_device):
+                with tf.name_scope('consolidation'):
+                    self.outputs, self.eval_outputs = reduce_tensors(tower_outputs_tuple)
+                    self.d_losses = reduce_tensors(tower_d_losses, shallow=True)
+                    self.g_losses = reduce_tensors(tower_g_losses, shallow=True)
+                    self.metrics, self.eval_metrics = reduce_tensors(tower_metrics_tuple)
+                    self.d_loss = reduce_tensors(tower_d_loss)
+                    self.g_loss = reduce_tensors(tower_g_loss)
+
+        original_local_variables = set(tf.local_variables())
+        self.accum_eval_metrics = OrderedDict()
+        for name, eval_metric in self.eval_metrics.items():
+            _, self.accum_eval_metrics['accum_' + name] = tf.metrics.mean_tensor(eval_metric)
+        local_variables = set(tf.local_variables()) - original_local_variables
+        self.accum_eval_metrics_reset_op = tf.group([tf.assign(v, tf.zeros_like(v)) for v in local_variables])
+
+        original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
+        add_summaries(self.inputs)
+        add_summaries(self.outputs)
+        add_scalar_summaries(self.d_losses)
+        add_scalar_summaries(self.g_losses)
+        add_scalar_summaries(self.metrics)
+        if self.d_losses:
+            add_scalar_summaries({'d_loss': self.d_loss})
+        if self.g_losses:
+            add_scalar_summaries({'g_loss': self.g_loss})
+        if self.d_losses and self.g_losses:
+            add_scalar_summaries({'loss': self.d_loss + self.g_loss})
+        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries
+        # split summaries into non-image summaries and image summaries
+        self.summary_op = tf.summary.merge(list(summaries - set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))))
+        self.image_summary_op = tf.summary.merge(list(summaries & set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))))
+
+        original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
+        add_gif_summaries(self.eval_outputs)
+        add_plot_and_scalar_summaries(
+            {name: tf.reduce_mean(metric, axis=0) for name, metric in self.eval_metrics.items()},
+            x_offset=self.hparams.context_frames + 1)
+        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries
+        self.eval_summary_op = tf.summary.merge(list(summaries))
+
+        original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
+        add_plot_and_scalar_summaries(
+            {name: tf.reduce_mean(metric, axis=0) for name, metric in self.accum_eval_metrics.items()},
+            x_offset=self.hparams.context_frames + 1)
+        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries
+        self.accum_eval_summary_op = tf.summary.merge(list(summaries))
+
+    def generator_loss_fn(self, inputs, outputs):
+        hparams = self.hparams
+        gen_losses = OrderedDict()
+        if hparams.l1_weight or hparams.l2_weight or hparams.vgg_cdist_weight:
+            gen_images = outputs.get('gen_images_enc', outputs['gen_images'])
+            target_images = inputs['images'][1:]
+        if hparams.l1_weight:
+            gen_l1_loss = vp.losses.l1_loss(gen_images, target_images)
+            gen_losses["gen_l1_loss"] = (gen_l1_loss, hparams.l1_weight)
+        if hparams.l2_weight:
+            gen_l2_loss = vp.losses.l2_loss(gen_images, target_images)
+            gen_losses["gen_l2_loss"] = (gen_l2_loss, hparams.l2_weight)
+        if hparams.vgg_cdist_weight:
+            gen_vgg_cdist_loss = vp.metrics.vgg_cosine_distance(gen_images, target_images)
+            gen_losses['gen_vgg_cdist_loss'] = (gen_vgg_cdist_loss, hparams.vgg_cdist_weight)
+        if hparams.feature_l2_weight:
+            gen_features = outputs.get('gen_features_enc', outputs['gen_features'])
+            target_features = outputs['features'][1:]
+            gen_feature_l2_loss = vp.losses.l2_loss(gen_features, target_features)
+            gen_losses["gen_feature_l2_loss"] = (gen_feature_l2_loss, hparams.feature_l2_weight)
+        if hparams.ae_l2_weight:
+            gen_images_dec = outputs.get('gen_images_dec_enc', outputs['gen_images_dec'])  # they both should be the same
+            target_images = inputs['images']
+            gen_ae_l2_loss = vp.losses.l2_loss(gen_images_dec, target_images)
+            gen_losses["gen_ae_l2_loss"] = (gen_ae_l2_loss, hparams.ae_l2_weight)
+        if hparams.state_weight:
+            gen_states = outputs.get('gen_states_enc', outputs['gen_states'])
+            target_states = inputs['states'][1:]
+            gen_state_loss = vp.losses.l2_loss(gen_states, target_states)
+            gen_losses["gen_state_loss"] = (gen_state_loss, hparams.state_weight)
+        if hparams.tv_weight:
+            gen_flows = outputs.get('gen_flows_enc', outputs['gen_flows'])
+            flow_diff1 = gen_flows[..., 1:, :, :, :] - gen_flows[..., :-1, :, :, :]
+            flow_diff2 = gen_flows[..., :, 1:, :, :] - gen_flows[..., :, :-1, :, :]
+            # sum over the multiple transformations but take the mean for the other dimensions
+            gen_tv_loss = (tf.reduce_mean(tf.reduce_sum(tf.abs(flow_diff1), axis=(-2, -1))) +
+                           tf.reduce_mean(tf.reduce_sum(tf.abs(flow_diff2), axis=(-2, -1))))
+            gen_losses['gen_tv_loss'] = (gen_tv_loss, hparams.tv_weight)
+        gan_weights = {'_image_sn': hparams.image_sn_gan_weight,
+                       '_images_sn': hparams.images_sn_gan_weight,
+                       '_video_sn': hparams.video_sn_gan_weight}
+        for infix, gan_weight in gan_weights.items():
+            if gan_weight:
+                gen_gan_loss = vp.losses.gan_loss(outputs['discrim%s_logits_fake' % infix], 1.0, hparams.gan_loss_type)
+                gen_losses["gen%s_gan_loss" % infix] = (gen_gan_loss, gan_weight)
+            if gan_weight and (hparams.gan_feature_l2_weight or hparams.gan_feature_cdist_weight):
+                i_feature = 0
+                discrim_features_fake = []
+                discrim_features_real = []
+                while True:
+                    discrim_feature_fake = outputs.get('discrim%s_feature%d_fake' % (infix, i_feature))
+                    discrim_feature_real = outputs.get('discrim%s_feature%d_real' % (infix, i_feature))
+                    if discrim_feature_fake is None or discrim_feature_real is None:
+                        break
+                    discrim_features_fake.append(discrim_feature_fake)
+                    discrim_features_real.append(discrim_feature_real)
+                    i_feature += 1
+                if hparams.gan_feature_l2_weight:
+                    gen_gan_feature_l2_loss = sum([vp.losses.l2_loss(discrim_feature_fake, discrim_feature_real)
+                                                   for discrim_feature_fake, discrim_feature_real in zip(discrim_features_fake, discrim_features_real)])
+                    gen_losses["gen%s_gan_feature_l2_loss" % infix] = (gen_gan_feature_l2_loss, hparams.gan_feature_l2_weight)
+                if hparams.gan_feature_cdist_weight:
+                    gen_gan_feature_cdist_loss = sum([vp.losses.cosine_distance(discrim_feature_fake, discrim_feature_real)
+                                                      for discrim_feature_fake, discrim_feature_real in zip(discrim_features_fake, discrim_features_real)])
+                    gen_losses["gen%s_gan_feature_cdist_loss" % infix] = (gen_gan_feature_cdist_loss, hparams.gan_feature_cdist_weight)
+        vae_gan_weights = {'_image_sn': hparams.image_sn_vae_gan_weight,
+                           '_images_sn': hparams.images_sn_vae_gan_weight,
+                           '_video_sn': hparams.video_sn_vae_gan_weight}
+        for infix, vae_gan_weight in vae_gan_weights.items():
+            if vae_gan_weight:
+                gen_vae_gan_loss = vp.losses.gan_loss(outputs['discrim%s_logits_enc_fake' % infix], 1.0, hparams.gan_loss_type)
+                gen_losses["gen%s_vae_gan_loss" % infix] = (gen_vae_gan_loss, vae_gan_weight)
+            if vae_gan_weight and (hparams.vae_gan_feature_l2_weight or hparams.vae_gan_feature_cdist_weight):
+                i_feature = 0
+                discrim_features_enc_fake = []
+                discrim_features_enc_real = []
+                while True:
+                    discrim_feature_enc_fake = outputs.get('discrim%s_feature%d_enc_fake' % (infix, i_feature))
+                    discrim_feature_enc_real = outputs.get('discrim%s_feature%d_enc_real' % (infix, i_feature))
+                    if discrim_feature_enc_fake is None or discrim_feature_enc_real is None:
+                        break
+                    discrim_features_enc_fake.append(discrim_feature_enc_fake)
+                    discrim_features_enc_real.append(discrim_feature_enc_real)
+                    i_feature += 1
+                if hparams.vae_gan_feature_l2_weight:
+                    gen_vae_gan_feature_l2_loss = sum([vp.losses.l2_loss(discrim_feature_enc_fake, discrim_feature_enc_real)
+                                                       for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real)])
+                    gen_losses["gen%s_vae_gan_feature_l2_loss" % infix] = (gen_vae_gan_feature_l2_loss, hparams.vae_gan_feature_l2_weight)
+                if hparams.vae_gan_feature_cdist_weight:
+                    gen_vae_gan_feature_cdist_loss = sum([vp.losses.cosine_distance(discrim_feature_enc_fake, discrim_feature_enc_real)
+                                                          for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real)])
+                    gen_losses["gen%s_vae_gan_feature_cdist_loss" % infix] = (gen_vae_gan_feature_cdist_loss, hparams.vae_gan_feature_cdist_weight)
+        if hparams.kl_weight:
+            gen_kl_loss = vp.losses.kl_loss(outputs['zs_mu_enc'], outputs['zs_log_sigma_sq_enc'],
+                                            outputs.get('zs_mu_prior'), outputs.get('zs_log_sigma_sq_prior'))
+            gen_losses["gen_kl_loss"] = (gen_kl_loss, self.kl_weight)  # possibly annealed kl_weight
+        return gen_losses
+
+    def discriminator_loss_fn(self, inputs, outputs):
+        hparams = self.hparams
+        discrim_losses = OrderedDict()
+        gan_weights = {'_image_sn': hparams.image_sn_gan_weight,
+                       '_images_sn': hparams.images_sn_gan_weight,
+                       '_video_sn': hparams.video_sn_gan_weight}
+        for infix, gan_weight in gan_weights.items():
+            if gan_weight:
+                discrim_gan_loss_real = vp.losses.gan_loss(outputs['discrim%s_logits_real' % infix], 1.0, hparams.gan_loss_type)
+                discrim_gan_loss_fake = vp.losses.gan_loss(outputs['discrim%s_logits_fake' % infix], 0.0, hparams.gan_loss_type)
+                discrim_gan_loss = discrim_gan_loss_real + discrim_gan_loss_fake
+                discrim_losses["discrim%s_gan_loss" % infix] = (discrim_gan_loss, gan_weight)
+        vae_gan_weights = {'_image_sn': hparams.image_sn_vae_gan_weight,
+                           '_images_sn': hparams.images_sn_vae_gan_weight,
+                           '_video_sn': hparams.video_sn_vae_gan_weight}
+        for infix, vae_gan_weight in vae_gan_weights.items():
+            if vae_gan_weight:
+                discrim_vae_gan_loss_real = vp.losses.gan_loss(outputs['discrim%s_logits_enc_real' % infix], 1.0, hparams.gan_loss_type)
+                discrim_vae_gan_loss_fake = vp.losses.gan_loss(outputs['discrim%s_logits_enc_fake' % infix], 0.0, hparams.gan_loss_type)
+                discrim_vae_gan_loss = discrim_vae_gan_loss_real + discrim_vae_gan_loss_fake
+                discrim_losses["discrim%s_vae_gan_loss" % infix] = (discrim_vae_gan_loss, vae_gan_weight)
+        return discrim_losses
diff --git a/video_prediction/models/dna_model.py b/video_prediction/models/dna_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4fa8b97bc523382adfa14f564aa30193920ed48
--- /dev/null
+++ b/video_prediction/models/dna_model.py
@@ -0,0 +1,476 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Model architecture for predictive model, including CDNA, DNA, and STP."""
+
+import itertools
+
+import numpy as np
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+from tensorflow.contrib.layers.python import layers as tf_layers
+
+from video_prediction.models import VideoPredictionModel
+from .sna_model import basic_conv_lstm_cell
+
+
+# Amount to use when lower bounding tensors
+RELU_SHIFT = 1e-12
+
+
+def construct_model(images,
+                    actions=None,
+                    states=None,
+                    iter_num=-1.0,
+                    kernel_size=(5, 5),
+                    k=-1,
+                    use_state=True,
+                    num_masks=10,
+                    stp=False,
+                    cdna=True,
+                    dna=False,
+                    context_frames=2,
+                    pix_distributions=None):
+    """Build convolutional lstm video predictor using STP, CDNA, or DNA.
+
+    Args:
+        images: tensor of ground truth image sequences
+        actions: tensor of action sequences
+        states: tensor of ground truth state sequences
+        iter_num: tensor of the current training iteration (for sched. sampling)
+        k: constant used for scheduled sampling. -1 to feed in own prediction.
+        use_state: True to include state and action in prediction
+        num_masks: the number of different pixel motion predictions (and
+                   the number of masks for each of those predictions)
+        stp: True to use Spatial Transformer Predictor (STP)
+        cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
+        dna: True to use Dynamic Neural Advection (DNA)
+        context_frames: number of ground truth frames to pass in before
+                        feeding in own predictions
+    Returns:
+        gen_images: predicted future image frames
+        gen_states: predicted future states
+
+    Raises:
+        ValueError: if more than one network option specified or more than 1 mask
+        specified for DNA model.
+    """
+    DNA_KERN_SIZE = kernel_size[0]
+
+    if stp + cdna + dna != 1:
+        raise ValueError('More than one, or no network option specified.')
+    batch_size, img_height, img_width, color_channels = images[0].get_shape()[0:4]
+    lstm_func = basic_conv_lstm_cell
+
+    # Generated robot states and images.
+    gen_states, gen_images = [], []
+    gen_pix_distrib = []
+    gen_masks = []
+    current_state = states[0]
+
+    if k == -1:
+        feedself = True
+    else:
+        # Scheduled sampling:
+        # Calculate number of ground-truth frames to pass in.
+        num_ground_truth = tf.to_int32(
+            tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k)))))
+        feedself = False
+
+    # LSTM state sizes and states.
+    lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32]))
+    lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None
+    lstm_state5, lstm_state6, lstm_state7 = None, None, None
+
+    for t, action in enumerate(actions):
+        # Reuse variables after the first timestep.
+        reuse = bool(gen_images)
+
+        done_warm_start = len(gen_images) > context_frames - 1
+        with slim.arg_scope(
+                [lstm_func, slim.layers.conv2d, slim.layers.fully_connected,
+                 tf_layers.layer_norm, slim.layers.conv2d_transpose],
+                reuse=reuse):
+
+            if feedself and done_warm_start:
+                # Feed in generated image.
+                prev_image = gen_images[-1]
+                if pix_distributions is not None:
+                    prev_pix_distrib = gen_pix_distrib[-1]
+            elif done_warm_start:
+                # Scheduled sampling
+                prev_image = scheduled_sample(images[t], gen_images[-1], batch_size,
+                                              num_ground_truth)
+            else:
+                # Always feed in ground_truth
+                prev_image = images[t]
+                if pix_distributions is not None:
+                    prev_pix_distrib = pix_distributions[t]
+                    # prev_pix_distrib = tf.expand_dims(prev_pix_distrib, -1)
+
+            # Predicted state is always fed back in
+            state_action = tf.concat(axis=1, values=[action, current_state])
+
+            enc0 = slim.layers.conv2d(
+                prev_image,
+                32, [5, 5],
+                stride=2,
+                scope='scale1_conv1',
+                normalizer_fn=tf_layers.layer_norm,
+                normalizer_params={'scope': 'layer_norm1'})
+
+            hidden1, lstm_state1 = lstm_func(
+                enc0, lstm_state1, lstm_size[0], scope='state1')
+            hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')
+            hidden2, lstm_state2 = lstm_func(
+                hidden1, lstm_state2, lstm_size[1], scope='state2')
+            hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3')
+            enc1 = slim.layers.conv2d(
+                hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2')
+
+            hidden3, lstm_state3 = lstm_func(
+                enc1, lstm_state3, lstm_size[2], scope='state3')
+            hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4')
+            hidden4, lstm_state4 = lstm_func(
+                hidden3, lstm_state4, lstm_size[3], scope='state4')
+            hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5')
+            enc2 = slim.layers.conv2d(
+                hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3')
+
+            # Pass in state and action.
+            smear = tf.reshape(
+                state_action,
+                [int(batch_size), 1, 1, int(state_action.get_shape()[1])])
+            smear = tf.tile(
+                smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1])
+            if use_state:
+                enc2 = tf.concat(axis=3, values=[enc2, smear])
+            enc3 = slim.layers.conv2d(
+                enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4')
+
+            hidden5, lstm_state5 = lstm_func(
+                enc3, lstm_state5, lstm_size[4], scope='state5')  # last 8x8
+            hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6')
+            enc4 = slim.layers.conv2d_transpose(
+                hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1')
+
+            hidden6, lstm_state6 = lstm_func(
+                enc4, lstm_state6, lstm_size[5], scope='state6')  # 16x16
+            hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7')
+            # Skip connection.
+            hidden6 = tf.concat(axis=3, values=[hidden6, enc1])  # both 16x16
+
+            enc5 = slim.layers.conv2d_transpose(
+                hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2')
+            hidden7, lstm_state7 = lstm_func(
+                enc5, lstm_state7, lstm_size[6], scope='state7')  # 32x32
+            hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8')
+
+            # Skip connection.
+            hidden7 = tf.concat(axis=3, values=[hidden7, enc0])  # both 32x32
+
+            enc6 = slim.layers.conv2d_transpose(
+                hidden7,
+                hidden7.get_shape()[3], 3, stride=2, scope='convt3',
+                normalizer_fn=tf_layers.layer_norm,
+                normalizer_params={'scope': 'layer_norm9'})
+
+            if dna:
+                # Using largest hidden state for predicting untied conv kernels.
+                enc7 = slim.layers.conv2d_transpose(
+                    enc6, DNA_KERN_SIZE ** 2, 1, stride=1, scope='convt4')
+            else:
+                # Using largest hidden state for predicting a new image layer.
+                enc7 = slim.layers.conv2d_transpose(
+                    enc6, color_channels, 1, stride=1, scope='convt4')
+                # This allows the network to also generate one image from scratch,
+                # which is useful when regions of the image become unoccluded.
+                transformed = [tf.nn.sigmoid(enc7)]
+
+            if stp:
+                stp_input0 = tf.reshape(hidden5, [int(batch_size), -1])
+                stp_input1 = slim.layers.fully_connected(
+                    stp_input0, 100, scope='fc_stp')
+
+                # disabling capability to generete pixels
+                reuse_stp = None
+                if reuse:
+                    reuse_stp = reuse
+                transformed = stp_transformation(prev_image, stp_input1, num_masks, reuse_stp)
+                # transformed += stp_transformation(prev_image, stp_input1, num_masks)
+
+                if pix_distributions is not None:
+                    transf_distrib = stp_transformation(prev_pix_distrib, stp_input1, num_masks, reuse=True)
+
+            elif cdna:
+                cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
+
+                new_transformed, cdna_kerns = cdna_transformation(prev_image,
+                                                                  cdna_input,
+                                                                  num_masks,
+                                                                  int(color_channels),
+                                                                  kernel_size,
+                                                                  reuse_sc=reuse)
+                transformed += new_transformed
+
+                if pix_distributions is not None:
+                    if not dna:
+                        transf_distrib = [prev_pix_distrib]
+                    new_transf_distrib, _ = cdna_transformation(prev_pix_distrib,
+                                                                cdna_input,
+                                                                num_masks,
+                                                                prev_pix_distrib.shape[-1].value,
+                                                                kernel_size,
+                                                                reuse_sc=True)
+                    transf_distrib += new_transf_distrib
+
+            elif dna:
+                # Only one mask is supported (more should be unnecessary).
+                if num_masks != 1:
+                    raise ValueError('Only one mask is supported for DNA model.')
+                transformed = [dna_transformation(prev_image, enc7, DNA_KERN_SIZE)]
+
+            masks = slim.layers.conv2d_transpose(
+                enc6, num_masks + 1, 1, stride=1, scope='convt7')
+            masks = tf.reshape(
+                tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])),
+                [int(batch_size), int(img_height), int(img_width), num_masks + 1])
+            mask_list = tf.split(masks, num_masks + 1, axis=3)
+            output = mask_list[0] * prev_image
+            for layer, mask in zip(transformed, mask_list[1:]):
+                output += layer * mask
+            gen_images.append(output)
+            gen_masks.append(mask_list)
+
+            if dna and pix_distributions is not None:
+                transf_distrib = [dna_transformation(prev_pix_distrib, enc7, DNA_KERN_SIZE)]
+
+            if pix_distributions is not None:
+                pix_distrib_output = mask_list[0] * prev_pix_distrib
+                for layer, mask in zip(transf_distrib, mask_list[1:]):
+                    pix_distrib_output += layer * mask
+                pix_distrib_output /= tf.reduce_sum(pix_distrib_output, axis=(1, 2), keepdims=True)
+                gen_pix_distrib.append(pix_distrib_output)
+
+            if int(current_state.get_shape()[1]) == 0:
+                current_state = tf.zeros_like(state_action)
+            else:
+                current_state = slim.layers.fully_connected(
+                    state_action,
+                    int(current_state.get_shape()[1]),
+                    scope='state_pred',
+                    activation_fn=None)
+            gen_states.append(current_state)
+
+    return gen_images, gen_states, gen_masks, gen_pix_distrib
+
+
+## Utility functions
+def stp_transformation(prev_image, stp_input, num_masks):
+    """Apply spatial transformer predictor (STP) to previous image.
+
+    Args:
+        prev_image: previous image to be transformed.
+        stp_input: hidden layer to be used for computing STN parameters.
+        num_masks: number of masks and hence the number of STP transformations.
+    Returns:
+        List of images transformed by the predicted STP parameters.
+     """
+    # Only import spatial transformer if needed.
+    from spatial_transformer import transformer
+
+    identity_params = tf.convert_to_tensor(
+        np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32))
+    transformed = []
+    for i in range(num_masks - 1):
+        params = slim.layers.fully_connected(
+            stp_input, 6, scope='stp_params' + str(i),
+            activation_fn=None) + identity_params
+        transformed.append(transformer(prev_image, params))
+
+    return transformed
+
+
+def cdna_transformation(prev_image, cdna_input, num_masks, color_channels, kernel_size, reuse_sc=None):
+    """Apply convolutional dynamic neural advection to previous image.
+
+    Args:
+        prev_image: previous image to be transformed.
+        cdna_input: hidden lyaer to be used for computing CDNA kernels.
+        num_masks: the number of masks and hence the number of CDNA transformations.
+        color_channels: the number of color channels in the images.
+    Returns:
+        List of images transformed by the predicted CDNA kernels.
+    """
+    batch_size = int(cdna_input.get_shape()[0])
+    height = int(prev_image.get_shape()[1])
+    width = int(prev_image.get_shape()[2])
+
+    # Predict kernels using linear function of last hidden layer.
+    cdna_kerns = slim.layers.fully_connected(
+        cdna_input,
+        kernel_size[0] * kernel_size[1] * num_masks,
+        scope='cdna_params',
+        activation_fn=None,
+        reuse=reuse_sc)
+
+    # Reshape and normalize.
+    cdna_kerns = tf.reshape(
+        cdna_kerns, [batch_size, kernel_size[0], kernel_size[1], 1, num_masks])
+    cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT
+    norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True)
+    cdna_kerns /= norm_factor
+
+    # Treat the color channel dimension as the batch dimension since the same
+    # transformation is applied to each color channel.
+    # Treat the batch dimension as the channel dimension so that
+    # depthwise_conv2d can apply a different transformation to each sample.
+    cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3])
+    cdna_kerns = tf.reshape(cdna_kerns, [kernel_size[0], kernel_size[1], batch_size, num_masks])
+    # Swap the batch and channel dimensions.
+    prev_image = tf.transpose(prev_image, [3, 1, 2, 0])
+
+    # Transform image.
+    transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME')
+
+    # Transpose the dimensions to where they belong.
+    transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks])
+    transformed = tf.transpose(transformed, [3, 1, 2, 0, 4])
+    transformed = tf.unstack(transformed, axis=-1)
+    return transformed, cdna_kerns
+
+
+def dna_transformation(prev_image, dna_input, kernel_size):
+    """Apply dynamic neural advection to previous image.
+
+    Args:
+        prev_image: previous image to be transformed.
+        dna_input: hidden lyaer to be used for computing DNA transformation.
+    Returns:
+        List of images transformed by the predicted CDNA kernels.
+    """
+    # Construct translated images.
+    pad_along_height = (kernel_size[0] - 1)
+    pad_along_width = (kernel_size[1] - 1)
+    pad_top = pad_along_height // 2
+    pad_bottom = pad_along_height - pad_top
+    pad_left = pad_along_width // 2
+    pad_right = pad_along_width - pad_left
+    prev_image_pad = tf.pad(prev_image, [[0, 0],
+                                         [pad_top, pad_bottom],
+                                         [pad_left, pad_right],
+                                         [0, 0]])
+    image_height = int(prev_image.get_shape()[1])
+    image_width = int(prev_image.get_shape()[2])
+
+    inputs = []
+    for xkern in range(kernel_size[0]):
+        for ykern in range(kernel_size[1]):
+            inputs.append(
+                tf.expand_dims(
+                    tf.slice(prev_image_pad, [0, xkern, ykern, 0],
+                             [-1, image_height, image_width, -1]), [3]))
+    inputs = tf.concat(axis=3, values=inputs)
+
+    # Normalize channels to 1.
+    kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT
+    kernel = tf.expand_dims(
+        kernel / tf.reduce_sum(
+            kernel, [3], keepdims=True), [4])
+    return tf.reduce_sum(kernel * inputs, [3], keepdims=False)
+
+
+def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
+    """Sample batch with specified mix of ground truth and generated data points.
+
+    Args:
+        ground_truth_x: tensor of ground-truth data points.
+        generated_x: tensor of generated data points.
+        batch_size: batch size
+        num_ground_truth: number of ground-truth examples to include in batch.
+    Returns:
+        New batch with num_ground_truth sampled from ground_truth_x and the rest
+        from generated_x.
+    """
+    idx = tf.random_shuffle(tf.range(int(batch_size)))
+    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
+    generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))
+
+    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
+    generated_examps = tf.gather(generated_x, generated_idx)
+    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
+                             [ground_truth_examps, generated_examps])
+
+
+def generator_fn(inputs, hparams=None):
+    images = tf.unstack(inputs['images'], axis=0)
+    actions = tf.unstack(inputs['actions'], axis=0)
+    states = tf.unstack(inputs['states'], axis=0)
+    pix_distributions = tf.unstack(inputs['pix_distribs'], axis=0) if 'pix_distribs' in inputs else None
+    iter_num = tf.to_float(tf.train.get_or_create_global_step())
+
+    gen_images, gen_states, gen_masks, gen_pix_distrib = \
+        construct_model(images,
+                        actions,
+                        states,
+                        iter_num=iter_num,
+                        kernel_size=hparams.kernel_size,
+                        k=hparams.schedule_sampling_k,
+                        num_masks=hparams.num_masks,
+                        cdna=hparams.transformation == 'cdna',
+                        dna=hparams.transformation == 'dna',
+                        stp=hparams.transformation == 'stp',
+                        context_frames=hparams.context_frames,
+                        pix_distributions=pix_distributions)
+    outputs = {
+        'gen_images': tf.stack(gen_images, axis=0),
+        'gen_states': tf.stack(gen_states, axis=0),
+        'masks': tf.stack([tf.stack(gen_mask_list, axis=-1) for gen_mask_list in gen_masks], axis=0),
+    }
+    if 'pix_distribs' in inputs:
+        outputs['gen_pix_distribs'] = tf.stack(gen_pix_distrib, axis=0)
+    gen_images = outputs['gen_images'][hparams.context_frames - 1:]
+    return gen_images, outputs
+
+
+class DNAVideoPredictionModel(VideoPredictionModel):
+    def __init__(self, *args, **kwargs):
+        super(DNAVideoPredictionModel, self).__init__(
+            generator_fn, *args, **kwargs)
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(DNAVideoPredictionModel, self).get_default_hparams_dict()
+        hparams = dict(
+            batch_size=32,
+            l1_weight=0.0,
+            l2_weight=1.0,
+            transformation='cdna',
+            kernel_size=(9, 9),
+            num_masks=10,
+            schedule_sampling_k=900.0,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def parse_hparams(self, hparams_dict, hparams):
+        hparams = super(DNAVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)
+        if self.mode == 'test':
+            def override_hparams_maybe(name, value):
+                orig_value = hparams.values()[name]
+                if orig_value != value:
+                    print('Overriding hparams from %s=%r to %r for mode=%s.' %
+                          (name, orig_value, value, self.mode))
+                    hparams.set_hparam(name, value)
+            override_hparams_maybe('schedule_sampling_k', -1)
+        return hparams
diff --git a/video_prediction/models/mcnet_model.py b/video_prediction/models/mcnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..725ce4f46a301b6aa07f3d50ef811584d5b502db
--- /dev/null
+++ b/video_prediction/models/mcnet_model.py
@@ -0,0 +1,467 @@
+import collections
+import functools
+import itertools
+from collections import OrderedDict
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.util import nest
+from video_prediction import ops, flow_ops
+from video_prediction.models import BaseVideoPredictionModel
+from video_prediction.models import networks
+from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
+from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
+from video_prediction.utils import tf_utils
+from datetime import datetime
+from pathlib import Path
+from video_prediction.layers import layer_def as ld
+from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
+from video_prediction.layers.mcnet_ops import *
+from video_prediction.utils.mcnet_utils import *
+import os
+
+class McNetVideoPredictionModel(BaseVideoPredictionModel):
+    def __init__(self, mode='train', hparams_dict=None,
+                 hparams=None, **kwargs):
+        super(McNetVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
+        self.mode = mode
+        self.lr = self.hparams.lr
+        self.context_frames = self.hparams.context_frames
+        self.sequence_length = self.hparams.sequence_length
+        self.predict_frames = self.sequence_length - self.context_frames
+        self.df_dim = self.hparams.df_dim
+        self.gf_dim = self.hparams.gf_dim
+        self.alpha = self.hparams.alpha
+        self.beta = self.hparams.beta
+        self.gen_images_enc = None
+        self.recon_loss = None
+        self.latent_loss = None
+        self.total_loss = None
+
+    def get_default_hparams_dict(self):
+        """
+        The keys of this dict define valid hyperparameters for instances of
+        this class. A class inheriting from this one should override this
+        method if it has a different set of hyperparameters.
+
+        Returns:
+            A dict with the following hyperparameters.
+
+            batch_size: batch size for training.
+            lr: learning rate. if decay steps is non-zero, this is the
+                learning rate for steps <= decay_step.
+
+
+
+
+            max_steps: number of training steps.
+
+
+            context_frames: the number of ground-truth frames to pass in at
+                start. Must be specified during instantiation.
+            sequence_length: the number of frames in the video sequence,
+                including the context frames, so this model predicts
+                `sequence_length - context_frames` future frames. Must be
+                specified during instantiation.
+            df_dim: specific parameters for mcnet
+            gf_dim: specific parameters for menet
+            alpha:  specific parameters for mcnet
+            beta:   specific paramters for mcnet
+
+        """
+        default_hparams = super(McNetVideoPredictionModel, self).get_default_hparams_dict()
+        hparams = dict(
+            batch_size=16,
+            lr=0.001,
+            max_steps=350000,
+            context_frames = 10,
+            sequence_length = 20,
+            nz = 16,
+            gf_dim = 64,
+            df_dim = 64,
+            alpha = 1,
+            beta = 0.0
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def build_graph(self, x):
+
+        self.x = x["images"]
+        self.x_shape = self.x.get_shape().as_list()
+        self.batch_size = self.x_shape[0]
+        self.image_size = [self.x_shape[2],self.x_shape[3]]
+        self.c_dim = self.x_shape[4]
+        self.diff_shape = [self.batch_size, self.context_frames-1, self.image_size[0],
+                           self.image_size[1], self.c_dim]
+        self.xt_shape = [self.batch_size, self.image_size[0], self.image_size[1],self.c_dim]
+        self.is_train = True
+       
+
+        self.global_step = tf.Variable(0, name='global_step', trainable=False)
+        original_global_variables = tf.global_variables()
+
+        # self.xt = tf.placeholder(tf.float32, self.xt_shape, name='xt')
+        self.xt = self.x[:, self.context_frames - 1, :, :, :]
+
+        self.diff_in = tf.placeholder(tf.float32, self.diff_shape, name='diff_in')
+        diff_in_all = []
+        for t in range(1, self.context_frames):
+            prev = self.x[:, t-1:t, :, :, :]
+            next = self.x[:, t:t+1, :, :, :]
+            #diff_in = tf.reshape(next - prev, [self.batch_size, 1, self.image_size[0], self.image_size[1], -1])
+            print("prev:",prev)
+            print("next:",next)
+            diff_in = tf.subtract(next,prev)
+            print("diff_in:",diff_in)
+            diff_in_all.append(diff_in)
+
+        self.diff_in = tf.concat(axis = 1, values = diff_in_all)
+
+        cell = BasicConvLSTMCell([self.image_size[0] / 8, self.image_size[1] / 8], [3, 3], 256)
+
+        pred = self.forward(self.diff_in, self.xt, cell)
+
+
+        self.G = tf.concat(axis=1, values=pred)#[batch_size,context_frames,image1,image2,channels]
+        print ("1:self.G:",self.G)
+        if self.is_train:
+
+            true_sim = self.x[:, self.context_frames:, :, :, :]
+
+            # Bing: the following make sure the channel is three dimension, if the channle is 3 then will be duplicated
+            if self.c_dim == 1: true_sim = tf.tile(true_sim, [1, 1, 1, 1, 3])
+
+            # Bing: the raw inputs shape is [batch_size, image_size[0],self.image_size[1], num_seq, channel]. tf.transpose will transpoe the shape into
+            # [batch size*num_seq, image_size0, image_size1, channels], for our era5 case, we do not need transpose
+            # true_sim = tf.reshape(tf.transpose(true_sim,[0,3,1,2,4]),
+            #                             [-1, self.image_size[0],
+            #                              self.image_size[1], 3])
+            true_sim = tf.reshape(true_sim, [-1, self.image_size[0], self.image_size[1], 3])
+
+
+
+
+        gen_sim = self.G
+        
+        #combine groud truth and predict frames
+        self.x_hat = tf.concat([self.x[:, :self.context_frames, :, :, :], self.G], 1)
+        print ("self.x_hat:",self.x_hat)
+        if self.c_dim == 1: gen_sim = tf.tile(gen_sim, [1, 1, 1, 1, 3])
+        # gen_sim = tf.reshape(tf.transpose(gen_sim,[0,3,1,2,4]),
+        #                                [-1, self.image_size[0],
+        #                                self.image_size[1], 3])
+
+        gen_sim = tf.reshape(gen_sim, [-1, self.image_size[0], self.image_size[1], 3])
+
+
+        binput = tf.reshape(tf.transpose(self.x[:, :self.context_frames, :, :, :], [0, 1, 2, 3, 4]),
+                            [self.batch_size, self.image_size[0],
+                             self.image_size[1], -1])
+
+        btarget = tf.reshape(tf.transpose(self.x[:, self.context_frames:, :, :, :], [0, 1, 2, 3, 4]),
+                             [self.batch_size, self.image_size[0],
+                              self.image_size[1], -1])
+        bgen = tf.reshape(self.G, [self.batch_size,
+                                   self.image_size[0],
+                                   self.image_size[1], -1])
+
+        print ("binput:",binput)
+        print("btarget:",btarget)
+        print("bgen:",bgen)
+
+        good_data = tf.concat(axis=3, values=[binput, btarget])
+        gen_data = tf.concat(axis=3, values=[binput, bgen])
+        self.gen_data = gen_data
+        print ("2:self.gen_data:", self.gen_data)
+        with tf.variable_scope("DIS", reuse=False):
+            self.D, self.D_logits = self.discriminator(good_data)
+
+        with tf.variable_scope("DIS", reuse=True):
+            self.D_, self.D_logits_ = self.discriminator(gen_data)
+
+        self.L_p = tf.reduce_mean(
+            tf.square(self.G - self.x[:, self.context_frames:, :, :, :]))
+
+        self.L_gdl = gdl(gen_sim, true_sim, 1.)
+        self.L_img = self.L_p + self.L_gdl
+
+        self.d_loss_real = tf.reduce_mean(
+            tf.nn.sigmoid_cross_entropy_with_logits(
+                logits = self.D_logits, labels = tf.ones_like(self.D)
+            ))
+        self.d_loss_fake = tf.reduce_mean(
+            tf.nn.sigmoid_cross_entropy_with_logits(
+                logits = self.D_logits_, labels = tf.zeros_like(self.D_)
+            ))
+        self.d_loss = self.d_loss_real + self.d_loss_fake
+        self.L_GAN = tf.reduce_mean(
+            tf.nn.sigmoid_cross_entropy_with_logits(
+                logits = self.D_logits_, labels = tf.ones_like(self.D_)
+            ))
+
+        self.loss_sum = tf.summary.scalar("L_img", self.L_img)
+        self.L_p_sum = tf.summary.scalar("L_p", self.L_p)
+        self.L_gdl_sum = tf.summary.scalar("L_gdl", self.L_gdl)
+        self.L_GAN_sum = tf.summary.scalar("L_GAN", self.L_GAN)
+        self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
+        self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real)
+        self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)
+
+        self.total_loss = self.alpha * self.L_img + self.beta * self.L_GAN
+        self._loss_sum = tf.summary.scalar("total_loss", self.total_loss)
+        self.g_sum = tf.summary.merge([self.L_p_sum,
+                                       self.L_gdl_sum, self.loss_sum,
+                                       self.L_GAN_sum])
+        self.d_sum = tf.summary.merge([self.d_loss_real_sum, self.d_loss_sum,
+                                       self.d_loss_fake_sum])
+
+
+        self.t_vars = tf.trainable_variables()
+        self.g_vars = [var for var in self.t_vars if 'DIS' not in var.name]
+        self.d_vars = [var for var in self.t_vars if 'DIS' in var.name]
+        num_param = 0.0
+        for var in self.g_vars:
+            num_param += int(np.prod(var.get_shape()));
+        print("Number of parameters: %d" % num_param)
+
+        # Training
+        self.d_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize(
+            self.d_loss, var_list = self.d_vars)
+        self.g_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize(
+            self.alpha * self.L_img + self.beta * self.L_GAN, var_list = self.g_vars, global_step=self.global_step)
+       
+        self.train_op = [self.d_optim,self.g_optim]
+        self.outputs = {}
+        self.outputs["gen_images"] = self.x_hat
+        
+
+        self.summary_op = tf.summary.merge_all()
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        return 
+
+
+    def forward(self, diff_in, xt, cell):
+        # Initial state
+        state = tf.zeros([self.batch_size, self.image_size[0] / 8,
+                          self.image_size[1] / 8, 512])
+        reuse = False
+        # Encoder
+        for t in range(self.context_frames - 1):
+            enc_h, res_m = self.motion_enc(diff_in[:, t, :, :, :], reuse = reuse)
+            h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = reuse)
+            reuse = True
+        pred = []
+        # Decoder
+        for t in range(self.predict_frames):
+            if t == 0:
+                h_cont, res_c = self.content_enc(xt, reuse = False)
+                h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = False)
+                res_connect = self.residual(res_m, res_c, reuse = False)
+                x_hat = self.dec_cnn(h_tp1, res_connect, reuse = False)
+
+            else:
+
+                enc_h, res_m = self.motion_enc(diff_in, reuse = True)
+                h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = True)
+                h_cont, res_c = self.content_enc(xt, reuse = reuse)
+                h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = True)
+                res_connect = self.residual(res_m, res_c, reuse = True)
+                x_hat = self.dec_cnn(h_tp1, res_connect, reuse = True)
+                print ("x_hat :",x_hat)
+            if self.c_dim == 3:
+                # Network outputs are BGR so they need to be reversed to use
+                # rgb_to_grayscale
+                #x_hat_gray = tf.concat(axis=3,values=[x_hat[:,:,:,2:3], x_hat[:,:,:,1:2],x_hat[:,:,:,0:1]])
+                #xt_gray = tf.concat(axis=3,values=[xt[:,:,:,2:3], xt[:,:,:,1:2],xt[:,:,:,0:1]])
+
+                #                 x_hat_gray = 1./255.*tf.image.rgb_to_grayscale(
+                #                     inverse_transform(x_hat_rgb)*255.
+                #                 )
+                #                 xt_gray = 1./255.*tf.image.rgb_to_grayscale(
+                #                     inverse_transform(xt_rgb)*255.
+                #                 )
+
+                x_hat_gray = x_hat
+                xt_gray = xt
+            else:
+                x_hat_gray = inverse_transform(x_hat)
+                xt_gray = inverse_transform(xt)
+
+            diff_in = x_hat_gray - xt_gray
+            xt = x_hat
+
+
+            pred.append(tf.reshape(x_hat, [self.batch_size, 1, self.image_size[0],
+                                           self.image_size[1], self.c_dim]))
+
+        return pred
+
+    def motion_enc(self, diff_in, reuse):
+        res_in = []
+
+        conv1 = relu(conv2d(diff_in, output_dim = self.gf_dim, k_h = 5, k_w = 5,
+                            d_h = 1, d_w = 1, name = 'dyn1_conv1', reuse = reuse))
+        res_in.append(conv1)
+        pool1 = MaxPooling(conv1, [2, 2])
+
+        conv2 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 5, k_w = 5,
+                            d_h = 1, d_w = 1, name = 'dyn_conv2', reuse = reuse))
+        res_in.append(conv2)
+        pool2 = MaxPooling(conv2, [2, 2])
+
+        conv3 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 7, k_w = 7,
+                            d_h = 1, d_w = 1, name = 'dyn_conv3', reuse = reuse))
+        res_in.append(conv3)
+        pool3 = MaxPooling(conv3, [2, 2])
+        return pool3, res_in
+
+    def content_enc(self, xt, reuse):
+        res_in = []
+        conv1_1 = relu(conv2d(xt, output_dim = self.gf_dim, k_h = 3, k_w = 3,
+                              d_h = 1, d_w = 1, name = 'cont_conv1_1', reuse = reuse))
+        conv1_2 = relu(conv2d(conv1_1, output_dim = self.gf_dim, k_h = 3, k_w = 3,
+                              d_h = 1, d_w = 1, name = 'cont_conv1_2', reuse = reuse))
+        res_in.append(conv1_2)
+        pool1 = MaxPooling(conv1_2, [2, 2])
+
+        conv2_1 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3,
+                              d_h = 1, d_w = 1, name = 'cont_conv2_1', reuse = reuse))
+        conv2_2 = relu(conv2d(conv2_1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3,
+                              d_h = 1, d_w = 1, name = 'cont_conv2_2', reuse = reuse))
+        res_in.append(conv2_2)
+        pool2 = MaxPooling(conv2_2, [2, 2])
+
+        conv3_1 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
+                              d_h = 1, d_w = 1, name = 'cont_conv3_1', reuse = reuse))
+        conv3_2 = relu(conv2d(conv3_1, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
+                              d_h = 1, d_w = 1, name = 'cont_conv3_2', reuse = reuse))
+        conv3_3 = relu(conv2d(conv3_2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
+                              d_h = 1, d_w = 1, name = 'cont_conv3_3', reuse = reuse))
+        res_in.append(conv3_3)
+        pool3 = MaxPooling(conv3_3, [2, 2])
+        return pool3, res_in
+
+    def comb_layers(self, h_dyn, h_cont, reuse=False):
+        comb1 = relu(conv2d(tf.concat(axis = 3, values = [h_dyn, h_cont]),
+                            output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
+                            d_h = 1, d_w = 1, name = 'comb1', reuse = reuse))
+        comb2 = relu(conv2d(comb1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3,
+                            d_h = 1, d_w = 1, name = 'comb2', reuse = reuse))
+        h_comb = relu(conv2d(comb2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
+                             d_h = 1, d_w = 1, name = 'h_comb', reuse = reuse))
+        return h_comb
+
+    def residual(self, input_dyn, input_cont, reuse=False):
+        n_layers = len(input_dyn)
+        res_out = []
+        for l in range(n_layers):
+            input_ = tf.concat(axis = 3, values = [input_dyn[l], input_cont[l]])
+            out_dim = input_cont[l].get_shape()[3]
+            res1 = relu(conv2d(input_, output_dim = out_dim,
+                               k_h = 3, k_w = 3, d_h = 1, d_w = 1,
+                               name = 'res' + str(l) + '_1', reuse = reuse))
+            res2 = conv2d(res1, output_dim = out_dim, k_h = 3, k_w = 3,
+                          d_h = 1, d_w = 1, name = 'res' + str(l) + '_2', reuse = reuse)
+            res_out.append(res2)
+        return res_out
+
+    def dec_cnn(self, h_comb, res_connect, reuse=False):
+
+        shapel3 = [self.batch_size, int(self.image_size[0] / 4),
+                   int(self.image_size[1] / 4), self.gf_dim * 4]
+        shapeout3 = [self.batch_size, int(self.image_size[0] / 4),
+                     int(self.image_size[1] / 4), self.gf_dim * 2]
+        depool3 = FixedUnPooling(h_comb, [2, 2])
+        deconv3_3 = relu(deconv2d(relu(tf.add(depool3, res_connect[2])),
+                                  output_shape = shapel3, k_h = 3, k_w = 3,
+                                  d_h = 1, d_w = 1, name = 'dec_deconv3_3', reuse = reuse))
+        deconv3_2 = relu(deconv2d(deconv3_3, output_shape = shapel3, k_h = 3, k_w = 3,
+                                  d_h = 1, d_w = 1, name = 'dec_deconv3_2', reuse = reuse))
+        deconv3_1 = relu(deconv2d(deconv3_2, output_shape = shapeout3, k_h = 3, k_w = 3,
+                                  d_h = 1, d_w = 1, name = 'dec_deconv3_1', reuse = reuse))
+
+        shapel2 = [self.batch_size, int(self.image_size[0] / 2),
+                   int(self.image_size[1] / 2), self.gf_dim * 2]
+        shapeout3 = [self.batch_size, int(self.image_size[0] / 2),
+                     int(self.image_size[1] / 2), self.gf_dim]
+        depool2 = FixedUnPooling(deconv3_1, [2, 2])
+        deconv2_2 = relu(deconv2d(relu(tf.add(depool2, res_connect[1])),
+                                  output_shape = shapel2, k_h = 3, k_w = 3,
+                                  d_h = 1, d_w = 1, name = 'dec_deconv2_2', reuse = reuse))
+        deconv2_1 = relu(deconv2d(deconv2_2, output_shape = shapeout3, k_h = 3, k_w = 3,
+                                  d_h = 1, d_w = 1, name = 'dec_deconv2_1', reuse = reuse))
+
+        shapel1 = [self.batch_size, self.image_size[0],
+                   self.image_size[1], self.gf_dim]
+        shapeout1 = [self.batch_size, self.image_size[0],
+                     self.image_size[1], self.c_dim]
+        depool1 = FixedUnPooling(deconv2_1, [2, 2])
+        deconv1_2 = relu(deconv2d(relu(tf.add(depool1, res_connect[0])),
+                                  output_shape = shapel1, k_h = 3, k_w = 3, d_h = 1, d_w = 1,
+                                  name = 'dec_deconv1_2', reuse = reuse))
+        xtp1 = tanh(deconv2d(deconv1_2, output_shape = shapeout1, k_h = 3, k_w = 3,
+                             d_h = 1, d_w = 1, name = 'dec_deconv1_1', reuse = reuse))
+        return xtp1
+
+    def discriminator(self, image):
+        h0 = lrelu(conv2d(image, self.df_dim, name = 'dis_h0_conv'))
+        h1 = lrelu(batch_norm(conv2d(h0, self.df_dim * 2, name = 'dis_h1_conv'),
+                              "bn1"))
+        h2 = lrelu(batch_norm(conv2d(h1, self.df_dim * 4, name = 'dis_h2_conv'),
+                              "bn2"))
+        h3 = lrelu(batch_norm(conv2d(h2, self.df_dim * 8, name = 'dis_h3_conv'),
+                              "bn3"))
+        h = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'dis_h3_lin')
+
+        return tf.nn.sigmoid(h), h
+
+    def save(self, sess, checkpoint_dir, step):
+        model_name = "MCNET.model"
+
+        if not os.path.exists(checkpoint_dir):
+            os.makedirs(checkpoint_dir)
+
+        self.saver.save(sess,
+                        os.path.join(checkpoint_dir, model_name),
+                        global_step = step)
+
+    def load(self, sess, checkpoint_dir, model_name=None):
+        print(" [*] Reading checkpoints...")
+        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
+        if ckpt and ckpt.model_checkpoint_path:
+            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
+            if model_name is None: model_name = ckpt_name
+            self.saver.restore(sess, os.path.join(checkpoint_dir, model_name))
+            print(" Loaded model: " + str(model_name))
+            return True, model_name
+        else:
+            return False, None
+
+        # Execute the forward and the backward pass
+
+    def run_single_step(self, global_step):
+        print("global_step:", global_step)
+        try:
+            train_batch = self.sess.run(self.train_iterator.get_next())
+            # z=np.random.uniform(-1,1,size=(self.batch_size,self.nz))
+            x = self.sess.run([self.x], feed_dict = {self.x: train_batch["images"]})
+            _, g_sum = self.sess.run([self.g_optim, self.g_sum], feed_dict = {self.x: train_batch["images"]})
+            _, d_sum = self.sess.run([self.d_optim, self.d_sum], feed_dict = {self.x: train_batch["images"]})
+
+            gen_data, train_loss = self.sess.run([self.gen_data, self.total_loss],
+                                                       feed_dict = {self.x: train_batch["images"]})
+
+        except tf.errors.OutOfRangeError:
+            print("train out of range error")
+
+        try:
+            val_batch = self.sess.run(self.val_iterator.get_next())
+            val_loss = self.sess.run([self.total_loss], feed_dict = {self.x: val_batch["images"]})
+            # self.val_writer.add_summary(val_summary, global_step)
+        except tf.errors.OutOfRangeError:
+            print("train out of range error")
+
+        return train_loss, val_total_loss
+
+
+
diff --git a/video_prediction/models/networks.py b/video_prediction/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..844c28295cfea3597bf6a1ce52c9b0f3891370a9
--- /dev/null
+++ b/video_prediction/models/networks.py
@@ -0,0 +1,110 @@
+import tensorflow as tf
+from tensorflow.python.util import nest
+
+from video_prediction import ops
+from video_prediction.ops import conv2d
+from video_prediction.ops import dense
+from video_prediction.ops import lrelu
+from video_prediction.ops import pool2d
+from video_prediction.utils import tf_utils
+
+
+def encoder(inputs, nef=64, n_layers=3, norm_layer='instance'):
+    print("********inputs*******",inputs.get_shape())
+    print("*********nef********", nef)
+    norm_layer = ops.get_norm_layer(norm_layer)
+    layers = []
+    paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]
+
+    with tf.variable_scope("layer_1"):
+        convolved = conv2d(tf.pad(inputs, paddings), nef, kernel_size=4, strides=2, padding='VALID')
+        rectified = lrelu(convolved, 0.2)
+        layers.append(rectified)
+
+    for i in range(1, n_layers):
+        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
+            out_channels = nef * min(2**i, 4)
+            convolved = conv2d(tf.pad(layers[-1], paddings), out_channels, kernel_size=4, strides=2, padding='VALID')
+            normalized = norm_layer(convolved)
+            rectified = lrelu(normalized, 0.2)
+            layers.append(rectified)
+
+    pooled = pool2d(rectified, rectified.shape.as_list()[1:3], padding='VALID', pool_mode='avg')
+    squeezed = tf.squeeze(pooled, [1, 2])
+    return squeezed
+
+
+def image_sn_discriminator(images, ndf=64):
+    batch_size = images.shape[0].value
+    layers = []
+    paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]
+
+    def conv2d(inputs, *args, **kwargs):
+        kwargs.setdefault('padding', 'VALID')
+        kwargs.setdefault('use_spectral_norm', True)
+        return ops.conv2d(tf.pad(inputs, paddings), *args, **kwargs)
+
+    with tf.variable_scope("sn_conv0_0"):
+        layers.append(lrelu(conv2d(images, ndf, kernel_size=3, strides=1), 0.1))
+
+    with tf.variable_scope("sn_conv0_1"):
+        layers.append(lrelu(conv2d(layers[-1], ndf * 2, kernel_size=4, strides=2), 0.1))
+
+    with tf.variable_scope("sn_conv1_0"):
+        layers.append(lrelu(conv2d(layers[-1], ndf * 2, kernel_size=3, strides=1), 0.1))
+
+    with tf.variable_scope("sn_conv1_1"):
+        layers.append(lrelu(conv2d(layers[-1], ndf * 4, kernel_size=4, strides=2), 0.1))
+
+    with tf.variable_scope("sn_conv2_0"):
+        layers.append(lrelu(conv2d(layers[-1], ndf * 4, kernel_size=3, strides=1), 0.1))
+
+    with tf.variable_scope("sn_conv2_1"):
+        layers.append(lrelu(conv2d(layers[-1], ndf * 8, kernel_size=4, strides=2), 0.1))
+
+    with tf.variable_scope("sn_conv3_0"):
+        layers.append(lrelu(conv2d(layers[-1], ndf * 8, kernel_size=3, strides=1), 0.1))
+
+    with tf.variable_scope("sn_fc4"):
+        logits = dense(tf.reshape(layers[-1], [batch_size, -1]), 1, use_spectral_norm=True)
+        layers.append(logits)
+    return layers
+
+
+def video_sn_discriminator(clips, ndf=64):
+    clips = tf_utils.transpose_batch_time(clips)
+    batch_size = clips.shape[0].value
+    layers = []
+    paddings = [[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]]
+
+    def conv3d(inputs, *args, **kwargs):
+        kwargs.setdefault('padding', 'VALID')
+        kwargs.setdefault('use_spectral_norm', True)
+        return ops.conv3d(tf.pad(inputs, paddings), *args, **kwargs)
+
+    with tf.variable_scope("sn_conv0_0"):
+        layers.append(lrelu(conv3d(clips, ndf, kernel_size=3, strides=1), 0.1))
+
+    with tf.variable_scope("sn_conv0_1"):
+        layers.append(lrelu(conv3d(layers[-1], ndf * 2, kernel_size=4, strides=(1, 2, 2)), 0.1))
+
+    with tf.variable_scope("sn_conv1_0"):
+        layers.append(lrelu(conv3d(layers[-1], ndf * 2, kernel_size=3, strides=1), 0.1))
+
+    with tf.variable_scope("sn_conv1_1"):
+        layers.append(lrelu(conv3d(layers[-1], ndf * 4, kernel_size=4, strides=(1, 2, 2)), 0.1))
+
+    with tf.variable_scope("sn_conv2_0"):
+        layers.append(lrelu(conv3d(layers[-1], ndf * 4, kernel_size=3, strides=1), 0.1))
+
+    with tf.variable_scope("sn_conv2_1"):
+        layers.append(lrelu(conv3d(layers[-1], ndf * 8, kernel_size=4, strides=2), 0.1))
+
+    with tf.variable_scope("sn_conv3_0"):
+        layers.append(lrelu(conv3d(layers[-1], ndf * 8, kernel_size=3, strides=1), 0.1))
+
+    with tf.variable_scope("sn_fc4"):
+        logits = dense(tf.reshape(layers[-1], [batch_size, -1]), 1, use_spectral_norm=True)
+        layers.append(logits)
+    layers = nest.map_structure(tf_utils.transpose_batch_time, layers)
+    return layers
diff --git a/video_prediction/models/non_trainable_model.py b/video_prediction/models/non_trainable_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5082b5c0357d37262b3236185be89e146d92a84
--- /dev/null
+++ b/video_prediction/models/non_trainable_model.py
@@ -0,0 +1,54 @@
+from collections import OrderedDict
+from tensorflow.python.util import nest
+from video_prediction.utils.tf_utils import transpose_batch_time
+
+import tensorflow as tf
+
+from .base_model import BaseVideoPredictionModel
+
+
+class NonTrainableVideoPredictionModel(BaseVideoPredictionModel):
+    pass
+
+
+class GroundTruthVideoPredictionModel(NonTrainableVideoPredictionModel):
+    def build_graph(self, inputs):
+        super(GroundTruthVideoPredictionModel, self).build_graph(inputs)
+
+        self.outputs = OrderedDict()
+        self.outputs['gen_images'] = self.inputs['images'][:, 1:]
+        if 'pix_distribs' in self.inputs:
+            self.outputs['gen_pix_distribs'] = self.inputs['pix_distribs'][:, 1:]
+
+        inputs, outputs = nest.map_structure(transpose_batch_time, (self.inputs, self.outputs))
+        with tf.name_scope("metrics"):
+            metrics = self.metrics_fn(inputs, outputs)
+        with tf.name_scope("eval_outputs_and_metrics"):
+            eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs)
+        self.metrics, self.eval_outputs, self.eval_metrics = nest.map_structure(
+            transpose_batch_time, (metrics, eval_outputs, eval_metrics))
+
+
+class RepeatVideoPredictionModel(NonTrainableVideoPredictionModel):
+    def build_graph(self, inputs):
+        super(RepeatVideoPredictionModel, self).build_graph(inputs)
+
+        self.outputs = OrderedDict()
+        tile_pattern = [1, self.hparams.sequence_length - self.hparams.context_frames, 1, 1, 1]
+        last_context_images = self.inputs['images'][:, self.hparams.context_frames - 1]
+        self.outputs['gen_images'] = tf.concat([
+            self.inputs['images'][:, 1:self.hparams.context_frames - 1],
+            tf.tile(last_context_images[:, None], tile_pattern)], axis=-1)
+        if 'pix_distribs' in self.inputs:
+            last_context_pix_distrib = self.inputs['pix_distribs'][:, self.hparams.context_frames - 1]
+            self.outputs['gen_pix_distribs'] = tf.concat([
+                self.inputs['pix_distribs'][:, 1:self.hparams.context_frames - 1],
+                tf.tile(last_context_pix_distrib[:, None], tile_pattern)], axis=-1)
+
+        inputs, outputs = nest.map_structure(transpose_batch_time, (self.inputs, self.outputs))
+        with tf.name_scope("metrics"):
+            metrics = self.metrics_fn(inputs, outputs)
+        with tf.name_scope("eval_outputs_and_metrics"):
+            eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs)
+        self.metrics, self.eval_outputs, self.eval_metrics = nest.map_structure(
+            transpose_batch_time, (metrics, eval_outputs, eval_metrics))
diff --git a/video_prediction/models/savp_model.py b/video_prediction/models/savp_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca8acd3f32a5ea1772c9fbf36003149acfdcb950
--- /dev/null
+++ b/video_prediction/models/savp_model.py
@@ -0,0 +1,993 @@
+import collections
+import functools
+import itertools
+from collections import OrderedDict
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.util import nest
+
+from video_prediction import ops, flow_ops
+from video_prediction.models import VideoPredictionModel
+from video_prediction.models import networks
+from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
+from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
+from video_prediction.utils import tf_utils
+
+# Amount to use when lower bounding tensors
+RELU_SHIFT = 1e-12
+
+
+def posterior_fn(inputs, hparams):
+    images = inputs['images']
+    image_pairs = tf.concat([images[:-1], images[1:]], axis=-1)
+    if 'actions' in inputs:
+        image_pairs = tile_concat(
+            [image_pairs, inputs['actions'][..., None, None, :]], axis=-1)
+
+    h = tf_utils.with_flat_batch(networks.encoder)(
+        image_pairs, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer)
+
+    if hparams.use_e_rnn:
+        with tf.variable_scope('layer_%d' % (hparams.n_layers + 1)):
+            h = tf_utils.with_flat_batch(dense, 2)(h, hparams.nef * 4)
+
+        if hparams.rnn == 'lstm':
+            RNNCell = tf.contrib.rnn.BasicLSTMCell
+        elif hparams.rnn == 'gru':
+            RNNCell = tf.contrib.rnn.GRUCell
+        else:
+            raise NotImplementedError
+        with tf.variable_scope('%s' % hparams.rnn):
+            rnn_cell = RNNCell(hparams.nef * 4)
+            h, _ = tf_utils.unroll_rnn(rnn_cell, h)
+
+    with tf.variable_scope('z_mu'):
+        z_mu = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz)
+    with tf.variable_scope('z_log_sigma_sq'):
+        z_log_sigma_sq = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz)
+        z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10)
+    outputs = {'zs_mu': z_mu, 'zs_log_sigma_sq': z_log_sigma_sq}
+    return outputs
+
+
+def prior_fn(inputs, hparams):
+    images = inputs['images']
+    image_pairs = tf.concat([images[:hparams.context_frames - 1], images[1:hparams.context_frames]], axis=-1)
+    if 'actions' in inputs:
+        image_pairs = tile_concat(
+            [image_pairs, inputs['actions'][..., None, None, :]], axis=-1)
+
+    h = tf_utils.with_flat_batch(networks.encoder)(
+        image_pairs, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer)
+    h_zeros = tf.zeros(tf.concat([[hparams.sequence_length - hparams.context_frames], tf.shape(h)[1:]], axis=0))
+    h = tf.concat([h, h_zeros], axis=0)
+
+    with tf.variable_scope('layer_%d' % (hparams.n_layers + 1)):
+        h = tf_utils.with_flat_batch(dense, 2)(h, hparams.nef * 4)
+
+    if hparams.rnn == 'lstm':
+        RNNCell = tf.contrib.rnn.BasicLSTMCell
+    elif hparams.rnn == 'gru':
+        RNNCell = tf.contrib.rnn.GRUCell
+    else:
+        raise NotImplementedError
+    with tf.variable_scope('%s' % hparams.rnn):
+        rnn_cell = RNNCell(hparams.nef * 4)
+        h, _ = tf_utils.unroll_rnn(rnn_cell, h)
+
+    with tf.variable_scope('z_mu'):
+        z_mu = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz)
+    with tf.variable_scope('z_log_sigma_sq'):
+        z_log_sigma_sq = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz)
+        z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10)
+    outputs = {'zs_mu': z_mu, 'zs_log_sigma_sq': z_log_sigma_sq}
+    return outputs
+
+
+def discriminator_given_video_fn(targets, hparams):
+    sequence_length, batch_size = targets.shape.as_list()[:2]
+    clip_length = hparams.clip_length
+
+    # sample an image and apply the image distriminator on that frame
+    t_sample = tf.random_uniform([batch_size], minval=0, maxval=sequence_length, dtype=tf.int32)
+    image_sample = tf.gather_nd(targets, tf.stack([t_sample, tf.range(batch_size)], axis=1))
+
+    # sample a subsequence of length clip_length and apply the images/video discriminators on those frames
+    t_start = tf.random_uniform([batch_size], minval=0, maxval=sequence_length - clip_length + 1, dtype=tf.int32)
+    t_start_indices = tf.stack([t_start, tf.range(batch_size)], axis=1)
+    t_offset_indices = tf.stack([tf.range(clip_length), tf.zeros(clip_length, dtype=tf.int32)], axis=1)
+    indices = t_start_indices[None] + t_offset_indices[:, None]
+    clip_sample = tf.gather_nd(targets, flatten(indices, 0, 1))
+    clip_sample = tf.reshape(clip_sample, [clip_length] + targets.shape.as_list()[1:])
+
+    outputs = {}
+    if hparams.image_sn_gan_weight or hparams.image_sn_vae_gan_weight:
+        with tf.variable_scope('image'):
+            image_features = networks.image_sn_discriminator(image_sample, ndf=hparams.ndf)
+            image_features, image_logits = image_features[:-1], image_features[-1]
+            outputs['discrim_image_sn_logits'] = image_logits
+            for i, image_feature in enumerate(image_features):
+                outputs['discrim_image_sn_feature%d' % i] = image_feature
+    if hparams.video_sn_gan_weight or hparams.video_sn_vae_gan_weight:
+        with tf.variable_scope('video'):
+            video_features = networks.video_sn_discriminator(clip_sample, ndf=hparams.ndf)
+            video_features, video_logits = video_features[:-1], video_features[-1]
+            outputs['discrim_video_sn_logits'] = video_logits
+            for i, video_feature in enumerate(video_features):
+                outputs['discrim_video_sn_feature%d' % i] = video_feature
+    if hparams.images_sn_gan_weight or hparams.images_sn_vae_gan_weight:
+        with tf.variable_scope('images'):
+            images_features = tf_utils.with_flat_batch(networks.image_sn_discriminator)(clip_sample, ndf=hparams.ndf)
+            images_features, images_logits = images_features[:-1], images_features[-1]
+            outputs['discrim_images_sn_logits'] = images_logits
+            for i, images_feature in enumerate(images_features):
+                outputs['discrim_images_sn_feature%d' % i] = images_feature
+    return outputs
+
+
+def discriminator_fn(inputs, outputs, mode, hparams):
+    # do the encoder version first so that it isn't affected by the reuse_variables() call
+    if hparams.nz == 0:
+        discrim_outputs_enc_real = collections.OrderedDict()
+        discrim_outputs_enc_fake = collections.OrderedDict()
+    else:
+        images_enc_real = inputs['images'][1:]
+        images_enc_fake = outputs['gen_images_enc']
+        if hparams.use_same_discriminator:
+            with tf.name_scope("real"):
+                discrim_outputs_enc_real = discriminator_given_video_fn(images_enc_real, hparams)
+            tf.get_variable_scope().reuse_variables()
+            with tf.name_scope("fake"):
+                discrim_outputs_enc_fake = discriminator_given_video_fn(images_enc_fake, hparams)
+        else:
+            with tf.variable_scope('encoder'), tf.name_scope("real"):
+                discrim_outputs_enc_real = discriminator_given_video_fn(images_enc_real, hparams)
+            with tf.variable_scope('encoder', reuse=True), tf.name_scope("fake"):
+                discrim_outputs_enc_fake = discriminator_given_video_fn(images_enc_fake, hparams)
+
+    images_real = inputs['images'][1:]
+    images_fake = outputs['gen_images']
+    with tf.name_scope("real"):
+        discrim_outputs_real = discriminator_given_video_fn(images_real, hparams)
+    tf.get_variable_scope().reuse_variables()
+    with tf.name_scope("fake"):
+        discrim_outputs_fake = discriminator_given_video_fn(images_fake, hparams)
+
+    discrim_outputs_real = OrderedDict([(k + '_real', v) for k, v in discrim_outputs_real.items()])
+    discrim_outputs_fake = OrderedDict([(k + '_fake', v) for k, v in discrim_outputs_fake.items()])
+    discrim_outputs_enc_real = OrderedDict([(k + '_enc_real', v) for k, v in discrim_outputs_enc_real.items()])
+    discrim_outputs_enc_fake = OrderedDict([(k + '_enc_fake', v) for k, v in discrim_outputs_enc_fake.items()])
+    outputs = [discrim_outputs_real, discrim_outputs_fake,
+               discrim_outputs_enc_real, discrim_outputs_enc_fake]
+    total_num_outputs = sum([len(output) for output in outputs])
+    outputs = collections.OrderedDict(itertools.chain(*[output.items() for output in outputs]))
+    assert len(outputs) == total_num_outputs  # ensure no output is lost because of repeated keys
+    return outputs
+
+
+class SAVPCell(tf.nn.rnn_cell.RNNCell):
+    def __init__(self, inputs, mode, hparams, reuse=None):
+        super(SAVPCell, self).__init__(_reuse=reuse)
+        self.inputs = inputs
+        self.mode = mode
+        self.hparams = hparams
+
+        if self.hparams.where_add not in ('input', 'all', 'middle'):
+            raise ValueError('Invalid where_add %s' % self.hparams.where_add)
+
+        batch_size = inputs['images'].shape[1].value
+        image_shape = inputs['images'].shape.as_list()[2:]
+        height, width, _ = image_shape
+        scale_size = min(height, width)
+        if scale_size >= 256:
+            self.encoder_layer_specs = [
+                (self.hparams.ngf, False),
+                (self.hparams.ngf * 2, False),
+                (self.hparams.ngf * 4, True),
+                (self.hparams.ngf * 8, True),
+                (self.hparams.ngf * 8, True),
+            ]
+            self.decoder_layer_specs = [
+                (self.hparams.ngf * 8, True),
+                (self.hparams.ngf * 4, True),
+                (self.hparams.ngf * 2, False),
+                (self.hparams.ngf, False),
+                (self.hparams.ngf, False),
+            ]
+        elif scale_size >= 128:
+            self.encoder_layer_specs = [
+                (self.hparams.ngf, False),
+                (self.hparams.ngf * 2, True),
+                (self.hparams.ngf * 4, True),
+                (self.hparams.ngf * 8, True),
+            ]
+            self.decoder_layer_specs = [
+                (self.hparams.ngf * 8, True),
+                (self.hparams.ngf * 4, True),
+                (self.hparams.ngf * 2, False),
+                (self.hparams.ngf, False),
+            ]
+        elif scale_size >= 64:
+            self.encoder_layer_specs = [
+                (self.hparams.ngf, True),
+                (self.hparams.ngf * 2, True),
+                (self.hparams.ngf * 4, True),
+            ]
+            self.decoder_layer_specs = [
+                (self.hparams.ngf * 2, True),
+                (self.hparams.ngf, True),
+                (self.hparams.ngf, False),
+            ]
+        elif scale_size >= 32:
+            self.encoder_layer_specs = [
+                (self.hparams.ngf, True),
+                (self.hparams.ngf * 2, True),
+            ]
+            self.decoder_layer_specs = [
+                (self.hparams.ngf, True),
+                (self.hparams.ngf, False),
+            ]
+        else:
+            raise NotImplementedError
+        assert len(self.encoder_layer_specs) == len(self.decoder_layer_specs)
+        total_stride = 2 ** len(self.encoder_layer_specs)
+        if (height % total_stride) or (width % total_stride):
+            raise ValueError("The image has dimension (%d, %d), but it should be divisible "
+                             "by the total stride, which is %d." % (height, width, total_stride))
+
+        # output_size
+        num_masks = self.hparams.last_frames * self.hparams.num_transformed_images + \
+            int(bool(self.hparams.prev_image_background)) + \
+            int(bool(self.hparams.first_image_background and not self.hparams.context_images_background)) + \
+            int(bool(self.hparams.last_image_background and not self.hparams.context_images_background)) + \
+            int(bool(self.hparams.last_context_image_background and not self.hparams.context_images_background)) + \
+            (self.hparams.context_frames if self.hparams.context_images_background else 0) + \
+            int(bool(self.hparams.generate_scratch_image))
+        output_size = {
+            'gen_images': tf.TensorShape(image_shape),
+            'transformed_images': tf.TensorShape(image_shape + [num_masks]),
+            'masks': tf.TensorShape([height, width, 1, num_masks]),
+        }
+        if 'pix_distribs' in inputs:
+            num_motions = inputs['pix_distribs'].shape[-1].value
+            output_size['gen_pix_distribs'] = tf.TensorShape([height, width, num_motions])
+            output_size['transformed_pix_distribs'] = tf.TensorShape([height, width, num_motions, num_masks])
+        if 'states' in inputs:
+            output_size['gen_states'] = inputs['states'].shape[2:]
+        if self.hparams.transformation == 'flow':
+            output_size['gen_flows'] = tf.TensorShape([height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images])
+            output_size['gen_flows_rgb'] = tf.TensorShape([height, width, 3, self.hparams.last_frames * self.hparams.num_transformed_images])
+        self._output_size = output_size
+
+        # state_size
+        conv_rnn_state_sizes = []
+        conv_rnn_height, conv_rnn_width = height, width
+        for out_channels, use_conv_rnn in self.encoder_layer_specs:
+            conv_rnn_height //= 2
+            conv_rnn_width //= 2
+            if use_conv_rnn and not self.hparams.ablation_rnn:
+                conv_rnn_state_sizes.append(tf.TensorShape([conv_rnn_height, conv_rnn_width, out_channels]))
+        for out_channels, use_conv_rnn in self.decoder_layer_specs:
+            conv_rnn_height *= 2
+            conv_rnn_width *= 2
+            if use_conv_rnn and not self.hparams.ablation_rnn:
+                conv_rnn_state_sizes.append(tf.TensorShape([conv_rnn_height, conv_rnn_width, out_channels]))
+        if self.hparams.conv_rnn == 'lstm':
+            conv_rnn_state_sizes = [tf.nn.rnn_cell.LSTMStateTuple(conv_rnn_state_size, conv_rnn_state_size)
+                                    for conv_rnn_state_size in conv_rnn_state_sizes]
+        state_size = {'time': tf.TensorShape([]),
+                      'gen_image': tf.TensorShape(image_shape),
+                      'last_images': [tf.TensorShape(image_shape)] * self.hparams.last_frames,
+                      'conv_rnn_states': conv_rnn_state_sizes}
+        if 'zs' in inputs and self.hparams.use_rnn_z and not self.hparams.ablation_rnn:
+            rnn_z_state_size = tf.TensorShape([self.hparams.nz])
+            if self.hparams.rnn == 'lstm':
+                rnn_z_state_size = tf.nn.rnn_cell.LSTMStateTuple(rnn_z_state_size, rnn_z_state_size)
+            state_size['rnn_z_state'] = rnn_z_state_size
+        if 'pix_distribs' in inputs:
+            state_size['gen_pix_distrib'] = tf.TensorShape([height, width, num_motions])
+            state_size['last_pix_distribs'] = [tf.TensorShape([height, width, num_motions])] * self.hparams.last_frames
+        if 'states' in inputs:
+            state_size['gen_state'] = inputs['states'].shape[2:]
+        self._state_size = state_size
+
+        if self.hparams.learn_initial_state:
+            learnable_initial_state_size = {k: v for k, v in state_size.items()
+                                            if k in ('conv_rnn_states', 'rnn_z_state')}
+        else:
+            learnable_initial_state_size = {}
+        learnable_initial_state_flat = []
+        for i, size in enumerate(nest.flatten(learnable_initial_state_size)):
+            with tf.variable_scope('initial_state_%d' % i):
+                state = tf.get_variable('initial_state', size,
+                                        dtype=tf.float32, initializer=tf.zeros_initializer())
+                learnable_initial_state_flat.append(state)
+        self._learnable_initial_state = nest.pack_sequence_as(
+            learnable_initial_state_size, learnable_initial_state_flat)
+
+        ground_truth_sampling_shape = [self.hparams.sequence_length - 1 - self.hparams.context_frames, batch_size]
+        if self.hparams.schedule_sampling == 'none' or self.mode != 'train':
+            ground_truth_sampling = tf.constant(False, dtype=tf.bool, shape=ground_truth_sampling_shape)
+        elif self.hparams.schedule_sampling in ('inverse_sigmoid', 'linear'):
+            if self.hparams.schedule_sampling == 'inverse_sigmoid':
+                k = self.hparams.schedule_sampling_k
+                start_step = self.hparams.schedule_sampling_steps[0]
+                iter_num = tf.to_float(tf.train.get_or_create_global_step())
+                prob = (k / (k + tf.exp((iter_num - start_step) / k)))
+                prob = tf.cond(tf.less(iter_num, start_step), lambda: 1.0, lambda: prob)
+            elif self.hparams.schedule_sampling == 'linear':
+                start_step, end_step = self.hparams.schedule_sampling_steps
+                step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step)
+                prob = 1.0 - tf.to_float(step - start_step) / tf.to_float(end_step - start_step)
+            log_probs = tf.log([1 - prob, prob])
+            ground_truth_sampling = tf.multinomial([log_probs] * batch_size, ground_truth_sampling_shape[0])
+            ground_truth_sampling = tf.cast(tf.transpose(ground_truth_sampling, [1, 0]), dtype=tf.bool)
+            # Ensure that eventually, the model is deterministically
+            # autoregressive (as opposed to autoregressive with very high probability).
+            ground_truth_sampling = tf.cond(tf.less(prob, 0.001),
+                                            lambda: tf.constant(False, dtype=tf.bool, shape=ground_truth_sampling_shape),
+                                            lambda: ground_truth_sampling)
+        else:
+            raise NotImplementedError
+        ground_truth_context = tf.constant(True, dtype=tf.bool, shape=[self.hparams.context_frames, batch_size])
+        self.ground_truth = tf.concat([ground_truth_context, ground_truth_sampling], axis=0)
+
+    @property
+    def output_size(self):
+        return self._output_size
+
+    @property
+    def state_size(self):
+        return self._state_size
+
+    def zero_state(self, batch_size, dtype):
+        init_state = super(SAVPCell, self).zero_state(batch_size, dtype)
+        learnable_init_state = nest.map_structure(
+            lambda x: tf.tile(x[None], [batch_size] + [1] * x.shape.ndims), self._learnable_initial_state)
+        init_state.update(learnable_init_state)
+        init_state['last_images'] = [self.inputs['images'][0]] * self.hparams.last_frames
+        if 'pix_distribs' in self.inputs:
+            init_state['last_pix_distribs'] = [self.inputs['pix_distribs'][0]] * self.hparams.last_frames
+        return init_state
+
+    def _rnn_func(self, inputs, state, num_units):
+        if self.hparams.rnn == 'lstm':
+            RNNCell = functools.partial(tf.nn.rnn_cell.LSTMCell, name='basic_lstm_cell')
+        elif self.hparams.rnn == 'gru':
+            RNNCell = tf.contrib.rnn.GRUCell
+        else:
+            raise NotImplementedError
+        rnn_cell = RNNCell(num_units, reuse=tf.get_variable_scope().reuse)
+        return rnn_cell(inputs, state)
+
+    def _conv_rnn_func(self, inputs, state, filters):
+        if isinstance(inputs, (list, tuple)):
+            inputs_shape = inputs[0].shape.as_list()
+        else:
+            inputs_shape = inputs.shape.as_list()
+        input_shape = inputs_shape[1:]
+        if self.hparams.conv_rnn_norm_layer == 'none':
+            normalizer_fn = None
+        else:
+            normalizer_fn = ops.get_norm_layer(self.hparams.conv_rnn_norm_layer)
+        if self.hparams.conv_rnn == 'lstm':
+            Conv2DRNNCell = BasicConv2DLSTMCell
+        elif self.hparams.conv_rnn == 'gru':
+            Conv2DRNNCell = Conv2DGRUCell
+        else:
+            raise NotImplementedError
+        if self.hparams.ablation_conv_rnn_norm:
+            conv_rnn_cell = Conv2DRNNCell(input_shape, filters, kernel_size=(5, 5),
+                                          reuse=tf.get_variable_scope().reuse)
+            h, state = conv_rnn_cell(inputs, state)
+            outputs = (normalizer_fn(h), state)
+        else:
+            conv_rnn_cell = Conv2DRNNCell(input_shape, filters, kernel_size=(5, 5),
+                                          normalizer_fn=normalizer_fn,
+                                          separate_norms=self.hparams.conv_rnn_norm_layer == 'layer',
+                                          reuse=tf.get_variable_scope().reuse)
+            outputs = conv_rnn_cell(inputs, state)
+        return outputs
+
+    def call(self, inputs, states):
+        norm_layer = ops.get_norm_layer(self.hparams.norm_layer)
+        downsample_layer = ops.get_downsample_layer(self.hparams.downsample_layer)
+        upsample_layer = ops.get_upsample_layer(self.hparams.upsample_layer)
+        activation_layer = ops.get_activation_layer(self.hparams.activation_layer)
+        image_shape = inputs['images'].get_shape().as_list()
+        batch_size, height, width, color_channels = image_shape
+        conv_rnn_states = states['conv_rnn_states']
+
+        time = states['time']
+        with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]):
+            t = tf.to_int32(tf.identity(time[0]))
+
+        image = tf.where(self.ground_truth[t], inputs['images'], states['gen_image'])  # schedule sampling (if any)
+        last_images = states['last_images'][1:] + [image]
+        if 'pix_distribs' in inputs:
+            pix_distrib = tf.where(self.ground_truth[t], inputs['pix_distribs'], states['gen_pix_distrib'])
+            last_pix_distribs = states['last_pix_distribs'][1:] + [pix_distrib]
+        if 'states' in inputs:
+            state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state'])
+
+        state_action = []
+        state_action_z = []
+        if 'actions' in inputs:
+            state_action.append(inputs['actions'])
+            state_action_z.append(inputs['actions'])
+        if 'states' in inputs:
+            state_action.append(state)
+            # don't backpropagate the convnet through the state dynamics
+            state_action_z.append(tf.stop_gradient(state))
+
+        if 'zs' in inputs:
+            if self.hparams.use_rnn_z:
+                with tf.variable_scope('%s_z' % ('fc' if self.hparams.ablation_rnn else self.hparams.rnn)):
+                    if self.hparams.ablation_rnn:
+                        rnn_z = dense(inputs['zs'], self.hparams.nz)
+                        rnn_z = tf.nn.tanh(rnn_z)
+                    else:
+                        rnn_z, rnn_z_state = self._rnn_func(inputs['zs'], states['rnn_z_state'], self.hparams.nz)
+                state_action_z.append(rnn_z)
+            else:
+                state_action_z.append(inputs['zs'])
+
+        def concat(tensors, axis):
+            if len(tensors) == 0:
+                return tf.zeros([batch_size, 0])
+            elif len(tensors) == 1:
+                return tensors[0]
+            else:
+                return tf.concat(tensors, axis=axis)
+        state_action = concat(state_action, axis=-1)
+        state_action_z = concat(state_action_z, axis=-1)
+
+        layers = []
+        new_conv_rnn_states = []
+        for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs):
+            with tf.variable_scope('h%d' % i):
+                if i == 0:
+                    h = tf.concat([image, self.inputs['images'][0]], axis=-1)
+                    kernel_size = (5, 5)
+                else:
+                    h = layers[-1][-1]
+                    kernel_size = (3, 3)
+                if self.hparams.where_add == 'all' or (self.hparams.where_add == 'input' and i == 0):
+                    if self.hparams.use_tile_concat:
+                        h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
+                    else:
+                        h = [h, state_action_z]
+                h = _maybe_tile_concat_layer(downsample_layer)(
+                    h, out_channels, kernel_size=kernel_size, strides=(2, 2))
+                h = norm_layer(h)
+                h = activation_layer(h)
+            if use_conv_rnn:
+                with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, i)):
+                    if self.hparams.where_add == 'all':
+                        if self.hparams.use_tile_concat:
+                            conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
+                        else:
+                            conv_rnn_h = [h, state_action_z]
+                    else:
+                        conv_rnn_h = h
+                    if self.hparams.ablation_rnn:
+                        conv_rnn_h = _maybe_tile_concat_layer(conv2d)(
+                            conv_rnn_h, out_channels, kernel_size=(5, 5))
+                        conv_rnn_h = norm_layer(conv_rnn_h)
+                        conv_rnn_h = activation_layer(conv_rnn_h)
+                    else:
+                        conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
+                        conv_rnn_h, conv_rnn_state = self._conv_rnn_func(conv_rnn_h, conv_rnn_state, out_channels)
+                        new_conv_rnn_states.append(conv_rnn_state)
+            layers.append((h, conv_rnn_h) if use_conv_rnn else (h,))
+
+        num_encoder_layers = len(layers)
+        for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs):
+            with tf.variable_scope('h%d' % len(layers)):
+                if i == 0:
+                    h = layers[-1][-1]
+                else:
+                    h = tf.concat([layers[-1][-1], layers[num_encoder_layers - i - 1][-1]], axis=-1)
+                if self.hparams.where_add == 'all' or (self.hparams.where_add == 'middle' and i == 0):
+                    if self.hparams.use_tile_concat:
+                        h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
+                    else:
+                        h = [h, state_action_z]
+                h = _maybe_tile_concat_layer(upsample_layer)(
+                    h, out_channels, kernel_size=(3, 3), strides=(2, 2))
+                h = norm_layer(h)
+                h = activation_layer(h)
+            if use_conv_rnn:
+                with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, len(layers))):
+                    if self.hparams.where_add == 'all':
+                        if self.hparams.use_tile_concat:
+                            conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
+                        else:
+                            conv_rnn_h = [h, state_action_z]
+                    else:
+                        conv_rnn_h = h
+                    if self.hparams.ablation_rnn:
+                        conv_rnn_h = _maybe_tile_concat_layer(conv2d)(conv_rnn_h, out_channels, kernel_size=(5, 5))
+                        conv_rnn_h = norm_layer(conv_rnn_h)
+                        conv_rnn_h = activation_layer(conv_rnn_h)
+                    else:
+                        conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
+                        conv_rnn_h, conv_rnn_state = self._conv_rnn_func(conv_rnn_h, conv_rnn_state, out_channels)
+                        new_conv_rnn_states.append(conv_rnn_state)
+            layers.append((h, conv_rnn_h) if use_conv_rnn else (h,))
+        assert len(new_conv_rnn_states) == len(conv_rnn_states)
+
+        if self.hparams.last_frames and self.hparams.num_transformed_images:
+            if self.hparams.transformation == 'flow':
+                with tf.variable_scope('h%d_flow' % len(layers)):
+                    h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1))
+                    h_flow = norm_layer(h_flow)
+                    h_flow = activation_layer(h_flow)
+
+                with tf.variable_scope('flows'):
+                    flows = conv2d(h_flow, 2 * self.hparams.last_frames * self.hparams.num_transformed_images, kernel_size=(3, 3), strides=(1, 1))
+                    flows = tf.reshape(flows, [batch_size, height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images])
+            else:
+                assert len(self.hparams.kernel_size) == 2
+                kernel_shape = list(self.hparams.kernel_size) + [self.hparams.last_frames * self.hparams.num_transformed_images]
+                if self.hparams.transformation == 'dna':
+                    with tf.variable_scope('h%d_dna_kernel' % len(layers)):
+                        h_dna_kernel = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1))
+                        h_dna_kernel = norm_layer(h_dna_kernel)
+                        h_dna_kernel = activation_layer(h_dna_kernel)
+
+                    # Using largest hidden state for predicting untied conv kernels.
+                    with tf.variable_scope('dna_kernels'):
+                        kernels = conv2d(h_dna_kernel, np.prod(kernel_shape), kernel_size=(3, 3), strides=(1, 1))
+                        kernels = tf.reshape(kernels, [batch_size, height, width] + kernel_shape)
+                        kernels = kernels + identity_kernel(self.hparams.kernel_size)[None, None, None, :, :, None]
+                    kernel_spatial_axes = [3, 4]
+                elif self.hparams.transformation == 'cdna':
+                    with tf.variable_scope('cdna_kernels'):
+                        smallest_layer = layers[num_encoder_layers - 1][-1]
+                        kernels = dense(flatten(smallest_layer), np.prod(kernel_shape))
+                        kernels = tf.reshape(kernels, [batch_size] + kernel_shape)
+                        kernels = kernels + identity_kernel(self.hparams.kernel_size)[None, :, :, None]
+                    kernel_spatial_axes = [1, 2]
+                else:
+                    raise ValueError('Invalid transformation %s' % self.hparams.transformation)
+
+            if self.hparams.transformation != 'flow':
+                with tf.name_scope('kernel_normalization'):
+                    kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT
+                    kernels /= tf.reduce_sum(kernels, axis=kernel_spatial_axes, keepdims=True)
+
+        if self.hparams.generate_scratch_image:
+            with tf.variable_scope('h%d_scratch' % len(layers)):
+                h_scratch = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1))
+                h_scratch = norm_layer(h_scratch)
+                h_scratch = activation_layer(h_scratch)
+
+            # Using largest hidden state for predicting a new image layer.
+            # This allows the network to also generate one image from scratch,
+            # which is useful when regions of the image become unoccluded.
+            with tf.variable_scope('scratch_image'):
+                scratch_image = conv2d(h_scratch, color_channels, kernel_size=(3, 3), strides=(1, 1))
+                scratch_image = tf.nn.sigmoid(scratch_image)
+
+        with tf.name_scope('transformed_images'):
+            transformed_images = []
+            if self.hparams.last_frames and self.hparams.num_transformed_images:
+                if self.hparams.transformation == 'flow':
+                    transformed_images.extend(apply_flows(last_images, flows))
+                else:
+                    transformed_images.extend(apply_kernels(last_images, kernels, self.hparams.dilation_rate))
+            if self.hparams.prev_image_background:
+                transformed_images.append(image)
+            if self.hparams.first_image_background and not self.hparams.context_images_background:
+                transformed_images.append(self.inputs['images'][0])
+            if self.hparams.last_image_background and not self.hparams.context_images_background:
+                transformed_images.append(self.inputs['images'][self.hparams.context_frames - 1])
+            if self.hparams.last_context_image_background and not self.hparams.context_images_background:
+                last_context_image = tf.cond(
+                    tf.less(t, self.hparams.context_frames),
+                    lambda: self.inputs['images'][t],
+                    lambda: self.inputs['images'][self.hparams.context_frames - 1])
+                transformed_images.append(last_context_image)
+            if self.hparams.context_images_background:
+                transformed_images.extend(tf.unstack(self.inputs['images'][:self.hparams.context_frames]))
+            if self.hparams.generate_scratch_image:
+                transformed_images.append(scratch_image)
+
+        if 'pix_distribs' in inputs:
+            with tf.name_scope('transformed_pix_distribs'):
+                transformed_pix_distribs = []
+                if self.hparams.last_frames and self.hparams.num_transformed_images:
+                    if self.hparams.transformation == 'flow':
+                        transformed_pix_distribs.extend(apply_flows(last_pix_distribs, flows))
+                    else:
+                        transformed_pix_distribs.extend(apply_kernels(last_pix_distribs, kernels, self.hparams.dilation_rate))
+                if self.hparams.prev_image_background:
+                    transformed_pix_distribs.append(pix_distrib)
+                if self.hparams.first_image_background and not self.hparams.context_images_background:
+                    transformed_pix_distribs.append(self.inputs['pix_distribs'][0])
+                if self.hparams.last_image_background and not self.hparams.context_images_background:
+                    transformed_pix_distribs.append(self.inputs['pix_distribs'][self.hparams.context_frames - 1])
+                if self.hparams.last_context_image_background and not self.hparams.context_images_background:
+                    last_context_pix_distrib = tf.cond(
+                        tf.less(t, self.hparams.context_frames),
+                        lambda: self.inputs['pix_distribs'][t],
+                        lambda: self.inputs['pix_distribs'][self.hparams.context_frames - 1])
+                    transformed_pix_distribs.append(last_context_pix_distrib)
+                if self.hparams.context_images_background:
+                    transformed_pix_distribs.extend(tf.unstack(self.inputs['pix_distribs'][:self.hparams.context_frames]))
+                if self.hparams.generate_scratch_image:
+                    transformed_pix_distribs.append(pix_distrib)
+
+        with tf.name_scope('masks'):
+            if len(transformed_images) > 1:
+                with tf.variable_scope('h%d_masks' % len(layers)):
+                    h_masks = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1))
+                    h_masks = norm_layer(h_masks)
+                    h_masks = activation_layer(h_masks)
+
+                with tf.variable_scope('masks'):
+                    if self.hparams.dependent_mask:
+                        h_masks = tf.concat([h_masks] + transformed_images, axis=-1)
+                    masks = conv2d(h_masks, len(transformed_images), kernel_size=(3, 3), strides=(1, 1))
+                    masks = tf.nn.softmax(masks)
+                    masks = tf.split(masks, len(transformed_images), axis=-1)
+            elif len(transformed_images) == 1:
+                masks = [tf.ones([batch_size, height, width, 1])]
+            else:
+                raise ValueError("Either one of the following should be true: "
+                                 "last_frames and num_transformed_images, first_image_background, "
+                                 "prev_image_background, generate_scratch_image")
+
+        with tf.name_scope('gen_images'):
+            assert len(transformed_images) == len(masks)
+            gen_image = tf.add_n([transformed_image * mask
+                                  for transformed_image, mask in zip(transformed_images, masks)])
+
+        if 'pix_distribs' in inputs:
+            with tf.name_scope('gen_pix_distribs'):
+                assert len(transformed_pix_distribs) == len(masks)
+                gen_pix_distrib = tf.add_n([transformed_pix_distrib * mask
+                                            for transformed_pix_distrib, mask in zip(transformed_pix_distribs, masks)])
+                gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True)
+
+        if 'states' in inputs:
+            with tf.name_scope('gen_states'):
+                with tf.variable_scope('state_pred'):
+                    gen_state = dense(state_action, inputs['states'].shape[-1].value)
+
+        outputs = {'gen_images': gen_image,
+                   'transformed_images': tf.stack(transformed_images, axis=-1),
+                   'masks': tf.stack(masks, axis=-1)}
+        if 'pix_distribs' in inputs:
+            outputs['gen_pix_distribs'] = gen_pix_distrib
+            outputs['transformed_pix_distribs'] = tf.stack(transformed_pix_distribs, axis=-1)
+        if 'states' in inputs:
+            outputs['gen_states'] = gen_state
+        if self.hparams.transformation == 'flow':
+            outputs['gen_flows'] = flows
+            flows_transposed = tf.transpose(flows, [0, 1, 2, 4, 3])
+            flows_rgb_transposed = tf_utils.flow_to_rgb(flows_transposed)
+            flows_rgb = tf.transpose(flows_rgb_transposed, [0, 1, 2, 4, 3])
+            outputs['gen_flows_rgb'] = flows_rgb
+
+        new_states = {'time': time + 1,
+                      'gen_image': gen_image,
+                      'last_images': last_images,
+                      'conv_rnn_states': new_conv_rnn_states}
+        if 'zs' in inputs and self.hparams.use_rnn_z and not self.hparams.ablation_rnn:
+            new_states['rnn_z_state'] = rnn_z_state
+        if 'pix_distribs' in inputs:
+            new_states['gen_pix_distrib'] = gen_pix_distrib
+            new_states['last_pix_distribs'] = last_pix_distribs
+        if 'states' in inputs:
+            new_states['gen_state'] = gen_state
+        return outputs, new_states
+
+
+def generator_given_z_fn(inputs, mode, hparams):
+    # all the inputs needs to have the same length for unrolling the rnn
+    inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
+              for name, input in inputs.items()}
+    cell = SAVPCell(inputs, mode, hparams)
+    outputs, _ = tf_utils.unroll_rnn(cell, inputs)
+    outputs['ground_truth_sampling_mean'] = tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:]))
+    return outputs
+
+
+def generator_fn(inputs, mode, hparams):
+    batch_size = tf.shape(inputs['images'])[1]
+    print("2********mode*****",mode)
+    print("3*******hparams",hparams)
+
+
+    if hparams.nz == 0:
+        # no zs is given in inputs
+        outputs = generator_given_z_fn(inputs, mode, hparams)
+    else:
+        zs_shape = [hparams.sequence_length - 1, batch_size, hparams.nz]
+
+        # posterior
+        with tf.variable_scope('encoder'):
+            outputs_posterior = posterior_fn(inputs, hparams)
+            eps = tf.random_normal(zs_shape, 0, 1)
+            zs_posterior = outputs_posterior['zs_mu'] + tf.sqrt(tf.exp(outputs_posterior['zs_log_sigma_sq'])) * eps
+        inputs_posterior = dict(inputs)
+        inputs_posterior['zs'] = zs_posterior
+
+        # prior
+        if hparams.learn_prior:
+            with tf.variable_scope('prior'):
+                outputs_prior = prior_fn(inputs, hparams)
+            eps = tf.random_normal(zs_shape, 0, 1)
+            zs_prior = outputs_prior['zs_mu'] + tf.sqrt(tf.exp(outputs_prior['zs_log_sigma_sq'])) * eps
+        else:
+            outputs_prior = {}
+            zs_prior = tf.random_normal([hparams.sequence_length - hparams.context_frames] + zs_shape[1:], 0, 1)
+            zs_prior = tf.concat([zs_posterior[:hparams.context_frames - 1], zs_prior], axis=0)
+        inputs_prior = dict(inputs)
+        inputs_prior['zs'] = zs_prior
+
+        # generate
+        gen_outputs_posterior = generator_given_z_fn(inputs_posterior, mode, hparams)
+        tf.get_variable_scope().reuse_variables()
+        gen_outputs = generator_given_z_fn(inputs_prior, mode, hparams)
+
+        # rename tensors to avoid name collisions
+        output_prior = collections.OrderedDict([(k + '_prior', v) for k, v in outputs_prior.items()])
+        outputs_posterior = collections.OrderedDict([(k + '_enc', v) for k, v in outputs_posterior.items()])
+        gen_outputs_posterior = collections.OrderedDict([(k + '_enc', v) for k, v in gen_outputs_posterior.items()])
+
+        outputs = [output_prior, gen_outputs, outputs_posterior, gen_outputs_posterior]
+        total_num_outputs = sum([len(output) for output in outputs])
+        outputs = collections.OrderedDict(itertools.chain(*[output.items() for output in outputs]))
+        assert len(outputs) == total_num_outputs  # ensure no output is lost because of repeated keys
+
+        # generate multiple samples from the prior for visualization purposes
+        inputs_samples = {
+            name: tf.tile(input[:, None], [1, hparams.num_samples] + [1] * (input.shape.ndims - 1))
+            for name, input in inputs.items()}
+        zs_samples_shape = [hparams.sequence_length - 1, hparams.num_samples, batch_size, hparams.nz]
+        if hparams.learn_prior:
+            eps = tf.random_normal(zs_samples_shape, 0, 1)
+            zs_prior_samples = (outputs_prior['zs_mu'][:, None] +
+                                tf.sqrt(tf.exp(outputs_prior['zs_log_sigma_sq']))[:, None] * eps)
+        else:
+            zs_prior_samples = tf.random_normal(
+                [hparams.sequence_length - hparams.context_frames] + zs_samples_shape[1:], 0, 1)
+            zs_prior_samples = tf.concat(
+                [tf.tile(zs_posterior[:hparams.context_frames - 1][:, None], [1, hparams.num_samples, 1, 1]),
+                 zs_prior_samples], axis=0)
+        inputs_prior_samples = dict(inputs_samples)
+        inputs_prior_samples['zs'] = zs_prior_samples
+        inputs_prior_samples = {name: flatten(input, 1, 2) for name, input in inputs_prior_samples.items()}
+        gen_outputs_samples = generator_given_z_fn(inputs_prior_samples, mode, hparams)
+        gen_images_samples = gen_outputs_samples['gen_images']
+        gen_images_samples = tf.stack(tf.split(gen_images_samples, hparams.num_samples, axis=1), axis=-1)
+        gen_images_samples_avg = tf.reduce_mean(gen_images_samples, axis=-1)
+        outputs['gen_images_samples'] = gen_images_samples
+        outputs['gen_images_samples_avg'] = gen_images_samples_avg
+    return outputs
+
+
+class SAVPVideoPredictionModel(VideoPredictionModel):
+    def __init__(self, *args, **kwargs):
+        super(SAVPVideoPredictionModel, self).__init__(
+            generator_fn, discriminator_fn, *args, **kwargs)
+        if self.mode != 'train':
+            self.discriminator_fn = None
+        self.deterministic = not self.hparams.nz
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(SAVPVideoPredictionModel, self).get_default_hparams_dict()
+        hparams = dict(
+            l1_weight=1.0,
+            l2_weight=0.0,
+            n_layers=3,
+            ndf=32,
+            norm_layer='instance',
+            use_same_discriminator=False,
+            ngf=32,
+            downsample_layer='conv_pool2d',
+            upsample_layer='upsample_conv2d',
+            activation_layer='relu',  # for generator only
+            transformation='cdna',
+            kernel_size=(5, 5),
+            dilation_rate=(1, 1),
+            where_add='all',
+            use_tile_concat=True,
+            learn_initial_state=False,
+            rnn='lstm',
+            conv_rnn='lstm',
+            conv_rnn_norm_layer='instance',
+            num_transformed_images=4,
+            last_frames=1,
+            prev_image_background=True,
+            first_image_background=True,
+            last_image_background=False,
+            last_context_image_background=False,
+            context_images_background=False,
+            generate_scratch_image=True,
+            dependent_mask=True,
+            schedule_sampling='inverse_sigmoid',
+            schedule_sampling_k=900.0,
+            schedule_sampling_steps=(0, 100000),
+            use_e_rnn=False,
+            learn_prior=False,
+            nz=8,
+            num_samples=8,
+            nef=64,
+            use_rnn_z=True,
+            ablation_conv_rnn_norm=False,
+            ablation_rnn=False,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def parse_hparams(self, hparams_dict, hparams):
+        # backwards compatibility
+        deprecated_hparams_keys = [
+            'num_gpus',
+            'e_net',
+            'd_conditional',
+            'd_downsample_layer',
+            'd_net',
+            'd_use_gt_inputs',
+            'acvideo_gan_weight',
+            'acvideo_vae_gan_weight',
+            'image_gan_weight',
+            'image_vae_gan_weight',
+            'tuple_gan_weight',
+            'tuple_vae_gan_weight',
+            'gan_weight',
+            'vae_gan_weight',
+            'video_gan_weight',
+            'video_vae_gan_weight',
+        ]
+        for deprecated_hparams_key in deprecated_hparams_keys:
+            hparams_dict.pop(deprecated_hparams_key, None)
+        return super(SAVPVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)
+
+    def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
+        def restore_to_checkpoint_mapping(restore_name, checkpoint_var_names):
+            restore_name = restore_name.split(':')[0]
+            if restore_name not in checkpoint_var_names:
+                restore_name = restore_name.replace('savp_cell', 'dna_cell')
+            return restore_name
+
+        super(SAVPVideoPredictionModel, self).restore(sess, checkpoints, restore_to_checkpoint_mapping)
+
+
+def apply_dna_kernels(image, kernels, dilation_rate=(1, 1)):
+    """
+    Args:
+        image: A 4-D tensor of shape
+            `[batch, in_height, in_width, in_channels]`.
+        kernels: A 6-D of shape
+            `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`.
+    Returns:
+        A list of `num_transformed_images` 4-D tensors, each of shape
+            `[batch, in_height, in_width, in_channels]`.
+    """
+    dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2
+    batch_size, height, width, color_channels = image.get_shape().as_list()
+    batch_size, height, width, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list()
+    kernel_size = [kernel_height, kernel_width]
+
+    # Flatten the spatial dimensions.
+    kernels_reshaped = tf.reshape(kernels, [batch_size, height, width,
+                                            kernel_size[0] * kernel_size[1], num_transformed_images])
+    image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC')
+    # Combine channel and batch dimensions into the first dimension.
+    image_transposed = tf.transpose(image_padded, [3, 0, 1, 2])
+    image_reshaped = flatten(image_transposed, 0, 1)[..., None]
+    patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1],
+                                                strides=[1] * 4, rates=[1] + dilation_rate + [1], padding='VALID')
+    # Separate channel and batch dimensions, and move channel dimension.
+    patches_transposed = tf.reshape(patches_reshaped, [color_channels, batch_size, height, width, kernel_size[0] * kernel_size[1]])
+    patches = tf.transpose(patches_transposed, [1, 2, 3, 0, 4])
+    # Reduce along the spatial dimensions of the kernel.
+    outputs = tf.matmul(patches, kernels_reshaped)
+    outputs = tf.unstack(outputs, axis=-1)
+    return outputs
+
+
+def apply_cdna_kernels(image, kernels, dilation_rate=(1, 1)):
+    """
+    Args:
+        image: A 4-D tensor of shape
+            `[batch, in_height, in_width, in_channels]`.
+        kernels: A 4-D of shape
+            `[batch, kernel_size[0], kernel_size[1], num_transformed_images]`.
+    Returns:
+        A list of `num_transformed_images` 4-D tensors, each of shape
+            `[batch, in_height, in_width, in_channels]`.
+    """
+    batch_size, height, width, color_channels = image.get_shape().as_list()
+    batch_size, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list()
+    kernel_size = [kernel_height, kernel_width]
+    image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC')
+    # Treat the color channel dimension as the batch dimension since the same
+    # transformation is applied to each color channel.
+    # Treat the batch dimension as the channel dimension so that
+    # depthwise_conv2d can apply a different transformation to each sample.
+    kernels = tf.transpose(kernels, [1, 2, 0, 3])
+    kernels = tf.reshape(kernels, [kernel_size[0], kernel_size[1], batch_size, num_transformed_images])
+    # Swap the batch and channel dimensions.
+    image_transposed = tf.transpose(image_padded, [3, 1, 2, 0])
+    # Transform image.
+    outputs = tf.nn.depthwise_conv2d(image_transposed, kernels, [1, 1, 1, 1], padding='VALID', rate=dilation_rate)
+    # Transpose the dimensions to where they belong.
+    outputs = tf.reshape(outputs, [color_channels, height, width, batch_size, num_transformed_images])
+    outputs = tf.transpose(outputs, [4, 3, 1, 2, 0])
+    outputs = tf.unstack(outputs, axis=0)
+    return outputs
+
+
+def apply_kernels(image, kernels, dilation_rate=(1, 1)):
+    """
+    Args:
+        image: A 4-D tensor of shape
+            `[batch, in_height, in_width, in_channels]`.
+        kernels: A 4-D or 6-D tensor of shape
+            `[batch, kernel_size[0], kernel_size[1], num_transformed_images]` or
+            `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`.
+    Returns:
+        A list of `num_transformed_images` 4-D tensors, each of shape
+            `[batch, in_height, in_width, in_channels]`.
+    """
+    if isinstance(image, list):
+        image_list = image
+        kernels_list = tf.split(kernels, len(image_list), axis=-1)
+        outputs = []
+        for image, kernels in zip(image_list, kernels_list):
+            outputs.extend(apply_kernels(image, kernels))
+    else:
+        if len(kernels.get_shape()) == 4:
+            outputs = apply_cdna_kernels(image, kernels, dilation_rate=dilation_rate)
+        elif len(kernels.get_shape()) == 6:
+            outputs = apply_dna_kernels(image, kernels, dilation_rate=dilation_rate)
+        else:
+            raise ValueError
+    return outputs
+
+
+def apply_flows(image, flows):
+    if isinstance(image, list):
+        image_list = image
+        flows_list = tf.split(flows, len(image_list), axis=-1)
+        outputs = []
+        for image, flows in zip(image_list, flows_list):
+            outputs.extend(apply_flows(image, flows))
+    else:
+        flows = tf.unstack(flows, axis=-1)
+        outputs = [flow_ops.image_warp(image, flow) for flow in flows]
+    return outputs
+
+
+def identity_kernel(kernel_size):
+    kh, kw = kernel_size
+    kernel = np.zeros(kernel_size)
+
+    def center_slice(k):
+        if k % 2 == 0:
+            return slice(k // 2 - 1, k // 2 + 1)
+        else:
+            return slice(k // 2, k // 2 + 1)
+
+    kernel[center_slice(kh), center_slice(kw)] = 1.0
+    kernel /= np.sum(kernel)
+    return kernel
+
+
+def _maybe_tile_concat_layer(conv2d_layer):
+    def layer(inputs, out_channels, *args, **kwargs):
+        if isinstance(inputs, (list, tuple)):
+            inputs_spatial, inputs_non_spatial = inputs
+            outputs = (conv2d_layer(inputs_spatial, out_channels, *args, **kwargs) +
+                       dense(inputs_non_spatial, out_channels, use_bias=False)[:, None, None, :])
+        else:
+            outputs = conv2d_layer(inputs, out_channels, *args, **kwargs)
+        return outputs
+
+    return layer
diff --git a/video_prediction/models/sna_model.py b/video_prediction/models/sna_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb04deafc73f49d0466acd74ac4a43d94ac72f0
--- /dev/null
+++ b/video_prediction/models/sna_model.py
@@ -0,0 +1,667 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Model architecture for predictive model, including CDNA, DNA, and STP."""
+
+import itertools
+
+import numpy as np
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+from tensorflow.contrib.layers.python import layers as tf_layers
+from tensorflow.contrib.slim import add_arg_scope
+from tensorflow.contrib.slim import layers
+
+from video_prediction.models import VideoPredictionModel
+
+
+# Amount to use when lower bounding tensors
+RELU_SHIFT = 1e-12
+
+
+@add_arg_scope
+def basic_conv_lstm_cell(inputs,
+                         state,
+                         num_channels,
+                         filter_size=5,
+                         forget_bias=1.0,
+                         scope=None,
+                         reuse=None,
+                         ):
+    """Basic LSTM recurrent network cell, with 2D convolution connctions.
+    We add forget_bias (default: 1) to the biases of the forget gate in order to
+    reduce the scale of forgetting in the beginning of the training.
+    It does not allow cell clipping, a projection layer, and does not
+    use peep-hole connections: it is the basic baseline.
+    Args:
+        inputs: input Tensor, 4D, batch x height x width x channels.
+        state: state Tensor, 4D, batch x height x width x channels.
+        num_channels: the number of output channels in the layer.
+        filter_size: the shape of the each convolution filter.
+        forget_bias: the initial value of the forget biases.
+        scope: Optional scope for variable_scope.
+        reuse: whether or not the layer and the variables should be reused.
+    Returns:
+        a tuple of tensors representing output and the new state.
+    """
+    if state is None:
+        state = tf.zeros(inputs.get_shape().as_list()[:3] + [2 * num_channels], name='init_state')
+
+    with tf.variable_scope(scope,
+                           'BasicConvLstmCell',
+                           [inputs, state],
+                           reuse=reuse):
+
+        inputs.get_shape().assert_has_rank(4)
+        state.get_shape().assert_has_rank(4)
+        c, h = tf.split(axis=3, num_or_size_splits=2, value=state)
+        inputs_h = tf.concat(values=[inputs, h], axis=3)
+        # Parameters of gates are concatenated into one conv for efficiency.
+        i_j_f_o = layers.conv2d(inputs_h,
+                                4 * num_channels, [filter_size, filter_size],
+                                stride=1,
+                                activation_fn=None,
+                                scope='Gates',
+                                )
+
+        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+        i, j, f, o = tf.split(value=i_j_f_o, num_or_size_splits=4, axis=3)
+
+        new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j)
+        new_h = tf.tanh(new_c) * tf.sigmoid(o)
+
+        return new_h, tf.concat(values=[new_c, new_h], axis=3)
+
+
+class Prediction_Model(object):
+
+    def __init__(self,
+                 images,
+                 actions=None,
+                 states=None,
+                 iter_num=-1.0,
+                 pix_distributions1=None,
+                 pix_distributions2=None,
+                 conf=None):
+
+        self.pix_distributions1 = pix_distributions1
+        self.pix_distributions2 = pix_distributions2
+        self.actions = actions
+        self.iter_num = iter_num
+        self.conf = conf
+        self.images = images
+
+        self.cdna, self.stp, self.dna = False, False, False
+        if self.conf['model'] == 'CDNA':
+            self.cdna = True
+        elif self.conf['model'] == 'DNA':
+            self.dna = True
+        elif self.conf['model'] == 'STP':
+            self.stp = True
+        if self.stp + self.cdna + self.dna != 1:
+            raise ValueError("More than one option selected!")
+
+        self.k = conf['schedsamp_k']
+        self.use_state = conf['use_state']
+        self.num_masks = conf['num_masks']
+        self.context_frames = conf['context_frames']
+
+        self.batch_size, self.img_height, self.img_width, self.color_channels = [int(i) for i in
+                                                                                 images[0].get_shape()[0:4]]
+        self.lstm_func = basic_conv_lstm_cell
+
+        # Generated robot states and images.
+        self.gen_states = []
+        self.gen_images = []
+        self.gen_masks = []
+
+        self.moved_images = []
+
+        self.moved_pix_distrib1 = []
+        self.moved_pix_distrib2 = []
+
+        self.states = states
+        self.gen_distrib1 = []
+        self.gen_distrib2 = []
+
+        self.trafos = []
+
+    def build(self):
+
+        if 'kern_size' in self.conf.keys():
+            KERN_SIZE = self.conf['kern_size']
+        else:
+            KERN_SIZE = 5
+
+        batch_size, img_height, img_width, color_channels = self.images[0].get_shape()[0:4]
+        lstm_func = basic_conv_lstm_cell
+
+
+        if self.states != None:
+            current_state = self.states[0]
+        else:
+            current_state = None
+
+        if self.actions == None:
+            self.actions = [None for _ in self.images]
+
+        if self.k == -1:
+            feedself = True
+        else:
+            # Scheduled sampling:
+            # Calculate number of ground-truth frames to pass in.
+            num_ground_truth = tf.to_int32(
+                tf.round(tf.to_float(batch_size) * (self.k / (self.k + tf.exp(self.iter_num / self.k)))))
+            feedself = False
+
+        # LSTM state sizes and states.
+
+        if 'lstm_size' in self.conf:
+            lstm_size = self.conf['lstm_size']
+            print('using lstm size', lstm_size)
+        else:
+            ngf = self.conf['ngf']
+            lstm_size = np.int32(np.array([ngf, ngf * 2, ngf * 4, ngf * 2, ngf]))
+
+
+        lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None
+        lstm_state5, lstm_state6, lstm_state7 = None, None, None
+
+        for t, action in enumerate(self.actions):
+            print(t)
+            # Reuse variables after the first timestep.
+            reuse = bool(self.gen_images)
+
+            done_warm_start = len(self.gen_images) > self.context_frames - 1
+            with slim.arg_scope(
+                    [lstm_func, slim.layers.conv2d, slim.layers.fully_connected,
+                     tf_layers.layer_norm, slim.layers.conv2d_transpose],
+                    reuse=reuse):
+
+                if feedself and done_warm_start:
+                    # Feed in generated image.
+                    prev_image = self.gen_images[-1]             # 64x64x6
+                    if self.pix_distributions1 != None:
+                        prev_pix_distrib1 = self.gen_distrib1[-1]
+                        if 'ndesig' in self.conf:
+                            prev_pix_distrib2 = self.gen_distrib2[-1]
+                elif done_warm_start:
+                    # Scheduled sampling
+                    prev_image = scheduled_sample(self.images[t], self.gen_images[-1], batch_size,
+                                                  num_ground_truth)
+                else:
+                    # Always feed in ground_truth
+                    prev_image = self.images[t]
+                    if self.pix_distributions1 != None:
+                        prev_pix_distrib1 = self.pix_distributions1[t]
+                        if 'ndesig' in self.conf:
+                            prev_pix_distrib2 = self.pix_distributions2[t]
+                        if len(prev_pix_distrib1.get_shape()) == 3:
+                            prev_pix_distrib1 = tf.expand_dims(prev_pix_distrib1, -1)
+                            if 'ndesig' in self.conf:
+                                prev_pix_distrib2 = tf.expand_dims(prev_pix_distrib2, -1)
+
+                if 'refeed_firstimage' in self.conf:
+                    assert self.conf['model']=='STP'
+                    if t > 1:
+                        input_image = self.images[1]
+                        print('refeed with image 1')
+                    else:
+                        input_image = prev_image
+                else:
+                    input_image = prev_image
+
+                # Predicted state is always fed back in
+                if not 'ignore_state_action' in self.conf:
+                    state_action = tf.concat(axis=1, values=[action, current_state])
+
+                enc0 = slim.layers.conv2d(    #32x32x32
+                    input_image,
+                    32, [5, 5],
+                    stride=2,
+                    scope='scale1_conv1',
+                    normalizer_fn=tf_layers.layer_norm,
+                    normalizer_params={'scope': 'layer_norm1'})
+
+                hidden1, lstm_state1 = lstm_func(       # 32x32x16
+                    enc0, lstm_state1, lstm_size[0], scope='state1')
+                hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')
+
+                enc1 = slim.layers.conv2d(     # 16x16x16
+                    hidden1, hidden1.get_shape()[3], [3, 3], stride=2, scope='conv2')
+
+                hidden3, lstm_state3 = lstm_func(   #16x16x32
+                    enc1, lstm_state3, lstm_size[1], scope='state3')
+                hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4')
+
+                enc2 = slim.layers.conv2d(  # 8x8x32
+                    hidden3, hidden3.get_shape()[3], [3, 3], stride=2, scope='conv3')
+
+                if not 'ignore_state_action' in self.conf:
+                    # Pass in state and action.
+                    if 'ignore_state' in self.conf:
+                        lowdim = action
+                        print('ignoring state')
+                    else:
+                        lowdim = state_action
+
+                    smear = tf.reshape(
+                        lowdim,
+                        [int(batch_size), 1, 1, int(lowdim.get_shape()[1])])
+                    smear = tf.tile(
+                        smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1])
+
+                    enc2 = tf.concat(axis=3, values=[enc2, smear])
+                else:
+                    print('ignoring states and actions')
+
+                enc3 = slim.layers.conv2d(   #8x8x32
+                    enc2, hidden3.get_shape()[3], [1, 1], stride=1, scope='conv4')
+
+                hidden5, lstm_state5 = lstm_func(  #8x8x64
+                    enc3, lstm_state5, lstm_size[2], scope='state5')
+                hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6')
+                enc4 = slim.layers.conv2d_transpose(  #16x16x64
+                    hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1')
+
+                hidden6, lstm_state6 = lstm_func(  #16x16x32
+                    enc4, lstm_state6, lstm_size[3], scope='state6')
+                hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7')
+
+                if 'noskip' not in self.conf:
+                    # Skip connection.
+                    hidden6 = tf.concat(axis=3, values=[hidden6, enc1])  # both 16x16
+
+                enc5 = slim.layers.conv2d_transpose(  #32x32x32
+                    hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2')
+                hidden7, lstm_state7 = lstm_func( # 32x32x16
+                    enc5, lstm_state7, lstm_size[4], scope='state7')
+                hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8')
+
+                if not 'noskip' in self.conf:
+                    # Skip connection.
+                    hidden7 = tf.concat(axis=3, values=[hidden7, enc0])  # both 32x32
+
+                enc6 = slim.layers.conv2d_transpose(   # 64x64x16
+                    hidden7,
+                    hidden7.get_shape()[3], 3, stride=2, scope='convt3',
+                    normalizer_fn=tf_layers.layer_norm,
+                    normalizer_params={'scope': 'layer_norm9'})
+
+                if 'transform_from_firstimage' in self.conf:
+                    prev_image = self.images[1]
+                    if self.pix_distributions1 != None:
+                        prev_pix_distrib1 = self.pix_distributions1[1]
+                        prev_pix_distrib1 = tf.expand_dims(prev_pix_distrib1, -1)
+                    print('transform from image 1')
+
+                if self.conf['model'] == 'DNA':
+                    # Using largest hidden state for predicting untied conv kernels.
+                    trafo_input = slim.layers.conv2d_transpose(
+                        enc6, KERN_SIZE ** 2, 1, stride=1, scope='convt4_cam2')
+
+                    transformed_l = [self.dna_transformation(prev_image, trafo_input, self.conf['kern_size'])]
+                    if self.pix_distributions1 != None:
+                        transf_distrib_ndesig1 = [self.dna_transformation(prev_pix_distrib1, trafo_input, KERN_SIZE)]
+                        if 'ndesig' in self.conf:
+                            transf_distrib_ndesig2 = [
+                                self.dna_transformation(prev_pix_distrib2, trafo_input, KERN_SIZE)]
+
+
+                    extra_masks = 1  ## extra_masks = 2 is needed for running singleview_shifted!!
+                    # print('using extra masks 2 because of single view shifted!!')
+                    # extra_masks = 2
+
+                if self.conf['model'] == 'CDNA':
+                    if 'gen_pix' in self.conf:
+                        # Using largest hidden state for predicting a new image layer.
+                        enc7 = slim.layers.conv2d_transpose(
+                            enc6, color_channels, 1, stride=1, scope='convt4', activation_fn=None)
+                        # This allows the network to also generate one image from scratch,
+                        # which is useful when regions of the image become unoccluded.
+                        transformed_l = [tf.nn.sigmoid(enc7)]
+                        extra_masks = 2
+                    else:
+                        transformed_l = []
+                        extra_masks = 1
+
+                    cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
+                    new_transformed, _ = self.cdna_transformation(prev_image,
+                                                            cdna_input,
+                                                            reuse_sc=reuse)
+                    transformed_l += new_transformed
+                    self.moved_images.append(transformed_l)
+
+                    if self.pix_distributions1 != None:
+                        transf_distrib_ndesig1, _ = self.cdna_transformation(prev_pix_distrib1,
+                                                                       cdna_input,
+                                                                         reuse_sc=True)
+                        self.moved_pix_distrib1.append(transf_distrib_ndesig1)
+                        if 'ndesig' in self.conf:
+                            transf_distrib_ndesig2, _ = self.cdna_transformation(
+                                                                               prev_pix_distrib2,
+                                                                               cdna_input,
+                                                                               reuse_sc=True)
+
+                            self.moved_pix_distrib2.append(transf_distrib_ndesig2)
+
+                if self.conf['model'] == 'STP':
+                    enc7 = slim.layers.conv2d_transpose(enc6, color_channels, 1, stride=1, scope='convt5', activation_fn= None)
+                    # This allows the network to also generate one image from scratch,
+                    # which is useful when regions of the image become unoccluded.
+                    if 'gen_pix' in self.conf:
+                        transformed_l = [tf.nn.sigmoid(enc7)]
+                        extra_masks = 2
+                    else:
+                        transformed_l = []
+                        extra_masks = 1
+
+                    enc_stp = tf.reshape(hidden5, [int(batch_size), -1])
+                    stp_input = slim.layers.fully_connected(
+                        enc_stp, 200, scope='fc_stp_cam2')
+
+                    # disabling capability to generete pixels
+                    reuse_stp = None
+                    if reuse:
+                        reuse_stp = reuse
+
+                    # enable the generation of pixels:
+                    transformed, trafo = self.stp_transformation(prev_image, stp_input, self.num_masks, reuse_stp, suffix='cam2')
+                    transformed_l += transformed
+
+                    self.trafos.append(trafo)
+                    self.moved_images.append(transformed_l)
+
+                    if self.pix_distributions1 != None:
+                        transf_distrib_ndesig1, _ = self.stp_transformation(prev_pix_distrib1, stp_input, suffix='cam2', reuse=True)
+                        self.moved_pix_distrib1.append(transf_distrib_ndesig1)
+
+                if '1stimg_bckgd' in self.conf:
+                    background = self.images[0]
+                    print('using background from first image..')
+                else: background = prev_image
+                output, mask_list = self.fuse_trafos(enc6, background,
+                                                     transformed_l,
+                                                     scope='convt7_cam2',
+                                                     extra_masks= extra_masks)
+                self.gen_images.append(output)
+                self.gen_masks.append(mask_list)
+
+                if self.pix_distributions1!=None:
+                    pix_distrib_output = self.fuse_pix_distrib(extra_masks,
+                                                                mask_list,
+                                                                self.pix_distributions1,
+                                                                prev_pix_distrib1,
+                                                                transf_distrib_ndesig1)
+
+                    self.gen_distrib1.append(pix_distrib_output)
+                    if 'ndesig' in self.conf:
+                        pix_distrib_output = self.fuse_pix_distrib(extra_masks,
+                                                                    mask_list,
+                                                                    self.pix_distributions2,
+                                                                    prev_pix_distrib2,
+                                                                    transf_distrib_ndesig2)
+
+                        self.gen_distrib2.append(pix_distrib_output)
+
+                if int(current_state.get_shape()[1]) == 0:
+                    current_state = tf.zeros_like(state_action)
+                else:
+                    current_state = slim.layers.fully_connected(
+                        state_action,
+                        int(current_state.get_shape()[1]),
+                        scope='state_pred',
+                        activation_fn=None)
+
+                self.gen_states.append(current_state)
+
+    def fuse_trafos(self, enc6, background_image, transformed, scope, extra_masks):
+        masks = slim.layers.conv2d_transpose(
+            enc6, (self.conf['num_masks']+ extra_masks), 1, stride=1, activation_fn=None, scope=scope)
+
+        img_height = 64
+        img_width = 64
+        num_masks = self.conf['num_masks']
+
+        if self.conf['model']=='DNA':
+            if num_masks != 1:
+                raise ValueError('Only one mask is supported for DNA model.')
+
+        # the total number of masks is num_masks +extra_masks because of background and generated pixels!
+        masks = tf.reshape(
+            tf.nn.softmax(tf.reshape(masks, [-1, num_masks +extra_masks])),
+            [int(self.batch_size), int(img_height), int(img_width), num_masks +extra_masks])
+        mask_list = tf.split(axis=3, num_or_size_splits=num_masks +extra_masks, value=masks)
+        output = mask_list[0] * background_image
+
+        assert len(transformed) == len(mask_list[1:])
+        for layer, mask in zip(transformed, mask_list[1:]):
+            output += layer * mask
+
+        return output, mask_list
+
+    def fuse_pix_distrib(self, extra_masks, mask_list, pix_distributions, prev_pix_distrib,
+                         transf_distrib):
+
+        if '1stimg_bckgd' in self.conf:
+            background_pix = pix_distributions[0]
+            if len(background_pix.get_shape()) == 3:
+                background_pix = tf.expand_dims(background_pix, -1)
+            print('using pix_distrib-background from first image..')
+        else:
+            background_pix = prev_pix_distrib
+        pix_distrib_output = mask_list[0] * background_pix
+        if 'gen_pix' in self.conf:
+            pix_distrib_output += mask_list[1] * prev_pix_distrib  # assume pixels don't when image is generated from scratch
+        for i in range(self.num_masks):
+            pix_distrib_output += transf_distrib[i] * mask_list[i + extra_masks]
+        pix_distrib_output /= tf.reduce_sum(pix_distrib_output, axis=(1, 2), keepdims=True)
+        return pix_distrib_output
+
+    ## Utility functions
+    def stp_transformation(self, prev_image, stp_input, num_masks, reuse= None, suffix = None):
+        """Apply spatial transformer predictor (STP) to previous image.
+
+        Args:
+          prev_image: previous image to be transformed.
+          stp_input: hidden layer to be used for computing STN parameters.
+          num_masks: number of masks and hence the number of STP transformations.
+        Returns:
+          List of images transformed by the predicted STP parameters.
+        """
+        # Only import spatial transformer if needed.
+        from spatial_transformer import transformer
+
+        identity_params = tf.convert_to_tensor(
+            np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32))
+        transformed = []
+        trafos = []
+        for i in range(num_masks):
+            params = slim.layers.fully_connected(
+                stp_input, 6, scope='stp_params' + str(i) + suffix,
+                activation_fn=None,
+                reuse= reuse) + identity_params
+            outsize = (prev_image.get_shape()[1], prev_image.get_shape()[2])
+            transformed.append(transformer(prev_image, params, outsize))
+            trafos.append(params)
+
+        return transformed, trafos
+
+    def dna_transformation(self, prev_image, dna_input, DNA_KERN_SIZE):
+        """Apply dynamic neural advection to previous image.
+
+        Args:
+          prev_image: previous image to be transformed.
+          dna_input: hidden lyaer to be used for computing DNA transformation.
+        Returns:
+          List of images transformed by the predicted CDNA kernels.
+        """
+        # Construct translated images.
+        pad_len = int(np.floor(DNA_KERN_SIZE / 2))
+        prev_image_pad = tf.pad(prev_image, [[0, 0], [pad_len, pad_len], [pad_len, pad_len], [0, 0]])
+        image_height = int(prev_image.get_shape()[1])
+        image_width = int(prev_image.get_shape()[2])
+
+        inputs = []
+        for xkern in range(DNA_KERN_SIZE):
+            for ykern in range(DNA_KERN_SIZE):
+                inputs.append(
+                    tf.expand_dims(
+                        tf.slice(prev_image_pad, [0, xkern, ykern, 0],
+                                 [-1, image_height, image_width, -1]), [3]))
+        inputs = tf.concat(axis=3, values=inputs)
+
+        # Normalize channels to 1.
+        kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT
+        kernel = tf.expand_dims(
+            kernel / tf.reduce_sum(
+                kernel, [3], keepdims=True), [4])
+
+        return tf.reduce_sum(kernel * inputs, [3], keepdims=False)
+
+    def cdna_transformation(self, prev_image, cdna_input, reuse_sc=None):
+        """Apply convolutional dynamic neural advection to previous image.
+
+        Args:
+          prev_image: previous image to be transformed.
+          cdna_input: hidden lyaer to be used for computing CDNA kernels.
+          num_masks: the number of masks and hence the number of CDNA transformations.
+          color_channels: the number of color channels in the images.
+        Returns:
+          List of images transformed by the predicted CDNA kernels.
+        """
+        batch_size = int(cdna_input.get_shape()[0])
+        height = int(prev_image.get_shape()[1])
+        width = int(prev_image.get_shape()[2])
+
+        DNA_KERN_SIZE = self.conf['kern_size']
+        num_masks = self.conf['num_masks']
+        color_channels = int(prev_image.get_shape()[3])
+
+        # Predict kernels using linear function of last hidden layer.
+        cdna_kerns = slim.layers.fully_connected(
+            cdna_input,
+            DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks,
+            scope='cdna_params',
+            activation_fn=None,
+            reuse = reuse_sc)
+
+        # Reshape and normalize.
+        cdna_kerns = tf.reshape(
+            cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks])
+        cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT
+        norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True)
+        cdna_kerns /= norm_factor
+        cdna_kerns_summary = cdna_kerns
+
+        # Transpose and reshape.
+        cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3])
+        cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks])
+        prev_image = tf.transpose(prev_image, [3, 1, 2, 0])
+
+        transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME')
+
+        # Transpose and reshape.
+        transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks])
+        transformed = tf.transpose(transformed, [3, 1, 2, 0, 4])
+        transformed = tf.unstack(value=transformed, axis=-1)
+
+        return transformed, cdna_kerns_summary
+
+
+def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
+    """Sample batch with specified mix of ground truth and generated data_files points.
+
+    Args:
+      ground_truth_x: tensor of ground-truth data_files points.
+      generated_x: tensor of generated data_files points.
+      batch_size: batch size
+      num_ground_truth: number of ground-truth examples to include in batch.
+    Returns:
+      New batch with num_ground_truth sampled from ground_truth_x and the rest
+      from generated_x.
+    """
+    idx = tf.random_shuffle(tf.range(int(batch_size)))
+    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
+    generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))
+
+    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
+    generated_examps = tf.gather(generated_x, generated_idx)
+    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
+                             [ground_truth_examps, generated_examps])
+
+
+def generator_fn(inputs, mode, hparams):
+    images = tf.unstack(inputs['images'], axis=0)
+    actions = tf.unstack(inputs['actions'], axis=0)
+    states = tf.unstack(inputs['states'], axis=0)
+    pix_distributions1 = tf.unstack(inputs['pix_distribs'], axis=0) if 'pix_distribs' in inputs else None
+    iter_num = tf.to_float(tf.train.get_or_create_global_step())
+
+    if isinstance(hparams.kernel_size, (tuple, list)):
+        kernel_height, kernel_width = hparams.kernel_size
+        assert kernel_height == kernel_width
+        kern_size = kernel_height
+    else:
+        kern_size = hparams.kernel_size
+
+    schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1
+    conf = {
+        'context_frames': hparams.context_frames,  # of frames before predictions.' ,
+        'use_state': 1,  # 'Whether or not to give the state+action to the model' ,
+        'ngf': hparams.ngf,
+        'model': hparams.transformation.upper(),  # 'model architecture to use - CDNA, DNA, or STP' ,
+        'num_masks': hparams.num_masks,  # 'number of masks, usually 1 for DNA, 10 for CDNA, STN.' ,
+        'schedsamp_k': schedule_sampling_k,  # 'The k hyperparameter for scheduled sampling -1 for no scheduled sampling.' ,
+        'kern_size': kern_size,  # size of DNA kerns
+    }
+    if hparams.first_image_background:
+        conf['1stimg_bckgd'] = ''
+    if hparams.generate_scratch_image:
+        conf['gen_pix'] = ''
+
+    m = Prediction_Model(images, actions, states,
+                         pix_distributions1=pix_distributions1,
+                         iter_num=iter_num, conf=conf)
+    m.build()
+    outputs = {
+        'gen_images': tf.stack(m.gen_images, axis=0),
+        'gen_states': tf.stack(m.gen_states, axis=0),
+    }
+    if 'pix_distribs' in inputs:
+        outputs['gen_pix_distribs'] = tf.stack(m.gen_distrib1, axis=0)
+    return outputs
+
+
+class SNAVideoPredictionModel(VideoPredictionModel):
+    def __init__(self, *args, **kwargs):
+        super(SNAVideoPredictionModel, self).__init__(
+            generator_fn, *args, **kwargs)
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(SNAVideoPredictionModel, self).get_default_hparams_dict()
+        hparams = dict(
+            batch_size=32,
+            l1_weight=0.0,
+            l2_weight=1.0,
+            ngf=16,
+            transformation='cdna',
+            kernel_size=(5, 5),
+            num_masks=10,
+            first_image_background=True,
+            generate_scratch_image=True,
+            schedule_sampling_k=900.0,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
diff --git a/video_prediction/models/sv2p_model.py b/video_prediction/models/sv2p_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7a06364178dc380456e47569787ab693ad121de
--- /dev/null
+++ b/video_prediction/models/sv2p_model.py
@@ -0,0 +1,678 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Model architecture for predictive model, including CDNA, DNA, and STP."""
+
+import itertools
+import numpy as np
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+from tensorflow.contrib.layers.python import layers as tf_layers
+from tensorflow.contrib.slim import add_arg_scope
+from tensorflow.contrib.slim import layers
+
+from video_prediction.models import VideoPredictionModel
+
+
+# Amount to use when lower bounding tensors
+RELU_SHIFT = 1e-12
+
+# kernel size for DNA and CDNA.
+DNA_KERN_SIZE = 5
+
+
+def init_state(inputs,
+               state_shape,
+               state_initializer=tf.zeros_initializer(),
+               dtype=tf.float32):
+    """Helper function to create an initial state given inputs.
+    Args:
+        inputs: input Tensor, at least 2D, the first dimension being batch_size
+        state_shape: the shape of the state.
+        state_initializer: Initializer(shape, dtype) for state Tensor.
+        dtype: Optional dtype, needed when inputs is None.
+    Returns:
+        A tensors representing the initial state.
+    """
+    if inputs is not None:
+        # Handle both the dynamic shape as well as the inferred shape.
+        inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0]
+        dtype = inputs.dtype
+    else:
+        inferred_batch_size = 0
+    initial_state = state_initializer(
+        [inferred_batch_size] + state_shape, dtype=dtype)
+    return initial_state
+
+
+@add_arg_scope
+def basic_conv_lstm_cell(inputs,
+                         state,
+                         num_channels,
+                         filter_size=5,
+                         forget_bias=1.0,
+                         scope=None,
+                         reuse=None):
+    """Basic LSTM recurrent network cell, with 2D convolution connctions.
+    We add forget_bias (default: 1) to the biases of the forget gate in order to
+    reduce the scale of forgetting in the beginning of the training.
+    It does not allow cell clipping, a projection layer, and does not
+    use peep-hole connections: it is the basic baseline.
+    Args:
+        inputs: input Tensor, 4D, batch x height x width x channels.
+        state: state Tensor, 4D, batch x height x width x channels.
+        num_channels: the number of output channels in the layer.
+        filter_size: the shape of the each convolution filter.
+        forget_bias: the initial value of the forget biases.
+        scope: Optional scope for variable_scope.
+        reuse: whether or not the layer and the variables should be reused.
+    Returns:
+         a tuple of tensors representing output and the new state.
+    """
+    spatial_size = inputs.get_shape()[1:3]
+    if state is None:
+        state = init_state(inputs, list(spatial_size) + [2 * num_channels])
+    with tf.variable_scope(scope,
+                           'BasicConvLstmCell',
+                           [inputs, state],
+                           reuse=reuse):
+        inputs.get_shape().assert_has_rank(4)
+        state.get_shape().assert_has_rank(4)
+        c, h = tf.split(axis=3, num_or_size_splits=2, value=state)
+        inputs_h = tf.concat(axis=3, values=[inputs, h])
+        # Parameters of gates are concatenated into one conv for efficiency.
+        i_j_f_o = layers.conv2d(inputs_h,
+                                4 * num_channels, [filter_size, filter_size],
+                                stride=1,
+                                activation_fn=None,
+                                scope='Gates')
+
+        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+        i, j, f, o = tf.split(axis=3, num_or_size_splits=4, value=i_j_f_o)
+
+        new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j)
+        new_h = tf.tanh(new_c) * tf.sigmoid(o)
+
+        return new_h, tf.concat(axis=3, values=[new_c, new_h])
+
+
+def kl_divergence(mu, log_sigma):
+    """KL divergence of diagonal gaussian N(mu,exp(log_sigma)) and N(0,1).
+
+    Args:
+        mu: mu parameter of the distribution.
+        log_sigma: log(sigma) parameter of the distribution.
+    Returns:
+        the KL loss.
+    """
+
+    return -.5 * tf.reduce_sum(1. + log_sigma - tf.square(mu) - tf.exp(log_sigma),
+                               axis=1)
+
+
+def construct_latent_tower(images, hparams):
+    """Builds convolutional latent tower for stochastic model.
+
+    At training time this tower generates a latent distribution (mean and std)
+    conditioned on the entire video. This latent variable will be fed to the
+    main tower as an extra variable to be used for future frames prediction.
+    At inference time, the tower is disabled and only returns latents sampled
+    from N(0,1).
+    If the multi_latent flag is on, a different latent for every timestep would
+    be generated.
+
+    Args:
+        images: tensor of ground truth image sequences
+    Returns:
+        latent_mean: predicted latent mean
+        latent_std: predicted latent standard deviation
+        latent_loss: loss of the latent twoer
+        samples: random samples sampled from standard guassian
+    """
+
+    with slim.arg_scope([slim.conv2d], reuse=False):
+        stacked_images = tf.concat(images, 3)
+
+        latent_enc1 = slim.conv2d(
+            stacked_images,
+            32, [3, 3],
+            stride=2,
+            scope='latent_conv1',
+            normalizer_fn=tf_layers.layer_norm,
+            normalizer_params={'scope': 'latent_norm1'})
+
+        latent_enc2 = slim.conv2d(
+            latent_enc1,
+            64, [3, 3],
+            stride=2,
+            scope='latent_conv2',
+            normalizer_fn=tf_layers.layer_norm,
+            normalizer_params={'scope': 'latent_norm2'})
+
+        latent_enc3 = slim.conv2d(
+            latent_enc2,
+            64, [3, 3],
+            stride=1,
+            scope='latent_conv3',
+            normalizer_fn=tf_layers.layer_norm,
+            normalizer_params={'scope': 'latent_norm3'})
+
+        latent_mean = slim.conv2d(
+            latent_enc3,
+            hparams.latent_channels, [3, 3],
+            stride=2,
+            activation_fn=None,
+            scope='latent_mean',
+            normalizer_fn=tf_layers.layer_norm,
+            normalizer_params={'scope': 'latent_norm_mean'})
+
+        latent_std = slim.conv2d(
+            latent_enc3,
+            hparams.latent_channels, [3, 3],
+            stride=2,
+            scope='latent_std',
+            normalizer_fn=tf_layers.layer_norm,
+            normalizer_params={'scope': 'latent_std_norm'})
+
+        latent_std += hparams.latent_std_min
+
+    return latent_mean, latent_std
+
+
+def encoder_fn(inputs, hparams):
+    images = tf.unstack(inputs['images'], axis=0)
+    latent_mean, latent_std = construct_latent_tower(images, hparams)
+    outputs = {'zs_mu_enc': latent_mean, 'zs_log_sigma_sq_enc': latent_std}
+    return outputs
+
+
+def construct_model(images,
+                    actions=None,
+                    states=None,
+                    outputs_enc=None,
+                    iter_num=-1.0,
+                    k=-1,
+                    use_state=True,
+                    num_masks=10,
+                    stp=False,
+                    cdna=True,
+                    dna=False,
+                    context_frames=2,
+                    hparams=None):
+    """Build convolutional lstm video predictor using STP, CDNA, or DNA.
+
+    Args:
+        images: tensor of ground truth image sequences
+        actions: tensor of action sequences
+        states: tensor of ground truth state sequences
+        iter_num: tensor of the current training iteration (for sched. sampling)
+        k: constant used for scheduled sampling. -1 to feed in own prediction.
+        use_state: True to include state and action in prediction
+        num_masks: the number of different pixel motion predictions (and
+                   the number of masks for each of those predictions)
+        stp: True to use Spatial Transformer Predictor (STP)
+        cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
+        dna: True to use Dynamic Neural Advection (DNA)
+        context_frames: number of ground truth frames to pass in before
+                        feeding in own predictions
+    Returns:
+        gen_images: predicted future image frames
+        gen_states: predicted future states
+
+    Raises:
+        ValueError: if more than one network option specified or more than 1 mask
+        specified for DNA model.
+    """
+    # Each image is being used twice, in latent tower and main tower.
+    # This is to make sure we are using the *same* image for both, ...
+    # ... given how TF queues work.
+    images = [tf.identity(image) for image in images]
+
+    if stp + cdna + dna != 1:
+        raise ValueError('More than one, or no network option specified.')
+    batch_size, img_height, img_width, color_channels = images[0].shape.as_list()
+    lstm_func = basic_conv_lstm_cell
+
+    # Generated robot states and images.
+    gen_states, gen_images = [], []
+    current_state = states[0]
+
+    if k == -1:
+        feedself = True
+    else:
+        # Scheduled sampling:
+        # Calculate number of ground-truth frames to pass in.
+        num_ground_truth = tf.to_int32(
+            tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k)))))
+        feedself = False
+
+    # LSTM state sizes and states.
+    lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32]))
+    lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None
+    lstm_state5, lstm_state6, lstm_state7 = None, None, None
+
+    # Latent tower
+    if hparams.stochastic_model:
+        latent_shape = [batch_size, img_height // 8, img_width // 8, hparams.latent_channels]
+        if outputs_enc is None:  # equivalent to inference_time
+            latent_mean, latent_std = None, None
+        else:
+            latent_mean, latent_std = outputs_enc['zs_mu_enc'], outputs_enc['zs_log_sigma_sq_enc']
+            assert latent_mean.shape.as_list() == latent_shape
+
+        if hparams.multi_latent:
+            # timestep x batch_size x latent_size
+            samples = tf.random_normal(
+                [hparams.sequence_length - 1] + latent_shape, 0, 1,
+                dtype=tf.float32)
+        else:
+            # batch_size x latent_size
+            samples = tf.random_normal(latent_shape, 0, 1, dtype=tf.float32)
+
+    # Main tower
+    for t in range(hparams.sequence_length - 1):
+        action = actions[t]
+        # Reuse variables after the first timestep.
+        reuse = bool(gen_images)
+
+        done_warm_start = len(gen_images) > context_frames - 1
+        with slim.arg_scope(
+                [lstm_func, slim.layers.conv2d, slim.layers.fully_connected,
+                 tf_layers.layer_norm, slim.layers.conv2d_transpose],
+                reuse=reuse):
+
+            if feedself and done_warm_start:
+                # Feed in generated image.
+                prev_image = gen_images[-1]
+            elif done_warm_start:
+                # Scheduled sampling
+                prev_image = scheduled_sample(images[t], gen_images[-1], batch_size,
+                                              num_ground_truth)
+            else:
+                # Always feed in ground_truth
+                prev_image = images[t]
+
+            # Predicted state is always fed back in
+            state_action = tf.concat(axis=1, values=[action, current_state])
+
+            enc0 = slim.layers.conv2d(
+                prev_image,
+                32, [5, 5],
+                stride=2,
+                scope='scale1_conv1',
+                normalizer_fn=tf_layers.layer_norm,
+                normalizer_params={'scope': 'layer_norm1'})
+
+            hidden1, lstm_state1 = lstm_func(
+                enc0, lstm_state1, lstm_size[0], scope='state1')
+            hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')
+            hidden2, lstm_state2 = lstm_func(
+                hidden1, lstm_state2, lstm_size[1], scope='state2')
+            hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3')
+            enc1 = slim.layers.conv2d(
+                hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2')
+
+            hidden3, lstm_state3 = lstm_func(
+                enc1, lstm_state3, lstm_size[2], scope='state3')
+            hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4')
+            hidden4, lstm_state4 = lstm_func(
+                hidden3, lstm_state4, lstm_size[3], scope='state4')
+            hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5')
+            enc2 = slim.layers.conv2d(
+                hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3')
+
+            # Pass in state and action.
+            smear = tf.reshape(
+                state_action,
+                [int(batch_size), 1, 1, int(state_action.get_shape()[1])])
+            smear = tf.tile(
+                smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1])
+            if use_state:
+                enc2 = tf.concat(axis=3, values=[enc2, smear])
+            # Setup latent
+            if hparams.stochastic_model:
+                latent = samples
+                if hparams.multi_latent:
+                    latent = samples[t]
+                if outputs_enc is not None:  # equivalent to not inference_time
+                    latent = tf.cond(iter_num < hparams.num_iterations_1st_stage,
+                                     lambda: tf.identity(latent),
+                                     lambda: latent_mean + tf.exp(latent_std / 2.0) * latent)
+                with tf.control_dependencies([latent]):
+                    enc2 = tf.concat([enc2, latent], 3)
+
+            enc3 = slim.layers.conv2d(
+                enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4')
+
+            hidden5, lstm_state5 = lstm_func(
+                enc3, lstm_state5, lstm_size[4], scope='state5')  # last 8x8
+            hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6')
+            enc4 = slim.layers.conv2d_transpose(
+                hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1')
+
+            hidden6, lstm_state6 = lstm_func(
+                enc4, lstm_state6, lstm_size[5], scope='state6')  # 16x16
+            hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7')
+            # Skip connection.
+            hidden6 = tf.concat(axis=3, values=[hidden6, enc1])  # both 16x16
+
+            enc5 = slim.layers.conv2d_transpose(
+                hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2')
+            hidden7, lstm_state7 = lstm_func(
+                enc5, lstm_state7, lstm_size[6], scope='state7')  # 32x32
+            hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8')
+
+            # Skip connection.
+            hidden7 = tf.concat(axis=3, values=[hidden7, enc0])  # both 32x32
+
+            enc6 = slim.layers.conv2d_transpose(
+                hidden7,
+                hidden7.get_shape()[3], 3, stride=2, scope='convt3', activation_fn=None,
+                normalizer_fn=tf_layers.layer_norm,
+                normalizer_params={'scope': 'layer_norm9'})
+
+            if dna:
+                # Using largest hidden state for predicting untied conv kernels.
+                enc7 = slim.layers.conv2d_transpose(
+                    enc6, DNA_KERN_SIZE ** 2, 1, stride=1, scope='convt4', activation_fn=None)
+            else:
+                # Using largest hidden state for predicting a new image layer.
+                enc7 = slim.layers.conv2d_transpose(
+                    enc6, color_channels, 1, stride=1, scope='convt4', activation_fn=None)
+                # This allows the network to also generate one image from scratch,
+                # which is useful when regions of the image become unoccluded.
+                transformed = [tf.nn.sigmoid(enc7)]
+
+            if stp:
+                stp_input0 = tf.reshape(hidden5, [int(batch_size), -1])
+                stp_input1 = slim.layers.fully_connected(
+                    stp_input0, 100, scope='fc_stp')
+                transformed += stp_transformation(prev_image, stp_input1, num_masks)
+            elif cdna:
+                cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
+                transformed += cdna_transformation(prev_image, cdna_input, num_masks,
+                                                   int(color_channels))
+            elif dna:
+                # Only one mask is supported (more should be unnecessary).
+                if num_masks != 1:
+                    raise ValueError('Only one mask is supported for DNA model.')
+                transformed = [dna_transformation(prev_image, enc7)]
+
+            masks = slim.layers.conv2d_transpose(
+                enc6, num_masks + 1, 1, stride=1, scope='convt7', activation_fn=None)
+            masks = tf.reshape(
+                tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])),
+                [int(batch_size), int(img_height), int(img_width), num_masks + 1])
+            mask_list = tf.split(axis=3, num_or_size_splits=num_masks + 1, value=masks)
+            output = mask_list[0] * prev_image
+            for layer, mask in zip(transformed, mask_list[1:]):
+                output += layer * mask
+            gen_images.append(output)
+
+            current_state = slim.layers.fully_connected(
+                state_action,
+                int(current_state.get_shape()[1]),
+                scope='state_pred',
+                activation_fn=None)
+            gen_states.append(current_state)
+
+    return gen_images, gen_states
+
+
+## Utility functions
+def stp_transformation(prev_image, stp_input, num_masks):
+    """Apply spatial transformer predictor (STP) to previous image.
+
+    Args:
+        prev_image: previous image to be transformed.
+        stp_input: hidden layer to be used for computing STN parameters.
+        num_masks: number of masks and hence the number of STP transformations.
+    Returns:
+        List of images transformed by the predicted STP parameters.
+    """
+    # Only import spatial transformer if needed.
+    from spatial_transformer import transformer
+
+    identity_params = tf.convert_to_tensor(
+        np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32))
+    transformed = []
+    for i in range(num_masks - 1):
+        params = slim.layers.fully_connected(
+            stp_input, 6, scope='stp_params' + str(i),
+            activation_fn=None) + identity_params
+        transformed.append(transformer(prev_image, params))
+
+    return transformed
+
+
+def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
+    """Apply convolutional dynamic neural advection to previous image.
+
+    Args:
+        prev_image: previous image to be transformed.
+        cdna_input: hidden lyaer to be used for computing CDNA kernels.
+        num_masks: the number of masks and hence the number of CDNA transformations.
+        color_channels: the number of color channels in the images.
+    Returns:
+        List of images transformed by the predicted CDNA kernels.
+    """
+    batch_size = int(cdna_input.get_shape()[0])
+    height = int(prev_image.get_shape()[1])
+    width = int(prev_image.get_shape()[2])
+
+    # Predict kernels using linear function of last hidden layer.
+    cdna_kerns = slim.layers.fully_connected(
+        cdna_input,
+        DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks,
+        scope='cdna_params',
+        activation_fn=None)
+
+    # Reshape and normalize.
+    cdna_kerns = tf.reshape(
+        cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks])
+    cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT
+    norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True)
+    cdna_kerns /= norm_factor
+
+    # Treat the color channel dimension as the batch dimension since the same
+    # transformation is applied to each color channel.
+    # Treat the batch dimension as the channel dimension so that
+    # depthwise_conv2d can apply a different transformation to each sample.
+    cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3])
+    cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks])
+    # Swap the batch and channel dimensions.
+    prev_image = tf.transpose(prev_image, [3, 1, 2, 0])
+
+    # Transform image.
+    transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME')
+
+    # Transpose the dimensions to where they belong.
+    transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks])
+    transformed = tf.transpose(transformed, [3, 1, 2, 0, 4])
+    transformed = tf.unstack(transformed, axis=-1)
+    return transformed
+
+
+def dna_transformation(prev_image, dna_input):
+    """Apply dynamic neural advection to previous image.
+
+    Args:
+        prev_image: previous image to be transformed.
+        dna_input: hidden lyaer to be used for computing DNA transformation.
+    Returns:
+        List of images transformed by the predicted CDNA kernels.
+    """
+    # Construct translated images.
+    prev_image_pad = tf.pad(prev_image, [[0, 0], [2, 2], [2, 2], [0, 0]])
+    image_height = int(prev_image.get_shape()[1])
+    image_width = int(prev_image.get_shape()[2])
+
+    inputs = []
+    for xkern in range(DNA_KERN_SIZE):
+        for ykern in range(DNA_KERN_SIZE):
+            inputs.append(
+                tf.expand_dims(
+                    tf.slice(prev_image_pad, [0, xkern, ykern, 0],
+                             [-1, image_height, image_width, -1]), [3]))
+    inputs = tf.concat(axis=3, values=inputs)
+
+    # Normalize channels to 1.
+    kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT
+    kernel = tf.expand_dims(
+        kernel / tf.reduce_sum(
+            kernel, [3], keepdims=True), [4])
+    return tf.reduce_sum(kernel * inputs, [3], keepdims=False)
+
+
+def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
+    """Sample batch with specified mix of ground truth and generated data points.
+
+    Args:
+        ground_truth_x: tensor of ground-truth data points.
+        generated_x: tensor of generated data points.
+        batch_size: batch size
+        num_ground_truth: number of ground-truth examples to include in batch.
+    Returns:
+        New batch with num_ground_truth sampled from ground_truth_x and the rest
+        from generated_x.
+    """
+    idx = tf.random_shuffle(tf.range(int(batch_size)))
+    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
+    generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))
+
+    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
+    generated_examps = tf.gather(generated_x, generated_idx)
+    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
+                             [ground_truth_examps, generated_examps])
+
+
+def generator_fn(inputs, mode, hparams):
+    images = tf.unstack(inputs['images'], axis=0)
+    batch_size = images[0].shape[0].value
+    action_dim, state_dim = 4, 3
+
+    # if not use_state, use zero actions and states to match reference implementation.
+    actions = inputs.get('actions', tf.zeros([hparams.sequence_length - 1, batch_size, action_dim]))
+    actions = tf.unstack(actions, axis=0)
+    states = inputs.get('states', tf.zeros([hparams.sequence_length, batch_size, state_dim]))
+    states = tf.unstack(states, axis=0)
+    iter_num = tf.to_float(tf.train.get_or_create_global_step())
+
+    schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1
+    gen_images, gen_states = \
+        construct_model(images,
+                        actions,
+                        states,
+                        outputs_enc=None,
+                        iter_num=iter_num,
+                        k=schedule_sampling_k,
+                        use_state='actions' in inputs,
+                        num_masks=hparams.num_masks,
+                        cdna=hparams.transformation == 'cdna',
+                        dna=hparams.transformation == 'dna',
+                        stp=hparams.transformation == 'stp',
+                        context_frames=hparams.context_frames,
+                        hparams=hparams)
+    outputs = {
+        'gen_images': tf.stack(gen_images, axis=0),
+        'gen_states': tf.stack(gen_states, axis=0),
+    }
+
+    if mode == 'train':
+        outputs_enc = encoder_fn(inputs, hparams)
+        tf.get_variable_scope().reuse_variables()
+        gen_images_enc, gen_states_enc = \
+            construct_model(images,
+                            actions,
+                            states,
+                            outputs_enc=outputs_enc,
+                            iter_num=iter_num,
+                            k=schedule_sampling_k,
+                            use_state='actions' in inputs,
+                            num_masks=hparams.num_masks,
+                            cdna=hparams.transformation == 'cdna',
+                            dna=hparams.transformation == 'dna',
+                            stp=hparams.transformation == 'stp',
+                            context_frames=hparams.context_frames,
+                            hparams=hparams)
+        outputs.update({
+            'gen_images_enc': tf.stack(gen_images_enc, axis=0),
+            'gen_states_enc': tf.stack(gen_states_enc, axis=0),
+            'zs_mu_enc': outputs_enc['zs_mu_enc'],
+            'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'],
+        })
+    return outputs
+
+
+class SV2PVideoPredictionModel(VideoPredictionModel):
+    """
+    Stochastic Variational Video Prediction
+    https://arxiv.org/abs/1710.11252
+
+    Reference implementation:
+    https://github.com/mbz/models/tree/master/research/video_prediction
+    """
+    def __init__(self, *args, **kwargs):
+        super(SV2PVideoPredictionModel, self).__init__(
+            generator_fn, *args, ** kwargs)
+        self.deterministic = not self.hparams.stochastic_model
+
+    def get_default_hparams_dict(self):
+        default_hparams = super(SV2PVideoPredictionModel, self).get_default_hparams_dict()
+        hparams = dict(
+            batch_size=32,
+            l1_weight=0.0,
+            l2_weight=1.0,
+            kl_weight=1e-3 * 10 * 8,  # equivalent to latent_loss_multiplier up to a factor (see below)
+            transformation='cdna',
+            num_masks=10,
+            schedule_sampling_k=900.0,
+            stochastic_model=True,
+            multi_latent=False,
+            latent_std_min=-5.0,
+            latent_channels=1,
+            num_iterations_1st_stage=50000,
+            kl_anneal_steps=(100000, 120000),
+            max_steps=200000,
+            decay_steps=(0, 0),  # do not decay the learning rate (doing so produces blurrier images)
+        )
+        # Notes on equivalence with reference implementation:
+        # kl_weight is equivalent to latent_loss_multiplier * time_factor * factor, where
+        # time_factor = (sequence_length - context_frames) since the reference implementation
+        # doesn't normalize the kl divergence over time, and factor = (width // 8) / latent_channels
+        # since the reference implementation's kl_divergence sums over axis=1 instead of axis=-1.
+        # The paper and the reference implementation differs in the annealing of the kl_weight.
+        # Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is
+        # linearly increased for the first 20k iterations of this stage.
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def parse_hparams(self, hparams_dict, hparams):
+        # backwards compatibility
+        deprecated_hparams_keys = [
+            'num_gpus',
+            'acvideo_gan_weight',
+            'acvideo_vae_gan_weight',
+            'image_gan_weight',
+            'image_vae_gan_weight',
+            'tuple_gan_weight',
+            'tuple_vae_gan_weight',
+            'gan_weight',
+            'vae_gan_weight',
+            'video_gan_weight',
+            'video_vae_gan_weight',
+        ]
+        for deprecated_hparams_key in deprecated_hparams_keys:
+            hparams_dict.pop(deprecated_hparams_key, None)
+        return super(SV2PVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)
diff --git a/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction/models/vanilla_convLSTM_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7753004348ae0ae60057a469de1e2d1421c3869
--- /dev/null
+++ b/video_prediction/models/vanilla_convLSTM_model.py
@@ -0,0 +1,162 @@
+import collections
+import functools
+import itertools
+from collections import OrderedDict
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.util import nest
+from video_prediction import ops, flow_ops
+from video_prediction.models import BaseVideoPredictionModel
+from video_prediction.models import networks
+from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
+from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
+from video_prediction.utils import tf_utils
+from datetime import datetime
+from pathlib import Path
+from video_prediction.layers import layer_def as ld
+from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
+
+class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
+    def __init__(self, mode='train',aggregate_nccl=None, hparams_dict=None,
+                 hparams=None, **kwargs):
+        super(VanillaConvLstmVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
+        print ("Hparams_dict",self.hparams)
+        self.mode = mode
+        self.learning_rate = self.hparams.lr
+        self.gen_images_enc = None
+        self.recon_loss = None
+        self.latent_loss = None
+        self.total_loss = None
+        self.context_frames = 10
+        self.sequence_length = 20
+        self.predict_frames = self.sequence_length - self.context_frames
+        self.aggregate_nccl=aggregate_nccl
+    
+    def get_default_hparams_dict(self):
+        """
+        The keys of this dict define valid hyperparameters for instances of
+        this class. A class inheriting from this one should override this
+        method if it has a different set of hyperparameters.
+
+        Returns:
+            A dict with the following hyperparameters.
+
+            batch_size: batch size for training.
+            lr: learning rate. if decay steps is non-zero, this is the
+                learning rate for steps <= decay_step.
+
+
+
+            max_steps: number of training steps.
+
+
+            context_frames: the number of ground-truth frames to pass in at
+                start. Must be specified during instantiation.
+            sequence_length: the number of frames in the video sequence,
+                including the context frames, so this model predicts
+                `sequence_length - context_frames` future frames. Must be
+                specified during instantiation.
+        """
+        default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict()
+        print ("default hparams",default_hparams)
+        hparams = dict(
+            batch_size=16,
+            lr=0.001,
+            end_lr=0.0,
+            nz=16,
+            decay_steps=(200000, 300000),
+            max_steps=350000,
+        )
+
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def build_graph(self, x):
+        self.x = x["images"]
+       
+        self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
+        original_global_variables = tf.global_variables()
+        # ARCHITECTURE
+        self.x_hat_context_frames, self.x_hat_predict_frames = self.convLSTM_network()
+        self.x_hat = tf.concat([self.x_hat_context_frames, self.x_hat_predict_frames], 1)
+
+
+        self.context_frames_loss = tf.reduce_mean(
+            tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
+        self.predict_frames_loss = tf.reduce_mean(
+            tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_predict_frames[:, :, :, :, 0]))
+        self.total_loss = self.context_frames_loss + self.predict_frames_loss
+
+        self.train_op = tf.train.AdamOptimizer(
+            learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
+        self.outputs = {}
+        self.outputs["gen_images"] = self.x_hat
+        # Summary op
+        self.loss_summary = tf.summary.scalar("recon_loss", self.context_frames_loss)
+        self.loss_summary = tf.summary.scalar("latent_loss", self.predict_frames_loss)
+        self.loss_summary = tf.summary.scalar("total_loss", self.total_loss)
+        self.summary_op = tf.summary.merge_all()
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        return
+
+
+    @staticmethod
+    def convLSTM_cell(inputs, hidden, nz=16):
+
+        conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu")
+
+        conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2", activate = "leaky_relu")
+
+        conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3", activate = "leaky_relu")
+
+        y_0 = conv3
+        # conv lstm cell
+        cell_shape = y_0.get_shape().as_list()
+        with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
+            cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size = [3, 3], num_features = 8)
+            if hidden is None:
+                hidden = cell.zero_state(y_0, tf.float32)
+
+            output, hidden = cell(y_0, hidden)
+
+
+        output_shape = output.get_shape().as_list()
+
+
+        z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
+
+        conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, "decode_5", activate = "leaky_relu")
+
+
+        conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6", activate = "leaky_relu")
+
+
+        x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, "decode_7", activate = "sigmoid")  # set activation to linear
+
+        return x_hat, hidden
+
+    def convLSTM_network(self):
+        network_template = tf.make_template('network',
+                                            VanillaConvLstmVideoPredictionModel.convLSTM_cell)  # make the template to share the variables
+        # create network
+        x_hat_context = []
+        x_hat_predict = []
+        seq_start = 1
+        hidden = None
+        for i in range(self.context_frames):
+            if i < seq_start:
+                x_1, hidden = network_template(self.x[:, i, :, :, :], hidden)
+            else:
+                x_1, hidden = network_template(x_1, hidden)
+            x_hat_context.append(x_1)
+
+        for i in range(self.predict_frames):
+            x_1, hidden = network_template(x_1, hidden)
+            x_hat_predict.append(x_1)
+
+        # pack them all together
+        x_hat_context = tf.stack(x_hat_context)
+        x_hat_predict = tf.stack(x_hat_predict)
+        self.x_hat_context = tf.transpose(x_hat_context, [1, 0, 2, 3, 4])  # change first dim with sec dim
+        self.x_hat_predict = tf.transpose(x_hat_predict, [1, 0, 2, 3, 4])  # change first dim with sec dim
+        return self.x_hat_context, self.x_hat_predict
diff --git a/video_prediction/models/vanilla_vae_model.py b/video_prediction/models/vanilla_vae_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..eec5598305044226280080d630313487c7d847a4
--- /dev/null
+++ b/video_prediction/models/vanilla_vae_model.py
@@ -0,0 +1,191 @@
+import collections
+import functools
+import itertools
+from collections import OrderedDict
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.util import nest
+from video_prediction import ops, flow_ops
+from video_prediction.models import BaseVideoPredictionModel
+from video_prediction.models import networks
+from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
+from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
+from video_prediction.utils import tf_utils
+from datetime import datetime
+from pathlib import Path
+from video_prediction.layers import layer_def as ld
+
+class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
+    def __init__(self, mode='train', aggregate_nccl=None,hparams_dict=None,
+                 hparams=None,**kwargs):
+        super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
+        self.mode = mode
+        self.learning_rate = self.hparams.lr
+        self.nz = self.hparams.nz
+        self.aggregate_nccl=aggregate_nccl
+        self.gen_images_enc = None
+        self.train_op = None
+        self.summary_op = None
+        self.recon_loss = None
+        self.latent_loss = None
+        self.total_loss = None
+
+    def get_default_hparams_dict(self):
+        """
+        The keys of this dict define valid hyperparameters for instances of
+        this class. A class inheriting from this one should override this
+        method if it has a different set of hyperparameters.
+
+        Returns:
+            A dict with the following hyperparameters.
+
+            batch_size: batch size for training.
+            lr: learning rate. if decay steps is non-zero, this is the
+                learning rate for steps <= decay_step.
+            end_lr: learning rate for steps >= end_decay_step if decay_steps
+                is non-zero, ignored otherwise.
+            decay_steps: (decay_step, end_decay_step) tuple.
+            max_steps: number of training steps.
+            beta1: momentum term of Adam.
+            beta2: momentum term of Adam.
+            context_frames: the number of ground-truth frames to pass in at
+                start. Must be specified during instantiation.
+            sequence_length: the number of frames in the video sequence,
+                including the context frames, so this model predicts
+                `sequence_length - context_frames` future frames. Must be
+                specified during instantiation.
+        """
+        default_hparams = super(VanillaVAEVideoPredictionModel, self).get_default_hparams_dict()
+        print ("default hparams",default_hparams)
+        hparams = dict(
+            batch_size=16,
+            lr=0.001,
+            end_lr=0.0,
+            decay_steps=(200000, 300000),
+            lr_boundaries=(0,),
+            max_steps=350000,
+            nz=10,
+            context_frames=-1,
+            sequence_length=-1,
+            clip_length=10, #Bing: TODO What is the clip_length, original is 10,
+        )
+        return dict(itertools.chain(default_hparams.items(), hparams.items()))
+
+    def build_graph(self,x):
+        
+        
+        
+
+
+
+        tf.set_random_seed(12345)
+        self.x = x["images"]
+       
+
+        self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
+        original_global_variables = tf.global_variables()
+        self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step')
+
+        self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all()
+       
+     
+    
+   
+  
+ 
+
+
+
+
+
+        self.recon_loss = tf.reduce_mean(tf.square(self.x[:, 1:, :, :, 0] - self.x_hat[:, :-1, :, :, 0]))
+
+
+
+
+
+        latent_loss = -0.5 * tf.reduce_sum(
+            1 + self.z_log_sigma_sq - tf.square(self.z_mu) -
+            tf.exp(self.z_log_sigma_sq), axis = 1)
+        self.latent_loss = tf.reduce_mean(latent_loss)
+        self.total_loss = self.recon_loss + self.latent_loss
+        self.train_op = tf.train.AdamOptimizer(
+            learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
+        # Build a saver
+
+        self.losses = {
+            'recon_loss': self.recon_loss,
+            'latent_loss': self.latent_loss,
+            'total_loss': self.total_loss,
+        }
+
+        # Summary op
+        self.loss_summary = tf.summary.scalar("recon_loss", self.recon_loss)
+        self.loss_summary = tf.summary.scalar("latent_loss", self.latent_loss)
+        self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss)
+        self.summary_op = tf.summary.merge_all()
+
+
+
+        self.outputs = {}
+        self.outputs["gen_images"] = self.x_hat
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+
+        return
+
+
+    @staticmethod
+    def vae_arc3(x,l_name=0,nz=16):
+        seq_name = "sq_" + str(l_name) + "_"
+        
+        conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1")
+
+
+        conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2")
+
+
+        conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3")
+
+
+        conv4 = tf.layers.Flatten()(conv3)
+
+        conv3_shape = conv3.get_shape().as_list()
+
+        
+        z_mu = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m")
+        z_log_sigma_sq = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma')
+        eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32)
+        z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps
+        
+        z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1")
+        
+        
+        z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]])
+
+        conv5 = ld.transpose_conv_layer(z3, 3, 2, 8,
+                                        seq_name + "decode_5")  
+
+        conv6  = ld.transpose_conv_layer(conv5, 3, 1, 8,
+                                        seq_name + "decode_6")
+        
+
+        x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8")
+
+        return x_hat, z_mu, z_log_sigma_sq, z
+
+    def vae_arc_all(self):
+        X = []
+        z_log_sigma_sq_all = []
+        z_mu_all = []
+        for i in range(20):
+            q, z_mu, z_log_sigma_sq, z = VanillaVAEVideoPredictionModel.vae_arc3(self.x[:, i, :, :, :], l_name=i, nz=self.nz)
+            X.append(q)
+            z_log_sigma_sq_all.append(z_log_sigma_sq)
+            z_mu_all.append(z_mu)
+        x_hat = tf.stack(X, axis = 1)
+        z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1)
+        z_mu_all = tf.stack(z_mu_all, axis = 1)
+       
+      
+        return x_hat, z_log_sigma_sq_all, z_mu_all
diff --git a/video_prediction/ops.py b/video_prediction/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d2e9f2eac608d2a9ec61e24a354882b3acce2de
--- /dev/null
+++ b/video_prediction/ops.py
@@ -0,0 +1,1098 @@
+import numpy as np
+import tensorflow as tf
+
+
+def dense(inputs, units, use_spectral_norm=False, use_bias=True):
+    with tf.variable_scope('dense'):
+        input_shape = inputs.get_shape().as_list()
+        kernel_shape = [input_shape[1], units]
+        kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
+        if use_spectral_norm:
+            kernel = spectral_normed_weight(kernel)
+        outputs = tf.matmul(inputs, kernel)
+        if use_bias:
+            bias = tf.get_variable('bias', [units], dtype=tf.float32, initializer=tf.zeros_initializer())
+            outputs = tf.nn.bias_add(outputs, bias)
+        return outputs
+
+
+def pad1d(inputs, size, strides=(1,), padding='SAME', mode='CONSTANT'):
+    size = list(size) if isinstance(size, (tuple, list)) else [size]
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides]
+    input_shape = inputs.get_shape().as_list()
+    assert len(input_shape) == 3
+    in_width = input_shape[1]
+    if padding in ('SAME', 'FULL'):
+        if in_width % strides[0] == 0:
+            pad_along_width = max(size[0] - strides[0], 0)
+        else:
+            pad_along_width = max(size[0] - (in_width % strides[0]), 0)
+        if padding == 'SAME':
+            pad_left = pad_along_width // 2
+            pad_right = pad_along_width - pad_left
+        else:
+            pad_left = pad_along_width
+            pad_right = pad_along_width
+        padding_pattern = [[0, 0],
+                           [pad_left, pad_right],
+                           [0, 0]]
+        outputs = tf.pad(inputs, padding_pattern, mode=mode)
+    elif padding == 'VALID':
+        outputs = inputs
+    else:
+        raise ValueError("Invalid padding scheme %s" % padding)
+    return outputs
+
+
+def conv1d(inputs, filters, kernel_size, strides=(1,), padding='SAME', kernel=None, use_bias=True):
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size]
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides]
+    input_shape = inputs.get_shape().as_list()
+    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
+    if kernel is None:
+        with tf.variable_scope('conv1d'):
+            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
+    else:
+        if kernel_shape != kernel.get_shape().as_list():
+            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
+    if padding == 'FULL':
+        inputs = pad1d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
+        padding = 'VALID'
+    stride, = strides
+    outputs = tf.nn.conv1d(inputs, kernel, stride, padding=padding)
+    if use_bias:
+        with tf.variable_scope('conv1d'):
+            bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
+            outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def pad2d_paddings(inputs, size, strides=(1, 1), rate=(1, 1), padding='SAME'):
+    """
+    Computes the paddings for a 4-D tensor according to the convolution padding algorithm.
+
+    See pad2d.
+
+    Reference:
+        https://www.tensorflow.org/api_guides/python/nn#convolution
+        https://www.tensorflow.org/api_docs/python/tf/nn/with_space_to_batch
+    """
+    size = np.array(size) if isinstance(size, (tuple, list)) else np.array([size] * 2)
+    strides = np.array(strides) if isinstance(strides, (tuple, list)) else np.array([strides] * 2)
+    rate = np.array(rate) if isinstance(rate, (tuple, list)) else np.array([rate] * 2)
+    if np.any(strides > 1) and np.any(rate > 1):
+        raise ValueError("strides > 1 not supported in conjunction with rate > 1")
+    input_shape = inputs.get_shape().as_list()
+    assert len(input_shape) == 4
+    input_size = np.array(input_shape[1:3])
+    if padding in ('SAME', 'FULL'):
+        if np.any(rate > 1):
+            # We have two padding contributions. The first is used for converting "SAME"
+            # to "VALID". The second is required so that the height and width of the
+            # zero-padded value tensor are multiples of rate.
+
+            # Spatial dimensions of the filters and the upsampled filters in which we
+            # introduce (rate - 1) zeros between consecutive filter values.
+            dilated_size = size + (size - 1) * (rate - 1)
+            pad = dilated_size - 1
+        else:
+            pad = np.where(input_size % strides == 0,
+                           np.maximum(size - strides, 0),
+                           np.maximum(size - (input_size % strides), 0))
+        if padding == 'SAME':
+            # When full_padding_shape is odd, we pad more at end, following the same
+            # convention as conv2d.
+            pad_start = pad // 2
+            pad_end = pad - pad_start
+        else:
+            pad_start = pad
+            pad_end = pad
+        if np.any(rate > 1):
+            # More padding so that rate divides the height and width of the input.
+            # TODO: not sure if this is correct when padding == 'FULL'
+            orig_pad_end = pad_end
+            full_input_size = input_size + pad_start + orig_pad_end
+            pad_end_extra = (rate - full_input_size % rate) % rate
+            pad_end = orig_pad_end + pad_end_extra
+        paddings = [[0, 0],
+                    [pad_start[0], pad_end[0]],
+                    [pad_start[1], pad_end[1]],
+                    [0, 0]]
+    elif padding == 'VALID':
+        paddings = [[0, 0]] * 4
+    else:
+        raise ValueError("Invalid padding scheme %s" % padding)
+    return paddings
+
+
+def pad2d(inputs, size, strides=(1, 1), rate=(1, 1), padding='SAME', mode='CONSTANT'):
+    """
+    Pads a 4-D tensor according to the convolution padding algorithm.
+
+    Convolution with a padding scheme
+        conv2d(..., padding=padding)
+    is equivalent to zero-padding of the input with such scheme, followed by
+    convolution with 'VALID' padding
+        padded = pad2d(..., padding=padding, mode='CONSTANT')
+        conv2d(padded, ..., padding='VALID')
+
+    Args:
+        inputs: A 4-D tensor of shape
+            `[batch, in_height, in_width, in_channels]`.
+        padding: A string, either 'VALID', 'SAME', or 'FULL'. The padding algorithm.
+        mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive).
+
+    Returns:
+        A 4-D tensor.
+
+    Reference:
+        https://www.tensorflow.org/api_guides/python/nn#convolution
+    """
+    paddings = pad2d_paddings(inputs, size, strides=strides, rate=rate, padding=padding)
+    if paddings == [[0, 0]] * 4:
+        outputs = inputs
+    else:
+        outputs = tf.pad(inputs, paddings, mode=mode)
+    return outputs
+
+
+def local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
+            kernel=None, flip_filters=False,
+            use_bias=True, channelwise=False):
+    """
+    2-D locally connected operation.
+
+    Works similarly to 2-D convolution except that the weights are unshared, that is, a different set of filters is
+    applied at each different patch of the input.
+
+    Args:
+        inputs: A 4-D tensor of shape
+            `[batch, in_height, in_width, in_channels]`.
+        kernel: A 6-D or 7-D tensor of shape
+            `[in_height, in_width, kernel_size[0], kernel_size[1], in_channels, filters]` or
+            `[batch, in_height, in_width, kernel_size[0], kernel_size[1], in_channels, filters]`.
+
+    Returns:
+        A 4-D tensor.
+    """
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    if strides != [1, 1]:
+        raise NotImplementedError
+    if padding == 'FULL':
+        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
+        padding = 'VALID'
+    input_shape = inputs.get_shape().as_list()
+    if padding == 'SAME':
+        output_shape = input_shape[:3] + [filters]
+    elif padding == 'VALID':
+        output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters]
+    else:
+        raise ValueError("Invalid padding scheme %s" % padding)
+
+    if channelwise:
+        if filters not in (input_shape[-1], 1):
+            raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise "
+                             "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1]))
+        kernel_shape = output_shape[1:3] + kernel_size + [filters]
+    else:
+        kernel_shape = output_shape[1:3] + kernel_size + [input_shape[-1], filters]
+    if kernel is None:
+        with tf.variable_scope('local2d'):
+            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
+    else:
+        if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape):
+            raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s"
+                             % (tuple(kernel_shape), tuple([input_shape[0]] + kernel_shape), tuple(kernel.get_shape().as_list())))
+
+    outputs = []
+    for i in range(kernel_size[0]):
+        filter_h_ind = -i-1 if flip_filters else i
+        if padding == 'VALID':
+            ii = i
+        else:
+            ii = i - (kernel_size[0] // 2)
+        input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1]))
+        output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii)
+        assert 0 <= output_h_slice.start < output_shape[1]
+        assert 0 < output_h_slice.stop <= output_shape[1]
+
+        for j in range(kernel_size[1]):
+            filter_w_ind = -j-1 if flip_filters else j
+            if padding == 'VALID':
+                jj = j
+            else:
+                jj = j - (kernel_size[1] // 2)
+            input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2]))
+            output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj)
+            assert 0 <= output_w_slice.start < output_shape[2]
+            assert 0 < output_w_slice.stop <= output_shape[2]
+            if channelwise:
+                inc = inputs[:, input_h_slice, input_w_slice, :] * \
+                      kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :]
+            else:
+                inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] *
+                                    kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :, :], axis=-2)
+            # equivalent to this
+            # outputs[:, output_h_slice, output_w_slice, :] += inc
+            paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop],
+                        [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]]
+            outputs.append(tf.pad(inc, paddings))
+    outputs = tf.add_n(outputs)
+    if use_bias:
+        with tf.variable_scope('local2d'):
+            bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer())
+            outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def separable_local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
+                      vertical_kernel=None, horizontal_kernel=None, flip_filters=False,
+                      use_bias=True, channelwise=False):
+    """
+    2-D locally connected operation with separable filters.
+
+    Note that, unlike tf.nn.separable_conv2d, this is spatial separability between dimensions 1 and 2.
+
+    Args:
+        inputs: A 4-D tensor of shape
+            `[batch, in_height, in_width, in_channels]`.
+        vertical_kernel: A 5-D or 6-D tensor of shape
+            `[in_height, in_width, kernel_size[0], in_channels, filters]` or
+            `[batch, in_height, in_width, kernel_size[0], in_channels, filters]`.
+        horizontal_kernel: A 5-D or 6-D tensor of shape
+            `[in_height, in_width, kernel_size[1], in_channels, filters]` or
+            `[batch, in_height, in_width, kernel_size[1], in_channels, filters]`.
+
+    Returns:
+        A 4-D tensor.
+    """
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    if strides != [1, 1]:
+        raise NotImplementedError
+    if padding == 'FULL':
+        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
+        padding = 'VALID'
+    input_shape = inputs.get_shape().as_list()
+    if padding == 'SAME':
+        output_shape = input_shape[:3] + [filters]
+    elif padding == 'VALID':
+        output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters]
+    else:
+        raise ValueError("Invalid padding scheme %s" % padding)
+
+    kernels = [vertical_kernel, horizontal_kernel]
+    for i, (kernel_type, kernel_length, kernel) in enumerate(zip(['vertical', 'horizontal'], kernel_size, kernels)):
+        if channelwise:
+            if filters not in (input_shape[-1], 1):
+                raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise "
+                                 "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1]))
+            kernel_shape = output_shape[1:3] + [kernel_length, filters]
+        else:
+            kernel_shape = output_shape[1:3] + [kernel_length, input_shape[-1], filters]
+        if kernel is None:
+            with tf.variable_scope('separable_local2d'):
+                kernel = tf.get_variable('%s_kernel' % kernel_type, kernel_shape, dtype=tf.float32,
+                                         initializer=tf.truncated_normal_initializer(stddev=0.02))
+                kernels[i] = kernel
+        else:
+            if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape):
+                raise ValueError("Expecting %s kernel with shape %s or %s but instead got kernel with shape %s"
+                                 % (kernel_type,
+                                    tuple(kernel_shape), tuple([input_shape[0]] +kernel_shape),
+                                    tuple(kernel.get_shape().as_list())))
+
+    outputs = []
+    for i in range(kernel_size[0]):
+        filter_h_ind = -i-1 if flip_filters else i
+        if padding == 'VALID':
+            ii = i
+        else:
+            ii = i - (kernel_size[0] // 2)
+        input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1]))
+        output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii)
+        assert 0 <= output_h_slice.start < output_shape[1]
+        assert 0 < output_h_slice.stop <= output_shape[1]
+
+        for j in range(kernel_size[1]):
+            filter_w_ind = -j-1 if flip_filters else j
+            if padding == 'VALID':
+                jj = j
+            else:
+                jj = j - (kernel_size[1] // 2)
+            input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2]))
+            output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj)
+            assert 0 <= output_w_slice.start < output_shape[2]
+            assert 0 < output_w_slice.stop <= output_shape[2]
+            if channelwise:
+                inc = inputs[:, input_h_slice, input_w_slice, :] * \
+                      kernels[0][..., output_h_slice, output_w_slice, filter_h_ind, :] * \
+                      kernels[1][..., output_h_slice, output_w_slice, filter_w_ind, :]
+            else:
+                inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] *
+                                    kernels[0][..., output_h_slice, output_w_slice, filter_h_ind, :, :] *
+                                    kernels[1][..., output_h_slice, output_w_slice, filter_w_ind, :, :],
+                                    axis=-2)
+            # equivalent to this
+            # outputs[:, output_h_slice, output_w_slice, :] += inc
+            paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop],
+                        [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]]
+            outputs.append(tf.pad(inc, paddings))
+    outputs = tf.add_n(outputs)
+    if use_bias:
+        with tf.variable_scope('separable_local2d'):
+            bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer())
+            outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def kronecker_local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
+                      kernels=None, flip_filters=False, use_bias=True, channelwise=False):
+    """
+    2-D locally connected operation with filters represented as a kronecker product of smaller filters
+
+    Args:
+        inputs: A 4-D tensor of shape
+            `[batch, in_height, in_width, in_channels]`.
+        kernel: A list of 6-D or 7-D tensors of shape
+            `[in_height, in_width, kernel_size[i][0], kernel_size[i][1], in_channels, filters]` or
+            `[batch, in_height, in_width, kernel_size[i][0], kernel_size[i][1], in_channels, filters]`.
+
+    Returns:
+        A 4-D tensor.
+    """
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    if strides != [1, 1]:
+        raise NotImplementedError
+    if padding == 'FULL':
+        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
+        padding = 'VALID'
+    input_shape = inputs.get_shape().as_list()
+    if padding == 'SAME':
+        output_shape = input_shape[:3] + [filters]
+    elif padding == 'VALID':
+        output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters]
+    else:
+        raise ValueError("Invalid padding scheme %s" % padding)
+
+    if channelwise:
+        if filters not in (input_shape[-1], 1):
+            raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise "
+                             "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1]))
+        kernel_shape = output_shape[1:3] + kernel_size + [filters]
+        factor_kernel_shape = output_shape[1:3] + [None, None, filters]
+    else:
+        kernel_shape = output_shape[1:3] + kernel_size + [input_shape[-1], filters]
+        factor_kernel_shape = output_shape[1:3] + [None, None, input_shape[-1], filters]
+    if kernels is None:
+        with tf.variable_scope('kronecker_local2d'):
+            kernels = [tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                       initializer=tf.truncated_normal_initializer(stddev=0.02))]
+        filter_h_lengths = [kernel_size[0]]
+        filter_w_lengths = [kernel_size[1]]
+    else:
+        for kernel in kernels:
+            if not ((len(kernel.shape) == len(factor_kernel_shape) and
+                    all(((k == f) or f is None) for k, f in zip(kernel.get_shape().as_list(), factor_kernel_shape))) or
+                    (len(kernel.shape) == (len(factor_kernel_shape) + 1) and
+                    all(((k == f) or f is None) for k, f in zip(kernel.get_shape().as_list(), [input_shape[0]] +factor_kernel_shape)))):
+                raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s"
+                                 % (tuple(factor_kernel_shape), tuple([input_shape[0]] + factor_kernel_shape),
+                                    tuple(kernel.get_shape().as_list())))
+        if channelwise:
+            filter_h_lengths, filter_w_lengths = zip(*[kernel.get_shape().as_list()[-3:-1] for kernel in kernels])
+        else:
+            filter_h_lengths, filter_w_lengths = zip(*[kernel.get_shape().as_list()[-4:-2] for kernel in kernels])
+        if [np.prod(filter_h_lengths), np.prod(filter_w_lengths)] != kernel_size:
+            raise ValueError("Expecting kernel size %s but instead got kernel size %s"
+                             % (tuple(kernel_size), tuple([np.prod(filter_h_lengths), np.prod(filter_w_lengths)])))
+
+    def get_inds(ind, lengths):
+        inds = []
+        for i in range(len(lengths)):
+            curr_ind = int(ind)
+            for j in range(len(lengths) - 1, i, -1):
+                curr_ind //= lengths[j]
+            curr_ind %= lengths[i]
+            inds.append(curr_ind)
+        return inds
+
+    outputs = []
+    for i in range(kernel_size[0]):
+        if padding == 'VALID':
+            ii = i
+        else:
+            ii = i - (kernel_size[0] // 2)
+        input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1]))
+        output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii)
+        assert 0 <= output_h_slice.start < output_shape[1]
+        assert 0 < output_h_slice.stop <= output_shape[1]
+
+        for j in range(kernel_size[1]):
+            if padding == 'VALID':
+                jj = j
+            else:
+                jj = j - (kernel_size[1] // 2)
+            input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2]))
+            output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj)
+            assert 0 <= output_w_slice.start < output_shape[2]
+            assert 0 < output_w_slice.stop <= output_shape[2]
+            kernel_slice = 1.0
+            for filter_h_ind, filter_w_ind, kernel in zip(get_inds(i, filter_h_lengths), get_inds(j, filter_w_lengths), kernels):
+                if flip_filters:
+                    filter_h_ind = -filter_h_ind-1
+                    filter_w_ind = -filter_w_ind-1
+                if channelwise:
+                    kernel_slice *= kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :]
+                else:
+                    kernel_slice *= kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :, :]
+            if channelwise:
+                inc = inputs[:, input_h_slice, input_w_slice, :] * kernel_slice
+            else:
+                inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] * kernel_slice, axis=-2)
+            # equivalent to this
+            # outputs[:, output_h_slice, output_w_slice, :] += inc
+            paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop],
+                        [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]]
+            outputs.append(tf.pad(inc, paddings))
+    outputs = tf.add_n(outputs)
+    if use_bias:
+        with tf.variable_scope('kronecker_local2d'):
+            bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer())
+            outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def depthwise_conv2d(inputs, channel_multiplier, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True):
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    input_shape = inputs.get_shape().as_list()
+    kernel_shape = kernel_size + [input_shape[-1], channel_multiplier]
+    if kernel is None:
+        with tf.variable_scope('depthwise_conv2d'):
+            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
+    else:
+        if kernel_shape != kernel.get_shape().as_list():
+            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s"
+                             % (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
+    if padding == 'FULL':
+        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
+        padding = 'VALID'
+    outputs = tf.nn.depthwise_conv2d(inputs, kernel, [1] + strides + [1], padding=padding)
+    if use_bias:
+        with tf.variable_scope('depthwise_conv2d'):
+            bias = tf.get_variable('bias', [input_shape[-1] * channel_multiplier], dtype=tf.float32, initializer=tf.zeros_initializer())
+            outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, use_spectral_norm=False):
+    """
+    2-D convolution.
+
+    Args:
+        inputs: A 4-D tensor of shape
+            `[batch, in_height, in_width, in_channels]`.
+        kernel: A 4-D or 5-D tensor of shape
+            `[kernel_size[0], kernel_size[1], in_channels, filters]` or
+            `[batch, kernel_size[0], kernel_size[1], in_channels, filters]`.
+        bias: A 1-D or 2-D tensor of shape
+            `[filters]` or `[batch, filters]`.
+
+    Returns:
+        A 4-D tensor.
+    """
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    input_shape = inputs.get_shape().as_list()
+    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
+    if kernel is None:
+        with tf.variable_scope('conv2d'):
+            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
+            if use_spectral_norm:
+                kernel = spectral_normed_weight(kernel)
+    else:
+        if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape):
+            raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s"
+                             % (tuple(kernel_shape), tuple([input_shape[0]] + kernel_shape), tuple(kernel.get_shape().as_list())))
+    if padding == 'FULL':
+        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
+        padding = 'VALID'
+    if kernel.get_shape().ndims == 4:
+        outputs = tf.nn.conv2d(inputs, kernel, [1] + strides + [1], padding=padding)
+    else:
+        def conv2d_single_fn(args):
+            input_, kernel_ = args
+            input_ = tf.expand_dims(input_, axis=0)
+            output = tf.nn.conv2d(input_, kernel_, [1] + strides + [1], padding=padding)
+            output = tf.squeeze(output, axis=0)
+            return output
+        outputs = tf.map_fn(conv2d_single_fn, [inputs, kernel], dtype=tf.float32)
+    if use_bias:
+        bias_shape = [filters]
+        if bias is None:
+            with tf.variable_scope('conv2d'):
+                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
+        else:
+            if bias.get_shape().as_list() not in (bias_shape, [input_shape[0]] + bias_shape):
+                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s"
+                                 % (tuple(bias_shape), tuple(bias.get_shape().as_list())))
+        if bias.get_shape().ndims == 1:
+            outputs = tf.nn.bias_add(outputs, bias)
+        else:
+            outputs = tf.add(outputs, bias[:, None, None, :])
+    return outputs
+
+
+def deconv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True):
+    """
+    2-D transposed convolution.
+
+    Notes on padding:
+       The equivalent of transposed convolution with full padding is a convolution with valid padding, and
+       the equivalent of transposed convolution with valid padding is a convolution with full padding.
+
+    Reference:
+        http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
+    """
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    input_shape = inputs.get_shape().as_list()
+    kernel_shape = list(kernel_size) + [filters, input_shape[-1]]
+    if kernel is None:
+        with tf.variable_scope('deconv2d'):
+            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
+    else:
+        if kernel_shape != kernel.get_shape().as_list():
+            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
+    if padding == 'FULL':
+        output_h, output_w = [s * (i + 1) - k for (i, k, s) in zip(input_shape[1:3], kernel_size, strides)]
+    elif padding == 'SAME':
+        output_h, output_w = [s * i for (i, s) in zip(input_shape[1:3], strides)]
+    elif padding == 'VALID':
+        output_h, output_w = [s * (i - 1) + k for (i, k, s) in zip(input_shape[1:3], kernel_size, strides)]
+    else:
+        raise ValueError("Invalid padding scheme %s" % padding)
+    output_shape = [input_shape[0], output_h, output_w, filters]
+    outputs = tf.nn.conv2d_transpose(inputs, kernel, output_shape, [1] + strides + [1], padding=padding)
+    if use_bias:
+        with tf.variable_scope('deconv2d'):
+            bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
+            outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def get_bilinear_kernel(strides):
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    strides = np.array(strides)
+    kernel_size = 2 * strides - strides % 2
+    center = strides - (kernel_size % 2 == 1) - 0.5 * (kernel_size % 2 != 1)
+    vertical_kernel = 1 - abs(np.arange(kernel_size[0]) - center[0]) / strides[0]
+    horizontal_kernel = 1 - abs(np.arange(kernel_size[1]) - center[1]) / strides[1]
+    kernel = vertical_kernel[:, None] * horizontal_kernel[None, :]
+    return kernel
+
+
+def upsample2d(inputs, strides, padding='SAME', upsample_mode='bilinear'):
+    if upsample_mode == 'bilinear':
+        single_bilinear_kernel = get_bilinear_kernel(strides).astype(np.float32)
+        input_shape = inputs.get_shape().as_list()
+        bilinear_kernel = tf.matrix_diag(tf.tile(tf.constant(single_bilinear_kernel)[..., None], (1, 1, input_shape[-1])))
+        outputs = deconv2d(inputs, input_shape[-1], kernel_size=single_bilinear_kernel.shape,
+                           strides=strides, kernel=bilinear_kernel, padding=padding, use_bias=False)
+    elif upsample_mode == 'nearest':
+        strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+        input_shape = inputs.get_shape().as_list()
+        inputs_tiled = tf.tile(inputs[:, :, None, :, None, :], [1, 1, strides[0], 1, strides[1], 1])
+        outputs = tf.reshape(inputs_tiled, [input_shape[0], input_shape[1] * strides[0],
+                                            input_shape[2] * strides[1], input_shape[3]])
+    else:
+        raise ValueError("Unknown upsample mode %s" % upsample_mode)
+    return outputs
+
+
+def upsample2d_v2(inputs, strides, padding='SAME', upsample_mode='bilinear'):
+    """
+    Possibly less computationally efficient but more memory efficent than upsampled2d.
+    """
+    if upsample_mode == 'bilinear':
+        single_kernel = get_bilinear_kernel(strides).astype(np.float32)
+    elif upsample_mode == 'nearest':
+        strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+        single_kernel = np.ones(strides, dtype=np.float32)
+    else:
+        raise ValueError("Unknown upsample mode %s" % upsample_mode)
+    input_shape = inputs.get_shape().as_list()
+    kernel = tf.constant(single_kernel)[:, :, None, None]
+    inputs = tf.transpose(inputs, [3, 0, 1, 2])[..., None]
+    outputs = tf.map_fn(lambda input: deconv2d(input, 1, kernel_size=single_kernel.shape,
+                                               strides=strides, kernel=kernel,
+                                               padding=padding, use_bias=False),
+                        inputs, parallel_iterations=input_shape[-1])
+    outputs = tf.transpose(tf.squeeze(outputs, axis=-1), [1, 2, 3, 0])
+    return outputs
+
+
+def upsample_conv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
+                    kernel=None, use_bias=True, bias=None, upsample_mode='bilinear'):
+    """
+    Upsamples the inputs by a factor using bilinear interpolation and the performs conv2d on the upsampled input. This
+    function is more computationally and memory efficient than a naive implementation. Unlike a naive implementation
+    that would upsample the input first, this implementation first convolves the bilinear kernel with the given kernel,
+    and then performs the convolution (actually a deconv2d) with the combined kernel. As opposed to just using deconv2d
+    directly, this function is less prone to checkerboard artifacts thanks to the implicit bilinear upsampling.
+
+    Example:
+        >>> import numpy as np
+        >>> import tensorflow as tf
+        >>> from video_prediction.ops import upsample_conv2d, upsample2d, conv2d, pad2d_paddings
+        >>> inputs_shape = [4, 8, 8, 64]
+        >>> kernel_size = [3, 3]  # for convolution
+        >>> filters = 32  # for convolution
+        >>> strides = [2, 2]  # for upsampling
+        >>> inputs = tf.get_variable("inputs", inputs_shape)
+        >>> kernel = tf.get_variable("kernel", (kernel_size[0], kernel_size[1], inputs_shape[-1], filters))
+        >>> bias = tf.get_variable("bias", (filters,))
+        >>> outputs = upsample_conv2d(inputs, filters, kernel_size=kernel_size, strides=strides, \
+                                      kernel=kernel, bias=bias)
+        >>> # upsample with bilinear interpolation
+        >>> inputs_up = upsample2d(inputs, strides=strides, padding='VALID')
+        >>> # convolve upsampled input with kernel
+        >>> outputs_up = conv2d(inputs_up, filters, kernel_size=kernel_size, strides=(1, 1), \
+                                kernel=kernel, bias=bias, padding='FULL')
+        >>> # crop appropriately
+        >>> same_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='SAME')
+        >>> full_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='FULL')
+        >>> crop_top = (strides[0] - strides[0] % 2) // 2 + full_paddings[1][1] - same_paddings[1][1]
+        >>> crop_left = (strides[1] - strides[1] % 2) // 2 + full_paddings[2][1] - same_paddings[2][1]
+        >>> outputs_up = outputs_up[:, crop_top:crop_top + strides[0] * inputs_shape[1], \
+                                    crop_left:crop_left + strides[1] * inputs_shape[2], :]
+        >>> sess = tf.Session()
+        >>> sess.run(tf.global_variables_initializer())
+        >>> assert np.allclose(*sess.run([outputs, outputs_up]), atol=1e-5)
+
+    """
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    if padding != 'SAME' or upsample_mode != 'bilinear':
+        raise NotImplementedError
+    input_shape = inputs.get_shape().as_list()
+    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
+    if kernel is None:
+        with tf.variable_scope('upsample_conv2d'):
+            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
+    else:
+        if kernel_shape != kernel.get_shape().as_list():
+            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" %
+                             (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
+
+    # convolve bilinear kernel with kernel
+    single_bilinear_kernel = get_bilinear_kernel(strides).astype(np.float32)
+    kernel_transposed = tf.transpose(kernel, (0, 1, 3, 2))
+    kernel_reshaped = tf.reshape(kernel_transposed, kernel_size + [1, input_shape[-1] * filters])
+    kernel_up_reshaped = conv2d(tf.constant(single_bilinear_kernel)[None, :, :, None], input_shape[-1] * filters,
+                                kernel_size=kernel_size, kernel=kernel_reshaped, padding='FULL', use_bias=False)
+    kernel_up = tf.reshape(kernel_up_reshaped,
+                           kernel_up_reshaped.get_shape().as_list()[1:3] + [filters, input_shape[-1]])
+
+    # deconvolve with the bilinearly convolved kernel
+    outputs = deconv2d(inputs, filters, kernel_size=kernel_up.get_shape().as_list()[:2], strides=strides,
+                       kernel=kernel_up, padding='SAME', use_bias=False)
+    if use_bias:
+        if bias is None:
+            with tf.variable_scope('upsample_conv2d'):
+                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
+        else:
+            bias_shape = [filters]
+            if bias_shape != bias.get_shape().as_list():
+                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" %
+                                 (tuple(bias_shape), tuple(bias.get_shape().as_list())))
+        outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def upsample_conv2d_v2(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
+                       kernel=None, use_bias=True, bias=None, upsample_mode='bilinear'):
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    if padding != 'SAME':
+        raise NotImplementedError
+    input_shape = inputs.get_shape().as_list()
+    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
+    if kernel is None:
+        with tf.variable_scope('upsample_conv2d'):
+            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
+    else:
+        if kernel_shape != kernel.get_shape().as_list():
+            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" %
+                             (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
+
+    inputs_up = upsample2d_v2(inputs, strides=strides, padding='VALID', upsample_mode=upsample_mode)
+    # convolve upsampled input with kernel
+    outputs = conv2d(inputs_up, filters, kernel_size=kernel_size, strides=(1, 1),
+                     kernel=kernel, bias=None, padding='FULL', use_bias=False)
+    # crop appropriately
+    same_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='SAME')
+    full_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='FULL')
+    crop_top = (strides[0] - strides[0] % 2) // 2 + full_paddings[1][1] - same_paddings[1][1]
+    crop_left = (strides[1] - strides[1] % 2) // 2 + full_paddings[2][1] - same_paddings[2][1]
+    outputs = outputs[:, crop_top:crop_top + strides[0] * input_shape[1],
+              crop_left:crop_left + strides[1] * input_shape[2], :]
+
+    if use_bias:
+        if bias is None:
+            with tf.variable_scope('upsample_conv2d'):
+                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
+        else:
+            bias_shape = [filters]
+            if bias_shape != bias.get_shape().as_list():
+                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" %
+                                 (tuple(bias_shape), tuple(bias.get_shape().as_list())))
+        outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def conv3d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', use_bias=True, use_spectral_norm=False):
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 3
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 3
+    input_shape = inputs.get_shape().as_list()
+    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
+    with tf.variable_scope('conv3d'):
+        kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
+        if use_spectral_norm:
+            kernel = spectral_normed_weight(kernel)
+    outputs = tf.nn.conv3d(inputs, kernel, [1] + strides + [1], padding=padding)
+    if use_bias:
+        bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
+        outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def pool2d(inputs, pool_size, strides=(1, 1), padding='SAME', pool_mode='avg'):
+    pool_size = list(pool_size) if isinstance(pool_size, (tuple, list)) else [pool_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    if padding == 'FULL':
+        inputs = pad2d(inputs, pool_size, strides=strides, padding=padding, mode='CONSTANT')
+        padding = 'VALID'
+    if pool_mode == 'max':
+        outputs = tf.nn.max_pool(inputs, [1] + pool_size + [1], [1] + strides + [1], padding=padding)
+    elif pool_mode == 'avg':
+        outputs = tf.nn.avg_pool(inputs, [1] + pool_size + [1], [1] + strides + [1], padding=padding)
+    else:
+        raise ValueError('Invalid pooling mode:', pool_mode)
+    return outputs
+
+
+def conv_pool2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, pool_mode='avg'):
+    """
+    Similar optimization as in upsample_conv2d
+
+    Example:
+        >>> import numpy as np
+        >>> import tensorflow as tf
+        >>> from video_prediction.ops import conv_pool2d, conv2d, pool2d
+        >>> inputs_shape = [4, 16, 16, 32]
+        >>> kernel_size = [3, 3]  # for convolution
+        >>> filters = 64  # for convolution
+        >>> strides = [2, 2]  # for pooling
+        >>> inputs = tf.get_variable("inputs", inputs_shape)
+        >>> kernel = tf.get_variable("kernel", (kernel_size[0], kernel_size[1], inputs_shape[-1], filters))
+        >>> bias = tf.get_variable("bias", (filters,))
+        >>> outputs = conv_pool2d(inputs, filters, kernel_size=kernel_size, strides=strides,
+                                  kernel=kernel, bias=bias, pool_mode='avg')
+        >>> inputs_conv = conv2d(inputs, filters, kernel_size=kernel_size, strides=(1, 1),
+                                 kernel=kernel, bias=bias)
+        >>> outputs_pool = pool2d(inputs_conv, pool_size=strides, strides=strides, pool_mode='avg')
+        >>> sess = tf.Session()
+        >>> sess.run(tf.global_variables_initializer())
+        >>> assert np.allclose(*sess.run([outputs, outputs_pool]), atol=1e-5)
+
+    """
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    if padding != 'SAME' or pool_mode != 'avg':
+        raise NotImplementedError
+    input_shape = inputs.get_shape().as_list()
+    if input_shape[1] % strides[0] or input_shape[2] % strides[1]:
+        raise NotImplementedError("The height and width of the input should be "
+                                  "an integer multiple of the respective stride.")
+    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
+    if kernel is None:
+        with tf.variable_scope('conv_pool2d'):
+            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
+    else:
+        if kernel_shape != kernel.get_shape().as_list():
+            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" %
+                             (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
+
+    # pool kernel
+    kernel_reshaped = tf.reshape(kernel, [1] + kernel_size + [input_shape[-1] * filters])
+    kernel_pool_reshaped = pool2d(kernel_reshaped, pool_size=strides, padding='FULL', pool_mode='avg')
+    kernel_pool = tf.reshape(kernel_pool_reshaped,
+                             kernel_pool_reshaped.get_shape().as_list()[1:3] + [input_shape[-1], filters])
+
+    outputs = conv2d(inputs, filters, kernel_size=kernel_pool.get_shape().as_list()[:2], strides=strides,
+                     kernel=kernel_pool, padding='SAME', use_bias=False)
+    if use_bias:
+        if bias is None:
+            with tf.variable_scope('conv_pool2d'):
+                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
+        else:
+            bias_shape = [filters]
+            if bias_shape != bias.get_shape().as_list():
+                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" %
+                                 (tuple(bias_shape), tuple(bias.get_shape().as_list())))
+        outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def conv_pool2d_v2(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, pool_mode='avg'):
+    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
+    if padding != 'SAME' or pool_mode != 'avg':
+        raise NotImplementedError
+    input_shape = inputs.get_shape().as_list()
+    if input_shape[1] % strides[0] or input_shape[2] % strides[1]:
+        raise NotImplementedError("The height and width of the input should be "
+                                  "an integer multiple of the respective stride.")
+    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
+    if kernel is None:
+        with tf.variable_scope('conv_pool2d'):
+            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
+                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
+    else:
+        if kernel_shape != kernel.get_shape().as_list():
+            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" %
+                             (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
+
+    inputs_conv = conv2d(inputs, filters, kernel_size=kernel_size, strides=(1, 1),
+                         kernel=kernel, bias=None, use_bias=False)
+    outputs = pool2d(inputs_conv, pool_size=strides, strides=strides, pool_mode='avg')
+
+    if use_bias:
+        if bias is None:
+            with tf.variable_scope('conv_pool2d'):
+                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
+        else:
+            bias_shape = [filters]
+            if bias_shape != bias.get_shape().as_list():
+                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" %
+                                 (tuple(bias_shape), tuple(bias.get_shape().as_list())))
+        outputs = tf.nn.bias_add(outputs, bias)
+    return outputs
+
+
+def lrelu(x, alpha):
+    """
+    Leaky ReLU activation function
+
+    Reference:
+        https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_ops.py
+    """
+    with tf.name_scope("lrelu"):
+        return tf.maximum(alpha * x, x)
+
+
+def batchnorm(input):
+    with tf.variable_scope("batchnorm"):
+        # this block looks like it has 3 inputs on the graph unless we do this
+        input = tf.identity(input)
+
+        channels = input.get_shape()[-1]
+        offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
+        scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.truncated_normal_initializer(1.0, 0.02))
+        mean, variance = tf.nn.moments(input, axes=list(range(len(input.get_shape()) - 1)), keepdims=False)
+        variance_epsilon = 1e-5
+        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
+        return normalized
+
+
+def instancenorm(input):
+    with tf.variable_scope("instancenorm"):
+        # this block looks like it has 3 inputs on the graph unless we do this
+        input = tf.identity(input)
+
+        channels = input.get_shape()[-1]
+        offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
+        scale = tf.get_variable("scale", [channels], dtype=tf.float32,
+                                initializer=tf.truncated_normal_initializer(1.0, 0.02))
+        mean, variance = tf.nn.moments(input, axes=list(range(1, len(input.get_shape()) - 1)), keepdims=True)
+        variance_epsilon = 1e-5
+        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale,
+                                               variance_epsilon=variance_epsilon)
+        return normalized
+
+
+def flatten(input, axis=1, end_axis=-1):
+    """
+    Caffe-style flatten.
+
+    Args:
+        inputs: An N-D tensor.
+        axis: The first axis to flatten: all preceding axes are retained in the output.
+            May be negative to index from the end (e.g., -1 for the last axis).
+        end_axis: The last axis to flatten: all following axes are retained in the output.
+            May be negative to index from the end (e.g., the default -1 for the last
+            axis)
+
+    Returns:
+        A M-D tensor where M = N - (end_axis - axis)
+    """
+    input_shape = tf.shape(input)
+    input_rank = tf.shape(input_shape)[0]
+    if axis < 0:
+        axis = input_rank + axis
+    if end_axis < 0:
+        end_axis = input_rank + end_axis
+    output_shape = []
+    if axis != 0:
+        output_shape.append(input_shape[:axis])
+    output_shape.append([tf.reduce_prod(input_shape[axis:end_axis + 1])])
+    if end_axis + 1 != input_rank:
+        output_shape.append(input_shape[end_axis + 1:])
+    output_shape = tf.concat(output_shape, axis=0)
+    output = tf.reshape(input, output_shape)
+    return output
+
+
+def tile_concat(values, axis):
+    """
+    Like concat except that first tiles the broadcastable dimensions if necessary
+    """
+    shapes = [value.get_shape() for value in values]
+    # convert axis to positive form
+    ndims = shapes[0].ndims
+    for shape in shapes[1:]:
+        assert ndims == shape.ndims
+    if -ndims < axis < 0:
+        axis += ndims
+    # remove axis dimension
+    shapes = [shape.as_list() for shape in shapes]
+    dims = [shape.pop(axis) for shape in shapes]
+    shapes = [tf.TensorShape(shape) for shape in shapes]
+    # compute broadcasted shape
+    b_shape = shapes[0]
+    for shape in shapes[1:]:
+        b_shape = tf.broadcast_static_shape(b_shape, shape)
+    # add back axis dimension
+    b_shapes = [b_shape.as_list() for _ in dims]
+    for b_shape, dim in zip(b_shapes, dims):
+        b_shape.insert(axis, dim)
+    # tile values to match broadcasted shape, if necessary
+    b_values = []
+    for value, b_shape in zip(values, b_shapes):
+        multiples = []
+        for dim, b_dim in zip(value.get_shape().as_list(), b_shape):
+            if dim == b_dim:
+                multiples.append(1)
+            else:
+                assert dim == 1
+                multiples.append(b_dim)
+        if any(multiple != 1 for multiple in multiples):
+            b_value = tf.tile(value, multiples)
+        else:
+            b_value = value
+        b_values.append(b_value)
+    return tf.concat(b_values, axis=axis)
+
+
+def sigmoid_kl_with_logits(logits, targets):
+    # broadcasts the same target value across the whole batch
+    # this is implemented so awkwardly because tensorflow lacks an x log x op
+    assert isinstance(targets, float)
+    if targets in [0., 1.]:
+        entropy = 0.
+    else:
+        entropy = - targets * np.log(targets) - (1. - targets) * np.log(1. - targets)
+    return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits) * targets) - entropy
+
+
+def spectral_normed_weight(W, u=None, num_iters=1):
+    SPECTRAL_NORMALIZATION_VARIABLES = 'spectral_normalization_variables'
+
+    # Usually num_iters = 1 will be enough
+    W_shape = W.shape.as_list()
+    W_reshaped = tf.reshape(W, [-1, W_shape[-1]])
+    if u is None:
+        u = tf.get_variable("u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
+
+    def l2normalize(v, eps=1e-12):
+        return v / (tf.norm(v) + eps)
+
+    def power_iteration(i, u_i, v_i):
+        v_ip1 = l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped)))
+        u_ip1 = l2normalize(tf.matmul(v_ip1, W_reshaped))
+        return i + 1, u_ip1, v_ip1
+    _, u_final, v_final = tf.while_loop(
+        cond=lambda i, _1, _2: i < num_iters,
+        body=power_iteration,
+        loop_vars=(tf.constant(0, dtype=tf.int32),
+                   u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]]))
+    )
+    sigma = tf.squeeze(tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final)))
+    W_bar_reshaped = W_reshaped / sigma
+    W_bar = tf.reshape(W_bar_reshaped, W_shape)
+
+    if u not in tf.get_collection(SPECTRAL_NORMALIZATION_VARIABLES):
+        tf.add_to_collection(SPECTRAL_NORMALIZATION_VARIABLES, u)
+        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u.assign(u_final))
+    return W_bar
+
+
+def get_activation_layer(layer_type):
+    if layer_type == 'relu':
+        layer = tf.nn.relu
+    elif layer_type == 'elu':
+        layer = tf.nn.elu
+    else:
+        raise ValueError('Invalid activation layer %s' % layer_type)
+    return layer
+
+
+def get_norm_layer(layer_type):
+    if layer_type == 'batch':
+        layer = tf.layers.batch_normalization
+    elif layer_type == 'layer':
+        layer = tf.contrib.layers.layer_norm
+    elif layer_type == 'instance':
+        from video_prediction.layers import fused_instance_norm
+        layer = fused_instance_norm
+    elif layer_type == 'none':
+        layer = tf.identity
+    else:
+        raise ValueError('Invalid normalization layer %s' % layer_type)
+    return layer
+
+
+def get_upsample_layer(layer_type):
+    if layer_type == 'deconv2d':
+        layer = deconv2d
+    elif layer_type == 'upsample_conv2d':
+        layer = upsample_conv2d
+    elif layer_type == 'upsample_conv2d_v2':
+        layer = upsample_conv2d_v2
+    else:
+        raise ValueError('Invalid upsampling layer %s' % layer_type)
+    return layer
+
+
+def get_downsample_layer(layer_type):
+    if layer_type == 'conv2d':
+        layer = conv2d
+    elif layer_type == 'conv_pool2d':
+        layer = conv_pool2d
+    elif layer_type == 'conv_pool2d_v2':
+        layer = conv_pool2d_v2
+    else:
+        raise ValueError('Invalid downsampling layer %s' % layer_type)
+    return layer
diff --git a/video_prediction/rnn_ops.py b/video_prediction/rnn_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..970fae3ec656dd3e677e906602f0d29d6650b2c7
--- /dev/null
+++ b/video_prediction/rnn_ops.py
@@ -0,0 +1,267 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Convolutional LSTM implementation."""
+
+import tensorflow as tf
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import variable_scope as vs
+
+
+class BasicConv2DLSTMCell(rnn_cell_impl.RNNCell):
+    """2D Convolutional LSTM cell with (optional) normalization and recurrent dropout.
+
+    The implementation is based on: tf.contrib.rnn.LayerNormBasicLSTMCell.
+
+    It does not allow cell clipping, a projection layer, and does not
+    use peep-hole connections: it is the basic baseline.
+    """
+    def __init__(self, input_shape, filters, kernel_size,
+                 forget_bias=1.0, activation_fn=math_ops.tanh,
+                 normalizer_fn=None, separate_norms=True,
+                 norm_gain=1.0, norm_shift=0.0,
+                 dropout_keep_prob=1.0, dropout_prob_seed=None,
+                 skip_connection=False, reuse=None):
+        """Initializes the basic convolutional LSTM cell.
+
+        Args:
+            input_shape: int tuple, Shape of the input, excluding the batch size.
+            filters: int, The number of filters of the conv LSTM cell.
+            kernel_size: int tuple, The kernel size of the conv LSTM cell.
+            forget_bias: float, The bias added to forget gates (see above).
+            activation_fn: Activation function of the inner states.
+            normalizer_fn: If specified, this normalization will be applied before the
+                internal nonlinearities.
+            separate_norms: If set to `False`, the normalizer_fn is applied to the
+                concatenated tensor that follows the convolution, i.e. before splitting
+                the tensor. This case is slightly faster but it might be functionally
+                different, depending on the normalizer_fn (it's functionally the same
+                for instance norm but not for layer norm). Default: `True`.
+            norm_gain: float, The layer normalization gain initial value. If
+                `normalizer_fn` is `None`, this argument will be ignored.
+            norm_shift: float, The layer normalization shift initial value. If
+                `normalizer_fn` is `None`, this argument will be ignored.
+            dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
+                recurrent dropout probability value. If float and 1.0, no dropout will
+                be applied.
+            dropout_prob_seed: (optional) integer, the randomness seed.
+            skip_connection: If set to `True`, concatenate the input to the
+                output of the conv LSTM. Default: `False`.
+            reuse: (optional) Python boolean describing whether to reuse variables
+                in an existing scope.  If not `True`, and the existing scope already has
+                the given variables, an error is raised.
+        """
+        super(BasicConv2DLSTMCell, self).__init__(_reuse=reuse)
+
+        self._input_shape = input_shape
+        self._filters = filters
+        self._kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+        self._forget_bias = forget_bias
+        self._activation_fn = activation_fn
+        self._normalizer_fn = normalizer_fn
+        self._separate_norms = separate_norms
+        self._g = norm_gain
+        self._b = norm_shift
+        self._keep_prob = dropout_keep_prob
+        self._seed = dropout_prob_seed
+        self._skip_connection = skip_connection
+        self._reuse = reuse
+
+        if self._skip_connection:
+            output_channels = self._filters + self._input_shape[-1]
+        else:
+            output_channels = self._filters
+        cell_size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._filters])
+        self._output_size = tensor_shape.TensorShape(self._input_shape[:-1] + [output_channels])
+        self._state_size = rnn_cell_impl.LSTMStateTuple(cell_size, self._output_size)
+
+    @property
+    def output_size(self):
+        return self._output_size
+
+    @property
+    def state_size(self):
+        return self._state_size
+
+    def _norm(self, inputs, scope):
+        shape = inputs.get_shape()[-1:]
+        gamma_init = init_ops.constant_initializer(self._g)
+        beta_init = init_ops.constant_initializer(self._b)
+        with vs.variable_scope(scope):
+            # Initialize beta and gamma for use by normalizer.
+            vs.get_variable("gamma", shape=shape, initializer=gamma_init)
+            vs.get_variable("beta", shape=shape, initializer=beta_init)
+        normalized = self._normalizer_fn(inputs, reuse=True, scope=scope)
+        return normalized
+
+    def _conv2d(self, inputs):
+        output_filters = 4 * self._filters
+        input_shape = inputs.get_shape().as_list()
+        kernel_shape = list(self._kernel_size) + [input_shape[-1], output_filters]
+        kernel = vs.get_variable("kernel", kernel_shape, dtype=dtypes.float32,
+                                 initializer=init_ops.truncated_normal_initializer(stddev=0.02))
+        outputs = nn_ops.conv2d(inputs, kernel, [1] * 4, padding='SAME')
+        if not self._normalizer_fn:
+            bias = vs.get_variable('bias', [output_filters], dtype=dtypes.float32,
+                                   initializer=init_ops.zeros_initializer())
+            outputs = nn_ops.bias_add(outputs, bias)
+        return outputs
+
+    def _dense(self, inputs):
+        num_units = 4 * self._filters
+        input_shape = inputs.shape.as_list()
+        kernel_shape = [input_shape[-1], num_units]
+        kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32,
+                                 initializer=init_ops.truncated_normal_initializer(stddev=0.02))
+        outputs = tf.matmul(inputs, kernel)
+        return outputs
+
+    def call(self, inputs, state):
+        """2D Convolutional LSTM cell with (optional) normalization and recurrent dropout."""
+        c, h = state
+        tile_concat = isinstance(inputs, (list, tuple))
+        if tile_concat:
+            inputs, inputs_non_spatial = inputs
+        args = array_ops.concat([inputs, h], -1)
+        concat = self._conv2d(args)
+        if tile_concat:
+            concat = concat + self._dense(inputs_non_spatial)[:, None, None, :]
+
+        if self._normalizer_fn and not self._separate_norms:
+            concat = self._norm(concat, "input_transform_forget_output")
+        i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=-1)
+        if self._normalizer_fn and self._separate_norms:
+            i = self._norm(i, "input")
+            j = self._norm(j, "transform")
+            f = self._norm(f, "forget")
+            o = self._norm(o, "output")
+
+        g = self._activation_fn(j)
+        if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
+            g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
+
+        new_c = (c * math_ops.sigmoid(f + self._forget_bias)
+                 + math_ops.sigmoid(i) * g)
+        if self._normalizer_fn:
+            new_c = self._norm(new_c, "state")
+        new_h = self._activation_fn(new_c) * math_ops.sigmoid(o)
+
+        if self._skip_connection:
+            new_h = array_ops.concat([new_h, inputs], axis=-1)
+
+        new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
+        return new_h, new_state
+
+
+class Conv2DGRUCell(tf.nn.rnn_cell.RNNCell):
+    """2D Convolutional GRU cell with (optional) normalization.
+
+    Modified from these:
+    https://github.com/carlthome/tensorflow-convlstm-cell/blob/master/cell.py
+    https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn_cell_impl.py
+    """
+    def __init__(self, input_shape, filters, kernel_size,
+                 activation_fn=tf.tanh,
+                 normalizer_fn=None, separate_norms=True,
+                 bias_initializer=None, reuse=None):
+        super(Conv2DGRUCell, self).__init__(_reuse=reuse)
+        self._input_shape = input_shape
+        self._filters = filters
+        self._kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
+        self._activation_fn = activation_fn
+        self._normalizer_fn = normalizer_fn
+        self._separate_norms = separate_norms
+        self._bias_initializer = bias_initializer
+        self._size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._filters])
+
+    @property
+    def state_size(self):
+        return self._size
+
+    @property
+    def output_size(self):
+        return self._size
+
+    def _norm(self, inputs, scope, bias_initializer):
+        shape = inputs.get_shape()[-1:]
+        gamma_init = init_ops.ones_initializer()
+        beta_init = bias_initializer
+        with vs.variable_scope(scope):
+            # Initialize beta and gamma for use by normalizer.
+            vs.get_variable("gamma", shape=shape, initializer=gamma_init)
+            vs.get_variable("beta", shape=shape, initializer=beta_init)
+        normalized = self._normalizer_fn(inputs, reuse=True, scope=scope)
+        return normalized
+
+    def _conv2d(self, inputs, output_filters, bias_initializer):
+        input_shape = inputs.get_shape().as_list()
+        kernel_shape = list(self._kernel_size) + [input_shape[-1], output_filters]
+        kernel = vs.get_variable("kernel", kernel_shape, dtype=dtypes.float32,
+                                 initializer=init_ops.truncated_normal_initializer(stddev=0.02))
+        outputs = nn_ops.conv2d(inputs, kernel, [1] * 4, padding='SAME')
+        if not self._normalizer_fn:
+            bias = vs.get_variable('bias', [output_filters], dtype=dtypes.float32,
+                                   initializer=bias_initializer)
+            outputs = nn_ops.bias_add(outputs, bias)
+        return outputs
+
+    def _dense(self, inputs, num_units):
+        input_shape = inputs.shape.as_list()
+        kernel_shape = [input_shape[-1], num_units]
+        kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32,
+                                 initializer=init_ops.truncated_normal_initializer(stddev=0.02))
+        outputs = tf.matmul(inputs, kernel)
+        return outputs
+
+    def call(self, inputs, state):
+        bias_ones = self._bias_initializer
+        if self._bias_initializer is None:
+            bias_ones = init_ops.ones_initializer()
+        tile_concat = isinstance(inputs, (list, tuple))
+        if tile_concat:
+            inputs, inputs_non_spatial = inputs
+        with vs.variable_scope('gates'):
+            inputs = array_ops.concat([inputs, state], axis=-1)
+            concat = self._conv2d(inputs, 2 * self._filters, bias_ones)
+            if tile_concat:
+                concat = concat + self._dense(inputs_non_spatial, concat.shape[-1].value)[:, None, None, :]
+            if self._normalizer_fn and not self._separate_norms:
+                concat = self._norm(concat, "reset_update", bias_ones)
+            r, u = array_ops.split(concat, 2, axis=-1)
+            if self._normalizer_fn and self._separate_norms:
+                r = self._norm(r, "reset", bias_ones)
+                u = self._norm(u, "update", bias_ones)
+            r, u = math_ops.sigmoid(r), math_ops.sigmoid(u)
+
+        bias_zeros = self._bias_initializer
+        if self._bias_initializer is None:
+            bias_zeros = init_ops.zeros_initializer()
+        with vs.variable_scope('candidate'):
+            inputs = array_ops.concat([inputs, r * state], axis=-1)
+            candidate = self._conv2d(inputs, self._filters, bias_zeros)
+            if tile_concat:
+                candidate = candidate + self._dense(inputs_non_spatial, candidate.shape[-1].value)[:, None, None, :]
+            if self._normalizer_fn:
+                candidate = self._norm(candidate, "state", bias_zeros)
+
+        c = self._activation_fn(candidate)
+        new_h = u * state + (1 - u) * c
+        return new_h, new_h
diff --git a/video_prediction/utils/__init__.py b/video_prediction/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/video_prediction/utils/ffmpeg_gif.py b/video_prediction/utils/ffmpeg_gif.py
new file mode 100644
index 0000000000000000000000000000000000000000..724933f7e1057e2e0b6174401021e6941bda5959
--- /dev/null
+++ b/video_prediction/utils/ffmpeg_gif.py
@@ -0,0 +1,95 @@
+import os
+
+import numpy as np
+
+
+def save_gif(gif_fname, images, fps):
+    """
+    To generate a gif from image files, first generate palette from images
+    and then generate the gif from the images and the palette.
+    ffmpeg -i input_%02d.jpg -vf palettegen -y palette.png
+    ffmpeg -i input_%02d.jpg -i palette.png -lavfi paletteuse -y output.gif
+
+    Alternatively, use a filter to map the input images to both the palette
+    and gif commands, while also passing the palette to the gif command.
+    ffmpeg -i input_%02d.jpg -filter_complex "[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse" -y output.gif
+
+    To directly pass in numpy images, use rawvideo format and `-i -` option.
+    """
+    from subprocess import Popen, PIPE
+    head, tail = os.path.split(gif_fname)
+    if head and not os.path.exists(head):
+        os.makedirs(head)
+    h, w, c = images[0].shape
+    cmd = ['ffmpeg', '-y',
+           '-f', 'rawvideo',
+           '-vcodec', 'rawvideo',
+           '-r', '%.02f' % fps,
+           '-s', '%dx%d' % (w, h),
+           '-pix_fmt', {1: 'gray', 3: 'rgb24', 4: 'rgba'}[c],
+           '-i', '-',
+           '-filter_complex', '[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse',
+           '-r', '%.02f' % fps,
+           '%s' % gif_fname]
+    proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE)
+    for image in images:
+        proc.stdin.write(image.tostring())
+    out, err = proc.communicate()
+    if proc.returncode:
+        err = '\n'.join([' '.join(cmd), err.decode('utf8')])
+        raise IOError(err)
+    del proc
+
+
+def encode_gif(images, fps):
+    """Encodes numpy images into gif string.
+    Args:
+        images: A 5-D `uint8` `np.array` (or a list of 4-D images) of shape
+            `[batch_size, time, height, width, channels]` where `channels` is 1 or 3.
+        fps: frames per second of the animation
+    Returns:
+        The encoded gif string.
+    Raises:
+        IOError: If the ffmpeg command returns an error.
+    """
+    from subprocess import Popen, PIPE
+    h, w, c = images[0].shape
+    cmd = ['ffmpeg', '-y',
+           '-f', 'rawvideo',
+           '-vcodec', 'rawvideo',
+           '-r', '%.02f' % fps,
+           '-s', '%dx%d' % (w, h),
+           '-pix_fmt', {1: 'gray', 3: 'rgb24'}[c],
+           '-i', '-',
+           '-filter_complex', '[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse',
+           '-r', '%.02f' % fps,
+           '-f', 'gif',
+           '-']
+    proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE)
+    for image in images:
+        proc.stdin.write(image.tostring())
+    out, err = proc.communicate()
+    if proc.returncode:
+        err = '\n'.join([' '.join(cmd), err.decode('utf8')])
+        raise IOError(err)
+    del proc
+    return out
+
+
+def main():
+    images_shape = (12, 64, 64, 3)  # num_frames, height, width, channels
+    images = np.random.randint(256, size=images_shape).astype(np.uint8)
+
+    save_gif('output_save.gif', images, 4)
+    with open('output_save.gif', 'rb') as f:
+        string_save = f.read()
+
+    string_encode = encode_gif(images, 4)
+    with open('output_encode.gif', 'wb') as f:
+        f.write(string_encode)
+
+    print(np.all(string_save == string_encode))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/video_prediction/utils/gif_summary.py b/video_prediction/utils/gif_summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..55f89987855c0288de827326107c9d72abc8ba6c
--- /dev/null
+++ b/video_prediction/utils/gif_summary.py
@@ -0,0 +1,114 @@
+# coding=utf-8
+# Copyright 2018 The Tensor2Tensor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.ops import summary_op_util
+#from tensorflow.python.distribute.summary_op_util import skip_summary TODO: IMPORT ERRORS IN juwels
+from video_prediction.utils import ffmpeg_gif
+
+
+def py_gif_summary(tag, images, max_outputs, fps):
+  """Outputs a `Summary` protocol buffer with gif animations.
+  Args:
+    tag: Name of the summary.
+    images: A 5-D `uint8` `np.array` of shape `[batch_size, time, height, width,
+      channels]` where `channels` is 1 or 3.
+    max_outputs: Max number of batch elements to generate gifs for.
+    fps: frames per second of the animation
+  Returns:
+    The serialized `Summary` protocol buffer.
+  Raises:
+    ValueError: If `images` is not a 5-D `uint8` array with 1 or 3 channels.
+  """
+  is_bytes = isinstance(tag, bytes)
+  if is_bytes:
+    tag = tag.decode("utf-8")
+  images = np.asarray(images)
+  if images.dtype != np.uint8:
+    raise ValueError("Tensor must have dtype uint8 for gif summary.")
+  if images.ndim != 5:
+    raise ValueError("Tensor must be 5-D for gif summary.")
+  batch_size, _, height, width, channels = images.shape
+  if channels not in (1, 3):
+    raise ValueError("Tensors must have 1 or 3 channels for gif summary.")
+
+  summ = tf.Summary()
+  num_outputs = min(batch_size, max_outputs)
+  for i in range(num_outputs):
+    image_summ = tf.Summary.Image()
+    image_summ.height = height
+    image_summ.width = width
+    image_summ.colorspace = channels  # 1: grayscale, 3: RGB
+    try:
+      image_summ.encoded_image_string = ffmpeg_gif.encode_gif(images[i], fps)
+    except (IOError, OSError) as e:
+      tf.logging.warning(
+          "Unable to encode images to a gif string because either ffmpeg is "
+          "not installed or ffmpeg returned an error: %s. Falling back to an "
+          "image summary of the first frame in the sequence.", e)
+      try:
+        from PIL import Image  # pylint: disable=g-import-not-at-top
+        import io  # pylint: disable=g-import-not-at-top
+        with io.BytesIO() as output:
+          Image.fromarray(images[i][0]).save(output, "PNG")
+          image_summ.encoded_image_string = output.getvalue()
+      except:
+        tf.logging.warning(
+            "Gif summaries requires ffmpeg or PIL to be installed: %s", e)
+        image_summ.encoded_image_string = "".encode('utf-8') if is_bytes else ""
+    if num_outputs == 1:
+      summ_tag = "{}/gif".format(tag)
+    else:
+      summ_tag = "{}/gif/{}".format(tag, i)
+    summ.value.add(tag=summ_tag, image=image_summ)
+  summ_str = summ.SerializeToString()
+  return summ_str
+
+
+def gif_summary(name, tensor, max_outputs=3, fps=10, collections=None,
+                family=None):
+  """Outputs a `Summary` protocol buffer with gif animations.
+  Args:
+    name: Name of the summary.
+    tensor: A 5-D `uint8` `Tensor` of shape `[batch_size, time, height, width,
+      channels]` where `channels` is 1 or 3.
+    max_outputs: Max number of batch elements to generate gifs for.
+    fps: frames per second of the animation
+    collections: Optional list of tf.GraphKeys.  The collections to add the
+      summary to.  Defaults to [tf.GraphKeys.SUMMARIES]
+    family: Optional; if provided, used as the prefix of the summary tag name,
+      which controls the tab name used for display on Tensorboard.
+  Returns:
+    A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+    buffer.
+  """
+  tensor = tf.convert_to_tensor(tensor)
+  # if skip_summary(): TODO: skipo summary errors happend in JUEWLS
+  #   return tf.constant("")
+  with summary_op_util.summary_scope(
+      name, family, values=[tensor]) as (tag, scope):
+    val = tf.py_func(
+        py_gif_summary,
+        [tag, tensor, max_outputs, fps],
+        tf.string,
+        stateful=False,
+        name=scope)
+    summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES])
+  return val
diff --git a/video_prediction/utils/html.py b/video_prediction/utils/html.py
new file mode 100755
index 0000000000000000000000000000000000000000..4334f87b56446c238d5fd15a34ecebef889e003b
--- /dev/null
+++ b/video_prediction/utils/html.py
@@ -0,0 +1,105 @@
+import os
+
+import dominate
+from dominate.tags import *
+
+
+class HTML:
+    def __init__(self, web_dir, title, reflesh=0):
+        self.title = title
+        self.web_dir = web_dir
+        self.img_dir = os.path.join(self.web_dir, 'images')
+        if not os.path.exists(self.web_dir):
+            os.makedirs(self.web_dir)
+        if not os.path.exists(self.img_dir):
+            os.makedirs(self.img_dir)
+        # print(self.img_dir)
+
+        self.doc = dominate.document(title=title)
+        if reflesh > 0:
+            with self.doc.head:
+                meta(http_equiv="reflesh", content=str(reflesh))
+        self.t = None
+
+    def get_image_dir(self):
+        return self.img_dir
+
+    def add_header1(self, str):
+        with self.doc:
+            h1(str)
+
+    def add_header2(self, str):
+        with self.doc:
+            h2(str)
+
+    def add_header3(self, str):
+        with self.doc:
+            h3(str)
+
+    def add_table(self, border=1):
+        self.t = table(border=border, style="table-layout: fixed;")
+        self.doc.add(self.t)
+
+    def add_row(self, txts, colspans=None):
+        if self.t is None:
+            self.add_table()
+        with self.t:
+            with tr():
+                if colspans:
+                    assert len(txts) == len(colspans)
+                    colspans = [dict(colspan=str(colspan)) for colspan in colspans]
+                else:
+                    colspans = [dict()] * len(txts)
+                for txt, colspan in zip(txts, colspans):
+                    style = "word-break: break-all;" if len(str(txt)) > 80 else "word-wrap: break-word;"
+                    with td(style=style, halign="center", valign="top", **colspan):
+                        with p():
+                            if txt is not None:
+                                p(txt)
+
+    def add_images(self, ims, txts, links, colspans=None, height=None, width=400):
+        image_style = ''
+        if height is not None:
+            image_style += "height:%dpx;" % height
+        if width is not None:
+            image_style += "width:%dpx;" % width
+        if self.t is None:
+            self.add_table()
+        with self.t:
+            with tr():
+                if colspans:
+                    assert len(txts) == len(colspans)
+                    colspans = [dict(colspan=str(colspan)) for colspan in colspans]
+                else:
+                    colspans = [dict()] * len(txts)
+                for im, txt, link, colspan in zip(ims, txts, links, colspans):
+                    with td(style="word-wrap: break-word;", halign="center", valign="top", **colspan):
+                        with p():
+                            if im is not None and link is not None:
+                                with a(href=os.path.join('images', link)):
+                                    img(style=image_style, src=os.path.join('images', im))
+                            if im is not None and link is not None and txt is not None:
+                                br()
+                            if txt is not None:
+                                p(txt)
+
+    def save(self):
+        html_file = '%s/index.html' % self.web_dir
+        f = open(html_file, 'wt')
+        f.write(self.doc.render())
+        f.close()
+
+
+if __name__ == '__main__':
+    html = HTML('web/', 'test_html')
+    html.add_header('hello world')
+
+    ims = []
+    txts = []
+    links = []
+    for n in range(4):
+        ims.append('image_%d.jpg' % n)
+        txts.append('text_%d' % n)
+        links.append('image_%d.jpg' % n)
+    html.add_images(ims, txts, links)
+    html.save()
diff --git a/video_prediction/utils/mcnet_utils.py b/video_prediction/utils/mcnet_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bad0131218f02e96d0e39132bc0d677547041da
--- /dev/null
+++ b/video_prediction/utils/mcnet_utils.py
@@ -0,0 +1,156 @@
+"""
+Some codes from https://github.com/Newmu/dcgan_code
+"""
+
+import cv2
+import random
+import imageio
+import scipy.misc
+import numpy as np
+
+
+def transform(image):
+    return image/127.5 - 1.
+
+
+def inverse_transform(images):
+    return (images+1.)/2.
+
+
+def save_images(images, size, image_path):
+    return imsave(inverse_transform(images)*255., size, image_path)
+
+
+def merge(images, size):
+    h, w = images.shape[1], images.shape[2]
+    img = np.zeros((h * size[0], w * size[1], 3))
+
+    for idx, image in enumerate(images):
+        i = idx % size[1]
+        j = idx / size[1]
+        img[j*h:j*h+h, i*w:i*w+w, :] = image
+
+    return img
+
+
+def imsave(images, size, path):
+    return scipy.misc.imsave(path, merge(images, size))
+
+
+def get_minibatches_idx(n, minibatch_size, shuffle=False):
+    """ 
+    Used to shuffle the dataset at each iteration.
+    """
+    idx_list = np.arange(n, dtype="int32")
+
+    if shuffle:
+        random.shuffle(idx_list)
+
+    minibatches = []
+    minibatch_start = 0 
+    for i in range(n // minibatch_size):
+        minibatches.append(idx_list[minibatch_start:minibatch_start + minibatch_size])
+        minibatch_start += minibatch_size
+
+    if (minibatch_start != n): 
+    # Make a minibatch out of what is left
+        minibatches.append(idx_list[minibatch_start:])
+
+    return zip(range(len(minibatches)), minibatches)
+
+
+def draw_frame(img, is_input):
+    if img.shape[2] == 1:
+        img = np.repeat(img, [3], axis=2)
+    if is_input:
+        img[:2,:,0]  = img[:2,:,2] = 0 
+        img[:,:2,0]  = img[:,:2,2] = 0 
+        img[-2:,:,0] = img[-2:,:,2] = 0 
+        img[:,-2:,0] = img[:,-2:,2] = 0 
+        img[:2,:,1]  = 255 
+        img[:,:2,1]  = 255 
+        img[-2:,:,1] = 255 
+        img[:,-2:,1] = 255 
+    else:
+        img[:2,:,0]  = img[:2,:,1] = 0 
+        img[:,:2,0]  = img[:,:2,2] = 0 
+        img[-2:,:,0] = img[-2:,:,1] = 0 
+        img[:,-2:,0] = img[:,-2:,1] = 0 
+        img[:2,:,2]  = 255 
+        img[:,:2,2]  = 255 
+        img[-2:,:,2] = 255 
+        img[:,-2:,2] = 255 
+
+    return img 
+
+
+def load_kth_data(f_name, data_path, image_size, K, T): 
+    flip = np.random.binomial(1,.5,1)[0]
+    tokens = f_name.split()
+    vid_path = data_path + tokens[0] + "_uncomp.avi"
+    vid = imageio.get_reader(vid_path,"ffmpeg")
+    low = int(tokens[1])
+    high = np.min([int(tokens[2]),vid.get_length()])-K-T+1
+    if low == high:
+        stidx = 0 
+    else:
+        if low >= high: print(vid_path)
+        stidx = np.random.randint(low=low, high=high)
+    seq = np.zeros((image_size, image_size, K+T, 1), dtype="float32")
+    for t in xrange(K+T):
+        img = cv2.cvtColor(cv2.resize(vid.get_data(stidx+t),
+                           (image_size,image_size)),
+                           cv2.COLOR_RGB2GRAY)
+        seq[:,:,t] = transform(img[:,:,None])
+
+    if flip == 1:
+        seq = seq[:,::-1]
+
+    diff = np.zeros((image_size, image_size, K-1, 1), dtype="float32")
+    for t in xrange(1,K):
+        prev = inverse_transform(seq[:,:,t-1])
+        next = inverse_transform(seq[:,:,t])
+        diff[:,:,t-1] = next.astype("float32")-prev.astype("float32")
+
+    return seq, diff
+
+
+def load_s1m_data(f_name, data_path, trainlist, K, T):
+    flip = np.random.binomial(1,.5,1)[0]
+    vid_path = data_path + f_name  
+    img_size = [240,320]
+
+    while True:
+        try:
+            vid = imageio.get_reader(vid_path,"ffmpeg")
+            low = 1
+            high = vid.get_length()-K-T+1
+            if low == high:
+                stidx = 0
+            else:
+                stidx = np.random.randint(low=low, high=high)
+            seq = np.zeros((img_size[0], img_size[1], K+T, 3),
+                         dtype="float32")
+            for t in xrange(K+T):
+                img = cv2.resize(vid.get_data(stidx+t),
+                                 (img_size[1],img_size[0]))[:,:,::-1] 
+                seq[:,:,t] = transform(img)
+
+            if flip == 1:seq = seq[:,::-1]
+
+            diff = np.zeros((img_size[0], img_size[1], K-1, 1),
+                          dtype="float32")
+            for t in xrange(1,K):
+                prev = inverse_transform(seq[:,:,t-1])*255
+                prev = cv2.cvtColor(prev.astype("uint8"),cv2.COLOR_BGR2GRAY)
+                next = inverse_transform(seq[:,:,t])*255
+                next = cv2.cvtColor(next.astype("uint8"),cv2.COLOR_BGR2GRAY)
+                diff[:,:,t-1,0] = (next.astype("float32")-prev.astype("float32"))/255.
+            break
+        except Exception:
+            # In case the current video is bad load a random one 
+            rep_idx = np.random.randint(low=0, high=len(trainlist))
+            f_name = trainlist[rep_idx]
+            vid_path = data_path + f_name
+
+    return seq, diff
diff --git a/video_prediction/utils/tf_utils.py b/video_prediction/utils/tf_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e49a54b11fb07833f0ef33278d2e894b905afc
--- /dev/null
+++ b/video_prediction/utils/tf_utils.py
@@ -0,0 +1,609 @@
+import itertools
+import os
+from collections import OrderedDict
+
+import numpy as np
+import six
+import tensorflow as tf
+import tensorflow.contrib.graph_editor as ge
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.training import device_setter
+from tensorflow.python.util import nest
+
+from video_prediction.utils import ffmpeg_gif
+from video_prediction.utils import gif_summary
+
+IMAGE_SUMMARIES = "image_summaries"
+EVAL_SUMMARIES = "eval_summaries"
+
+
+def local_device_setter(num_devices=1,
+                        ps_device_type='cpu',
+                        worker_device='/cpu:0',
+                        ps_ops=None,
+                        ps_strategy=None):
+    if ps_ops == None:
+        ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
+
+    if ps_strategy is None:
+        ps_strategy = device_setter._RoundRobinStrategy(num_devices)
+    if not six.callable(ps_strategy):
+        raise TypeError("ps_strategy must be callable")
+
+    def _local_device_chooser(op):
+        current_device = pydev.DeviceSpec.from_string(op.device or "")
+
+        node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
+        if node_def.op in ps_ops:
+            ps_device_spec = pydev.DeviceSpec.from_string(
+                '/{}:{}'.format(ps_device_type, ps_strategy(op)))
+
+            ps_device_spec.merge_from(current_device)
+            return ps_device_spec.to_string()
+        else:
+            worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
+            worker_device_spec.merge_from(current_device)
+            return worker_device_spec.to_string()
+
+    return _local_device_chooser
+
+
+def replace_read_ops(loss_or_losses, var_list):
+    """
+    Replaces read ops of each variable in `vars` with new read ops obtained
+    from `read_value()`, thus forcing to read the most up-to-date values of
+    the variables (which might incur copies across devices).
+    The graph is seeded from the tensor(s) `loss_or_losses`.
+    """
+    # ops between var ops and the loss
+    ops = set(ge.get_walks_intersection_ops([var.op for var in var_list], loss_or_losses))
+    if not ops:  # loss_or_losses doesn't depend on any var in var_list, so there is nothiing to replace
+        return
+
+    # filter out variables that are not involved in computing the loss
+    var_list = [var for var in var_list if var.op in ops]
+
+    for var in var_list:
+        output, = var.op.outputs
+        read_ops = set(output.consumers()) & ops
+        for read_op in read_ops:
+            with tf.name_scope('/'.join(read_op.name.split('/')[:-1])):
+                with tf.device(read_op.device):
+                    read_t, = read_op.outputs
+                    consumer_ops = set(read_t.consumers()) & ops
+                    # consumer_sgv might have multiple inputs, but we only care
+                    # about replacing the input that is read_t
+                    consumer_sgv = ge.sgv(consumer_ops)
+                    consumer_sgv = consumer_sgv.remap_inputs([list(consumer_sgv.inputs).index(read_t)])
+                    ge.connect(ge.sgv(var.read_value().op), consumer_sgv)
+
+
+def print_loss_info(losses, *tensors):
+    def get_descendants(tensor, tensors):
+        descendants = []
+        for child in tensor.op.inputs:
+            if child in tensors:
+                descendants.append(child)
+            else:
+                descendants.extend(get_descendants(child, tensors))
+        return descendants
+
+    name_to_tensors = itertools.chain(*[tensor.items() for tensor in tensors])
+    tensor_to_names = OrderedDict([(v, k) for k, v in name_to_tensors])
+
+    print(tf.get_default_graph().get_name_scope())
+    for name, (loss, weight) in losses.items():
+        print('  %s (%r)' % (name, weight))
+        descendant_names = []
+        for descendant in set(get_descendants(loss, tensor_to_names.keys())):
+            descendant_names.append(tensor_to_names[descendant])
+        for descendant_name in sorted(descendant_names):
+            print('    %s' % descendant_name)
+
+
+def with_flat_batch(flat_batch_fn, ndims=4):
+    def fn(x, *args, **kwargs):
+        shape = tf.shape(x)
+        flat_batch_shape = tf.concat([[-1], shape[-(ndims-1):]], axis=0)
+        flat_batch_shape.set_shape([ndims])
+        flat_batch_x = tf.reshape(x, flat_batch_shape)
+        flat_batch_r = flat_batch_fn(flat_batch_x, *args, **kwargs)
+        r = nest.map_structure(lambda x: tf.reshape(x, tf.concat([shape[:-(ndims-1)], tf.shape(x)[1:]], axis=0)),
+                               flat_batch_r)
+        return r
+    return fn
+
+
+def transpose_batch_time(x):
+    if isinstance(x, tf.Tensor) and x.shape.ndims >= 2:
+        return tf.transpose(x, [1, 0] + list(range(2, x.shape.ndims)))
+    else:
+        return x
+
+
+def dimension(inputs, axis=0):
+    shapes = [input_.shape for input_ in nest.flatten(inputs)]
+    s = tf.TensorShape([None])
+    for shape in shapes:
+        s = s.merge_with(shape[axis:axis + 1])
+    dim = s[0].value
+    return dim
+
+
+def unroll_rnn(cell, inputs, scope=None, use_dynamic_rnn=True):
+    """Chooses between dynamic_rnn and static_rnn if the leading time dimension is dynamic or not."""
+    dim = dimension(inputs, axis=0)
+    if use_dynamic_rnn or dim is None:
+        return tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32,
+                                 swap_memory=False, time_major=True, scope=scope)
+    else:
+        return static_rnn(cell, inputs, scope=scope)
+
+
+def static_rnn(cell, inputs, scope=None):
+    """Simple version of static_rnn."""
+    with tf.variable_scope(scope or "rnn") as varscope:
+        batch_size = dimension(inputs, axis=1)
+        state = cell.zero_state(batch_size, tf.float32)
+        flat_inputs = nest.flatten(inputs)
+        flat_inputs = list(zip(*[tf.unstack(flat_input, axis=0) for flat_input in flat_inputs]))
+        flat_outputs = []
+        for time, flat_input in enumerate(flat_inputs):
+            if time > 0:
+                varscope.reuse_variables()
+            input_ = nest.pack_sequence_as(inputs, flat_input)
+            output, state = cell(input_, state)
+            flat_output = nest.flatten(output)
+            flat_outputs.append(flat_output)
+        flat_outputs = [tf.stack(flat_output, axis=0) for flat_output in zip(*flat_outputs)]
+        outputs = nest.pack_sequence_as(output, flat_outputs)
+        return outputs, state
+
+
+def maybe_pad_or_slice(tensor, desired_length):
+    length = tensor.shape.as_list()[0]
+    if length < desired_length:
+        paddings = [[0, desired_length - length]] + [[0, 0]] * (tensor.shape.ndims - 1)
+        tensor = tf.pad(tensor, paddings)
+    elif length > desired_length:
+        tensor = tensor[:desired_length]
+    assert tensor.shape.as_list()[0] == desired_length
+    return tensor
+
+
+def tensor_to_clip(tensor):
+    if tensor.shape.ndims == 6:
+        # concatenate last dimension vertically
+        tensor = tf.concat(tf.unstack(tensor, axis=-1), axis=-3)
+    if tensor.shape.ndims == 5:
+        # concatenate batch dimension horizontally
+        tensor = tf.concat(tf.unstack(tensor, axis=0), axis=2)
+    if tensor.shape.ndims == 4:
+        # keep up to the first 3 channels
+        tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True)
+    else:
+        raise NotImplementedError
+    return tensor
+
+
+def tensor_to_image_batch(tensor):
+    if tensor.shape.ndims == 6:
+        # concatenate last dimension vertically
+        tensor= tf.concat(tf.unstack(tensor, axis=-1), axis=-3)
+    if tensor.shape.ndims == 5:
+        # concatenate time dimension horizontally
+        tensor = tf.concat(tf.unstack(tensor, axis=1), axis=2)
+    if tensor.shape.ndims == 4:
+        # keep up to the first 3 channels
+        tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True)
+    else:
+        raise NotImplementedError
+    return tensor
+
+
+def _as_name_scope_map(values):
+    name_scope_to_values = {}
+    for name, value in values.items():
+        name_scope = name.split('/')[0]
+        name_scope_to_values.setdefault(name_scope, {})
+        name_scope_to_values[name_scope][name] = value
+    return name_scope_to_values
+
+
+def add_image_summaries(outputs, max_outputs=8, collections=None):
+    if collections is None:
+        collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES]
+    for name_scope, outputs in _as_name_scope_map(outputs).items():
+        with tf.name_scope(name_scope):
+            for name, output in outputs.items():
+                if max_outputs:
+                    output = output[:max_outputs]
+                output = tensor_to_image_batch(output)
+                if output.shape[-1] not in (1, 3):
+                    # these are feature maps, so just skip them
+                    continue
+                tf.summary.image(name, output, collections=collections)
+
+
+def add_gif_summaries(outputs, max_outputs=8, collections=None):
+    if collections is None:
+        collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES]
+    for name_scope, outputs in _as_name_scope_map(outputs).items():
+        with tf.name_scope(name_scope):
+            for name, output in outputs.items():
+                if max_outputs:
+                    output = output[:max_outputs]
+                output = tensor_to_clip(output)
+                if output.shape[-1] not in (1, 3):
+                    # these are feature maps, so just skip them
+                    continue
+                gif_summary.gif_summary(name, output[None], fps=4, collections=collections)
+
+
+def add_scalar_summaries(losses_or_metrics, collections=None):
+    for name_scope, losses_or_metrics in _as_name_scope_map(losses_or_metrics).items():
+        with tf.name_scope(name_scope):
+            for name, loss_or_metric in losses_or_metrics.items():
+                if isinstance(loss_or_metric, tuple):
+                    loss_or_metric, _ = loss_or_metric
+                tf.summary.scalar(name, loss_or_metric, collections=collections)
+
+
+def add_summaries(outputs, collections=None):
+    scalar_outputs = OrderedDict()
+    image_outputs = OrderedDict()
+    gif_outputs = OrderedDict()
+    for name, output in outputs.items():
+        if not isinstance(output, tf.Tensor):
+            continue
+        if output.shape.ndims == 0:
+            scalar_outputs[name] = output
+        elif output.shape.ndims == 4:
+            image_outputs[name] = output
+        elif output.shape.ndims > 4 and output.shape[4].value in (1, 3):
+            gif_outputs[name] = output
+    add_scalar_summaries(scalar_outputs, collections=collections)
+    add_image_summaries(image_outputs, collections=collections)
+    add_gif_summaries(gif_outputs, collections=collections)
+
+
+def plot_buf(y):
+    def _plot_buf(y):
+        from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
+        from matplotlib.figure import Figure
+        import io
+        fig = Figure(figsize=(3, 3))
+        canvas = FigureCanvas(fig)
+        ax = fig.add_subplot(111)
+        ax.plot(y)
+        ax.grid(axis='y')
+        fig.tight_layout(pad=0)
+
+        buf = io.BytesIO()
+        fig.savefig(buf, format='png')
+        buf.seek(0)
+        return buf.getvalue()
+
+    s = tf.py_func(_plot_buf, [y], tf.string)
+    return s
+
+
+def add_plot_image_summaries(metrics, collections=None):
+    if collections is None:
+        collections = [IMAGE_SUMMARIES]
+    for name_scope, metrics in _as_name_scope_map(metrics).items():
+        with tf.name_scope(name_scope):
+            for name, metric in metrics.items():
+                try:
+                    buf = plot_buf(metric)
+                except:
+                    continue
+                image = tf.image.decode_png(buf, channels=4)
+                image = tf.expand_dims(image, axis=0)
+                tf.summary.image(name, image, max_outputs=1, collections=collections)
+
+
+def plot_summary(name, x, y, display_name=None, description=None, collections=None):
+    """
+    Hack that uses pr_curve summaries for 2D plots.
+
+    Args:
+        x: 1-D tensor with values in increasing order.
+        y: 1-D tensor with static shape.
+
+    Note: tensorboard needs to be modified and compiled from source to disable
+    default axis range [-0.05, 1.05].
+    """
+    from tensorboard import summary as summary_lib
+    x = tf.convert_to_tensor(x)
+    y = tf.convert_to_tensor(y)
+    with tf.control_dependencies([
+        tf.assert_equal(tf.shape(x), tf.shape(y)),
+        tf.assert_equal(y.shape.ndims, 1),
+    ]):
+        y = tf.identity(y)
+    num_thresholds = y.shape[0].value
+    if num_thresholds is None:
+        raise ValueError('Size of y needs to be statically defined for num_thresholds argument')
+    summary = summary_lib.pr_curve_raw_data_op(
+        name,
+        true_positive_counts=tf.ones(num_thresholds),
+        false_positive_counts=tf.ones(num_thresholds),
+        true_negative_counts=tf.ones(num_thresholds),
+        false_negative_counts=tf.ones(num_thresholds),
+        precision=y[::-1],
+        recall=x[::-1],
+        num_thresholds=num_thresholds,
+        display_name=display_name,
+        description=description,
+        collections=collections)
+    return summary
+
+
+def add_plot_summaries(metrics, x_offset=0, collections=None):
+    for name_scope, metrics in _as_name_scope_map(metrics).items():
+        with tf.name_scope(name_scope):
+            for name, metric in metrics.items():
+                plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections)
+
+
+def add_plot_and_scalar_summaries(metrics, x_offset=0, collections=None):
+    for name_scope, metrics in _as_name_scope_map(metrics).items():
+        with tf.name_scope(name_scope):
+            for name, metric in metrics.items():
+                tf.summary.scalar(name, tf.reduce_mean(metric), collections=collections)
+                plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections)
+
+
+def convert_tensor_to_gif_summary(summ):
+    if isinstance(summ, bytes):
+        summary_proto = tf.Summary()
+        summary_proto.ParseFromString(summ)
+        summ = summary_proto
+
+    summary = tf.Summary()
+    for value in summ.value:
+        tag = value.tag
+        try:
+            images_arr = tf.make_ndarray(value.tensor)
+        except TypeError:
+            summary.value.add(tag=tag, image=value.image)
+            continue
+
+        if len(images_arr.shape) == 5:
+            images_arr = np.concatenate(list(images_arr), axis=-2)
+        if len(images_arr.shape) != 4:
+            raise ValueError('Tensors must be 4-D or 5-D for gif summary.')
+        channels = images_arr.shape[-1]
+        if channels < 1 or channels > 4:
+            raise ValueError('Tensors must have 1, 2, 3, or 4 color channels for gif summary.')
+
+        encoded_image_string = ffmpeg_gif.encode_gif(images_arr, fps=4)
+
+        image = tf.Summary.Image()
+        image.height = images_arr.shape[-3]
+        image.width = images_arr.shape[-2]
+        image.colorspace = channels  # 1: grayscale, 2: grayscale + alpha, 3: RGB, 4: RGBA
+        image.encoded_image_string = encoded_image_string
+        summary.value.add(tag=tag, image=image)
+    return summary
+
+
+def compute_averaged_gradients(opt, tower_loss, **kwargs):
+    tower_gradvars = []
+    for loss in tower_loss:
+        with tf.device(loss.device):
+            gradvars = opt.compute_gradients(loss, **kwargs)
+            tower_gradvars.append(gradvars)
+
+    # Now compute global loss and gradients.
+    gradvars = []
+    with tf.name_scope('gradient_averaging'):
+        all_grads = {}
+        for grad, var in itertools.chain(*tower_gradvars):
+            if grad is not None:
+                all_grads.setdefault(var, []).append(grad)
+        for var, grads in all_grads.items():
+            # Average gradients on the same device as the variables
+            # to which they apply.
+            with tf.device(var.device):
+                if len(grads) == 1:
+                    avg_grad = grads[0]
+                else:
+                    avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
+            gradvars.append((avg_grad, var))
+    return gradvars
+
+
+# the next 3 function are from tensorpack:
+# https://github.com/tensorpack/tensorpack/blob/master/tensorpack/graph_builder/utils.py
+def split_grad_list(grad_list):
+    """
+    Args:
+        grad_list: K x N x 2
+
+    Returns:
+        K x N: gradients
+        K x N: variables
+    """
+    g = []
+    v = []
+    for tower in grad_list:
+        g.append([x[0] for x in tower])
+        v.append([x[1] for x in tower])
+    return g, v
+
+
+def merge_grad_list(all_grads, all_vars):
+    """
+    Args:
+        all_grads (K x N): gradients
+        all_vars(K x N): variables
+
+    Return:
+        K x N x 2: list of list of (grad, var) pairs
+    """
+    return [list(zip(gs, vs)) for gs, vs in zip(all_grads, all_vars)]
+
+
+def allreduce_grads(all_grads, average):
+    """
+    All-reduce average the gradients among K devices. Results are broadcasted to all devices.
+
+    Args:
+        all_grads (K x N): List of list of gradients. N is the number of variables.
+        average (bool): average gradients or not.
+
+    Returns:
+        K x N: same as input, but each grad is replaced by the average over K devices.
+    """
+    from tensorflow.contrib import nccl
+    nr_tower = len(all_grads)
+    if nr_tower == 1:
+        return all_grads
+    new_all_grads = []  # N x K
+    for grads in zip(*all_grads):
+        summed = nccl.all_sum(grads)
+
+        grads_for_devices = []  # K
+        for g in summed:
+            with tf.device(g.device):
+                # tensorflow/benchmarks didn't average gradients
+                if average:
+                    g = tf.multiply(g, 1.0 / nr_tower)
+            grads_for_devices.append(g)
+        new_all_grads.append(grads_for_devices)
+
+    # transpose to K x N
+    ret = list(zip(*new_all_grads))
+    return ret
+
+
+def _reduce_entries(*entries):
+    num_gpus = len(entries)
+    if entries[0] is None:
+        assert all(entry is None for entry in entries[1:])
+        reduced_entry = None
+    elif isinstance(entries[0], tf.Tensor):
+        if entries[0].shape.ndims == 0:
+            reduced_entry = tf.add_n(entries) / tf.to_float(num_gpus)
+        else:
+            reduced_entry = tf.concat(entries, axis=0)
+    elif np.isscalar(entries[0]) or isinstance(entries[0], np.ndarray):
+        if np.isscalar(entries[0]) or entries[0].ndim == 0:
+            reduced_entry = sum(entries) / float(num_gpus)
+        else:
+            reduced_entry = np.concatenate(entries, axis=0)
+    elif isinstance(entries[0], tuple) and len(entries[0]) == 2:
+        losses, weights = zip(*entries)
+        loss = tf.add_n(losses) / tf.to_float(num_gpus)
+        if isinstance(weights[0], tf.Tensor):
+            with tf.control_dependencies([tf.assert_equal(weight, weights[0]) for weight in weights[1:]]):
+                weight = tf.identity(weights[0])
+        else:
+            assert all(weight == weights[0] for weight in weights[1:])
+            weight = weights[0]
+        reduced_entry = (loss, weight)
+    else:
+        raise NotImplementedError
+    return reduced_entry
+
+
+def reduce_tensors(structures, shallow=False):
+    if len(structures) == 1:
+        reduced_structure = structures[0]
+    else:
+        if shallow:
+            if isinstance(structures[0], dict):
+                shallow_tree = type(structures[0])([(k, None) for k in structures[0]])
+            else:
+                shallow_tree = type(structures[0])([None for _ in structures[0]])
+            reduced_structure = nest.map_structure_up_to(shallow_tree, _reduce_entries, *structures)
+        else:
+            reduced_structure = nest.map_structure(_reduce_entries, *structures)
+    return reduced_structure
+
+
+def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None):
+
+
+    if os.path.isdir(checkpoint):
+        # latest_checkpoint doesn't work when the path has special characters
+        checkpoint = tf.train.latest_checkpoint(checkpoint)
+    checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint)
+    checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys()
+    restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0])
+    if not var_list:
+        var_list = tf.global_variables()
+    restore_vars = {restore_to_checkpoint_mapping(var.name, checkpoint_var_names): var for var in var_list}
+    if skip_global_step and 'global_step' in restore_vars:
+        del restore_vars['global_step']
+    # restore variables that are both in the global graph and in the checkpoint
+    restore_and_checkpoint_vars = {name: var for name, var in restore_vars.items() if name in checkpoint_var_names}
+    #restore_saver = tf.train.Saver(max_to_keep=1, var_list=restore_and_checkpoint_vars, filename=checkpoint)
+    # print out information regarding variables that were not restored or used for restoring
+    restore_not_in_checkpoint_vars = {name: var for name, var in restore_vars.items() if
+                                      name not in checkpoint_var_names}
+    checkpoint_not_in_restore_var_names = [name for name in checkpoint_var_names if name not in restore_vars]
+    if skip_global_step and 'global_step' in checkpoint_not_in_restore_var_names:
+        checkpoint_not_in_restore_var_names.remove('global_step')
+    if restore_not_in_checkpoint_vars:
+        print("global variables that were not restored because they are "
+              "not in the checkpoint:")
+        for name, _ in sorted(restore_not_in_checkpoint_vars.items()):
+            print("    ", name)
+    if checkpoint_not_in_restore_var_names:
+        print("checkpoint variables that were not used for restoring "
+              "because they are not in the graph:")
+        for name in sorted(checkpoint_not_in_restore_var_names):
+            print("    ", name)
+
+
+    restore_saver = tf.train.Saver(max_to_keep=1, var_list=restore_and_checkpoint_vars, filename=checkpoint)
+
+    return restore_saver, checkpoint
+
+
+def pixel_distribution(pos, height, width):
+    batch_size = pos.get_shape().as_list()[0]
+    y, x = tf.unstack(pos, 2, axis=1)
+
+    x0 = tf.cast(tf.floor(x), 'int32')
+    x1 = x0 + 1
+    y0 = tf.cast(tf.floor(y), 'int32')
+    y1 = y0 + 1
+
+    Ia = tf.reshape(tf.one_hot(y0 * width + x0, height * width), [batch_size, height, width])
+    Ib = tf.reshape(tf.one_hot(y1 * width + x0, height * width), [batch_size, height, width])
+    Ic = tf.reshape(tf.one_hot(y0 * width + x1, height * width), [batch_size, height, width])
+    Id = tf.reshape(tf.one_hot(y1 * width + x1, height * width), [batch_size, height, width])
+
+    x0_f = tf.cast(x0, 'float32')
+    x1_f = tf.cast(x1, 'float32')
+    y0_f = tf.cast(y0, 'float32')
+    y1_f = tf.cast(y1, 'float32')
+    wa = ((x1_f - x) * (y1_f - y))[:, None, None]
+    wb = ((x1_f - x) * (y - y0_f))[:, None, None]
+    wc = ((x - x0_f) * (y1_f - y))[:, None, None]
+    wd = ((x - x0_f) * (y - y0_f))[:, None, None]
+
+    return tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
+
+
+def flow_to_rgb(flows):
+    """The last axis should have dimension 2, for x and y values."""
+
+    def cartesian_to_polar(x, y):
+        magnitude = tf.sqrt(tf.square(x) + tf.square(y))
+        angle = tf.atan2(y, x)
+        return magnitude, angle
+
+    mag, ang = cartesian_to_polar(*tf.unstack(flows, axis=-1))
+    ang_normalized = (ang + np.pi) / (2 * np.pi)
+    mag_min = tf.reduce_min(mag)
+    mag_max = tf.reduce_max(mag)
+    mag_normalized = (mag - mag_min) / (mag_max - mag_min)
+    hsv = tf.stack([ang_normalized, tf.ones_like(ang), mag_normalized], axis=-1)
+    rgb = tf.image.hsv_to_rgb(hsv)
+    return rgb