diff --git a/video_prediction_tools/HPC_scripts/meta_postprocess_era5_template.sh b/video_prediction_tools/HPC_scripts/meta_postprocess_template.sh
similarity index 100%
rename from video_prediction_tools/HPC_scripts/meta_postprocess_era5_template.sh
rename to video_prediction_tools/HPC_scripts/meta_postprocess_template.sh
diff --git a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh b/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh
deleted file mode 100644
index daa48d352ce6b1eca9c2f76692e68ca3e786273e..0000000000000000000000000000000000000000
--- a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh
+++ /dev/null
@@ -1,77 +0,0 @@
-#!/bin/bash -x
-#SBATCH --account=<your_project>
-#SBATCH --nodes=1
-#SBATCH --ntasks=13
-##SBATCH --ntasks-per-node=13
-#SBATCH --cpus-per-task=1
-#SBATCH --output=DataPreprocess_era5_step2-out.%j
-#SBATCH --error=DataPreprocess_era5_step2-err.%j
-#SBATCH --time=04:00:00
-#SBATCH --gres=gpu:0
-#SBATCH --partition=batch
-#SBATCH --mail-type=ALL
-#SBATCH --mail-user=me@somewhere.com
-
-######### Template identifier (don't remove) #########
-echo "Do not run the template scripts"
-exit 99
-######### Template identifier (don't remove) #########
-
-# auxiliary variables
-WORK_DIR="$(pwd)"
-BASE_DIR=$(dirname "$WORK_DIR")
-# Name of virtual environment
-VIRT_ENV_NAME="my_venv"
-# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!!
-# For container usage, comment in the follwoing lines
-# Name of container image (must be available in working directory)
-CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif"
-WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh"
-
-# sanity checks
-if [[ ! -f ${CONTAINER_IMG} ]]; then
-  echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'."
-  exit 1
-fi
-
-if [[ ! -f ${WRAPPER} ]]; then
-  echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image."
-  exit 1
-fi
-
-# clean-up modules to avoid conflicts between host and container settings
-module purge
-
-# declare directory-variables which will be modified by config_runscript.py
-source_dir=/my/path/to/pkl/files/
-destination_dir=/my/path/to/tfrecords/files
-
-sequence_length=20
-sequences_per_file=10
-# run Preprocessing (step 2 where Tf-records are generated)
-export CUDA_VISIBLE_DEVICES=0
-## One node, single GPU
-srun --mpi=pspmix --cpu-bind=none \
-     singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \
-     python3 ../main_scripts/main_preprocess_data_step2.py -source_dir ${source_dir} -dest_dir ${destination_dir} \
-     -sequence_length ${sequence_length} -sequences_per_file ${sequences_per_file}
-
-# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above)
-# Activate virtual environment if needed (and possible)
-#if [ -z ${VIRTUAL_ENV} ]; then
-#   if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then
-#      echo "Activating virtual environment..."
-#      source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate
-#   else
-#      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
-#      exit 1
-#   fi
-#fi
-#
-# Loading modules
-#module purge
-#source ../env_setup/modules_train.sh
-#export CUDA_VISIBLE_DEVICES=0
-#
-# srun python3 ../main_scripts/main_preprocess_data_step2.py -source_dir ${source_dir} -dest_dir ${destination_dir} \
-#     -sequence_length ${sequence_length} -sequences_per_file ${sequences_per_file}
diff --git a/video_prediction_tools/HPC_scripts/preprocess_data_moving_mnist_template.sh b/video_prediction_tools/HPC_scripts/preprocess_data_moving_mnist_template.sh
deleted file mode 100644
index f72950255efa181ca95b9b4f13c81efafe1e7733..0000000000000000000000000000000000000000
--- a/video_prediction_tools/HPC_scripts/preprocess_data_moving_mnist_template.sh
+++ /dev/null
@@ -1,71 +0,0 @@
-#!/bin/bash -x
-#SBATCH --account=<your_project>
-#SBATCH --nodes=1
-#SBATCH --ntasks=1
-##SBATCH --ntasks-per-node=1
-#SBATCH --cpus-per-task=1
-#SBATCH --output=DataPreprocess_moving_mnist-out.%j
-#SBATCH --error=DataPreprocess_moving_mnist-err.%j
-#SBATCH --time=04:00:00
-#SBATCH --partition=batch
-#SBATCH --mail-type=ALL
-#SBATCH --mail-user=me@somewhere.com
-
-######### Template identifier (don't remove) #########
-echo "Do not run the template scripts"
-exit 99
-######### Template identifier (don't remove) #########
-
-# Name of virtual environment 
-VIRT_ENV_NAME="my_venv"
-
-# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!!
-# For container usage, comment in the follwoing lines
-# Name of container image (must be available in working directory)
-CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif"
-WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh"
-
-# sanity checks
-if [[ ! -f ${CONTAINER_IMG} ]]; then
-  echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'."
-  exit 1
-fi
-
-if [[ ! -f ${WRAPPER} ]]; then
-  echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image."
-  exit 1
-fi
-
-# clean-up modules to avoid conflicts between host and container settings
-module purge
-
-# declare directory-variables which will be modified generate_runscript.py
-source_dir=/my/path/to/mnist/raw/data/
-destination_dir=/my/path/to/mnist/tfrecords/
-
-# run Preprocessing (step 2 where Tf-records are generated)
-# run postprocessing/generation of model results including evaluation metrics
-export CUDA_VISIBLE_DEVICES=0
-## One node, single GPU
-srun --mpi=pspmix --cpu-bind=none \
-     singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \
-     python3 ../video_prediction/datasets/moving_mnist.py ${source_dir} ${destination_dir}
-
-# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above)
-# Activate virtual environment if needed (and possible)
-#if [ -z ${VIRTUAL_ENV} ]; then
-#   if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then
-#      echo "Activating virtual environment..."
-#      source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate
-#   else
-#      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
-#      exit 1
-#   fi
-#fi
-#
-# Loading modules
-#module purge
-#source ../env_setup/modules_train.sh
-#export CUDA_VISIBLE_DEVICES=0
-#
-# srun python3 .../video_prediction/datasets/moving_mnist.py ${source_dir} ${destination_dir}
\ No newline at end of file
diff --git a/video_prediction_tools/HPC_scripts/train_model_moving_mnist_template.sh b/video_prediction_tools/HPC_scripts/train_model_moving_mnist_template.sh
deleted file mode 100755
index 322d0fac362119032f558232e8161321434d2f2f..0000000000000000000000000000000000000000
--- a/video_prediction_tools/HPC_scripts/train_model_moving_mnist_template.sh
+++ /dev/null
@@ -1,82 +0,0 @@
-#!/bin/bash -x
-#SBATCH --account=<your_project>
-#SBATCH --nodes=1
-#SBATCH --ntasks=1
-##SBATCH --ntasks-per-node=1
-#SBATCH --cpus-per-task=1
-#SBATCH --output=train_moving_mnist-out.%j
-#SBATCH --error=train_moving_mnist-err.%j
-#SBATCH --time=00:20:00
-#SBATCH --gres=gpu:1
-#SBATCH --partition=gpus
-#SBATCH --mail-type=ALL
-#SBATCH --mail-user=me@somewhere.com
-
-######### Template identifier (don't remove) #########
-echo "Do not run the template scripts"
-exit 99
-######### Template identifier (don't remove) #########
-
-# auxiliary variables
-WORK_DIR=`pwd`
-BASE_DIR=$(dirname "$WORK_DIR")
-# Name of virtual environment
-VIRT_ENV_NAME="my_venv"
-# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!!
-# For container usage, comment in the follwoing lines
-# Name of container image (must be available in working directory)
-CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif"
-WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh"
-
-# sanity checks
-if [[ ! -f ${CONTAINER_IMG} ]]; then
-  echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'."
-  exit 1
-fi
-
-if [[ ! -f ${WRAPPER} ]]; then
-  echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image."
-  exit 1
-fi
-
-# clean-up modules to avoid conflicts between host and container settings
-module purge
-
-# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
-
-source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist
-destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist
-
-# for choosing the model, convLSTM,savp, mcnet,vae
-model=convLSTM
-dataset=moving_mnist
-model_hparams=../hparams/${dataset}/${model}/model_hparams.json
-destination_dir=${destination_dir}/${model}/"$(date +"%Y%m%dT%H%M")_"$USER""
-
-# run training in container
-export CUDA_VISIBLE_DEVICES=0
-## One node, single GPU
-srun --mpi=pspmix --cpu-bind=none \
-     singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \
-     python ../main_scripts/train.py --input_dir  ${source_dir}/tfrecords/ --dataset ${dataset}  --model ${model} \
-      --model_hparams_dict ${model_hparams} --output_dir "${destination_dir}"/
-
-# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above)
-# Activate virtual environment if needed (and possible)
-#if [ -z ${VIRTUAL_ENV} ]; then
-#   if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then
-#      echo "Activating virtual environment..."
-#      source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate
-#   else
-#      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
-#      exit 1
-#   fi
-#fi
-#
-# Loading modules
-#module purge
-#source ../env_setup/modules_train.sh
-#export CUDA_VISIBLE_DEVICES=0
-#
-# srun python3 ../main_scripts/train.py --input_dir  ${source_dir}/tfrecords/ --dataset ${dataset}  --model ${model} \
-#      --model_hparams_dict ${model_hparams} --output_dir "${destination_dir}"/
\ No newline at end of file
diff --git a/video_prediction_tools/HPC_scripts/train_model_era5_template.sh b/video_prediction_tools/HPC_scripts/train_model_template.sh
similarity index 100%
rename from video_prediction_tools/HPC_scripts/train_model_era5_template.sh
rename to video_prediction_tools/HPC_scripts/train_model_template.sh
diff --git a/video_prediction_tools/HPC_scripts/train_model_weatherbench_template.sh b/video_prediction_tools/HPC_scripts/train_model_weatherbench_template.sh
deleted file mode 100644
index 44ccf018d2896553ad360d5c5dbd0c398b7b54d8..0000000000000000000000000000000000000000
--- a/video_prediction_tools/HPC_scripts/train_model_weatherbench_template.sh
+++ /dev/null
@@ -1,78 +0,0 @@
-#!/bin/bash -x
-#SBATCH --account=<your_project>
-#SBATCH --nodes=1
-#SBATCH --ntasks=1
-#SBATCH --output=train_model_era5-out.%j
-#SBATCH --error=train_model_era5-err.%j
-#SBATCH --time=24:00:00
-#SBATCH --gres=gpu:1
-#SBATCH --partition=some_partition
-#SBATCH --mail-type=ALL
-#SBATCH --mail-user=me@somewhere.com
-
-######### Template identifier (don't remove) #########
-echo "Do not run the template scripts"
-exit 99
-######### Template identifier (don't remove) #########
-
-# auxiliary variables
-WORK_DIR="$(pwd)"
-BASE_DIR=$(dirname "$WORK_DIR")
-# Name of virtual environment
-VIRT_ENV_NAME="my_venv"
-# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!!
-# For container usage, comment in the follwoing lines
-# Name of container image (must be available in working directory)
-CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif"
-WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh"
-
-# sanity checks
-if [[ ! -f ${CONTAINER_IMG} ]]; then
-  echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'."
-  exit 1
-fi
-
-if [[ ! -f ${WRAPPER} ]]; then
-  echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image."
-  exit 1
-fi
-
-# clean-up modules to avoid conflicts between host and container settings
-module purge
-
-# declare directory-variables which will be modified by generate_runscript.py
-source_dir=/my/path/to/tfrecords/files
-destination_dir=/my/model/output/path
-
-# valid identifiers for model-argument are: convLSTM, savp, mcnet and vae
-model=convLSTM
-datasplit_dict=${destination_dir}/data_split.json
-model_hparams=${destination_dir}/model_hparams.json
-
-# run training in container
-export CUDA_VISIBLE_DEVICES=0
-## One node, single GPU 
-srun --mpi=pspmix --cpu-bind=none \
-     singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \
-     python3 "${BASE_DIR}"/main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} \
-     --dataset weatherbench --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/
-
-# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above)
-# Activate virtual environment if needed (and possible)
-#if [ -z ${VIRTUAL_ENV} ]; then
-#   if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then
-#      echo "Activating virtual environment..."
-#      source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate
-#   else
-#      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
-#      exit 1
-#   fi
-#fi
-#
-# Loading modules
-#module purge
-#source ../env_setup/modules_train.sh
-#export CUDA_VISIBLE_DEVICES=0
-#
-# srun python3 "${BASE_DIR}"/main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} \
-#     --dataset era5 --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/
\ No newline at end of file
diff --git a/video_prediction_tools/HPC_scripts/visualize_postprocess_moving_mnist_template.sh b/video_prediction_tools/HPC_scripts/visualize_postprocess_moving_mnist_template.sh
deleted file mode 100755
index 142193121fb12ea792d0350eac859652512438a1..0000000000000000000000000000000000000000
--- a/video_prediction_tools/HPC_scripts/visualize_postprocess_moving_mnist_template.sh
+++ /dev/null
@@ -1,80 +0,0 @@
-#!/bin/bash -x
-#SBATCH --account=<your_project>
-#SBATCH --nodes=1
-#SBATCH --ntasks=1
-##SBATCH --ntasks-per-node=1
-#SBATCH --cpus-per-task=1
-#SBATCH --output=generate_era5-out.%j
-#SBATCH --error=generate_era5-err.%j
-#SBATCH --time=00:20:00
-#SBATCH --gres=gpu:1
-#SBATCH --partition=develgpus
-#SBATCH --mail-type=ALL
-#SBATCH --mail-user=me@somewhere.com
-
-######### Template identifier (don't remove) #########
-echo "Do not run the template scripts"
-exit 99
-######### Template identifier (don't remove) #########
-
-# auxiliary variables
-WORK_DIR="$(pwd)"
-BASE_DIR=$(dirname "$WORK_DIR")
-# Name of virtual environment
-VIRT_ENV_NAME="my_venv"
-# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!!
-# For container usage, comment in the follwoing lines
-# Name of container image (must be available in working directory)
-CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif"
-WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh"
-
-# sanity checks
-if [[ ! -f ${CONTAINER_IMG} ]]; then
-  echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'."
-  exit 1
-fi
-
-if [[ ! -f ${WRAPPER} ]]; then
-  echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image."
-  exit 1
-fi
-
-# clean-up modules to avoid conflicts between host and container settings
-module purge
-
-# declare directory-variables which will be modified by config_runscript.py
-source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist
-checkpoint_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist
-results_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist
-# name of model
-model=convLSTM
-
-# run postprocessing/generation of model results including evaluation metrics
-export CUDA_VISIBLE_DEVICES=0
-## One node, single GPU
-srun --mpi=pspmix --cpu-bind=none \
-     singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \
-     python3 ../scripts/generate_movingmnist.py --input_dir ${source_dir}/ --dataset_hparams sequence_length=20 \
-     --checkpoint  ${checkpoint_dir}/${model} --mode test --model ${model} --results_dir ${results_dir}/${model} \
-     --batch_size 2 --dataset era5 > generate_era5-out."${SLURM_JOB_ID}"
-
-# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above)
-# Activate virtual environment if needed (and possible)
-#if [ -z ${VIRTUAL_ENV} ]; then
-#   if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then
-#      echo "Activating virtual environment..."
-#      source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate
-#   else
-#      echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
-#      exit 1
-#   fi
-#fi
-#
-# Loading modules
-#module purge
-#source ../env_setup/modules_train.sh
-#export CUDA_VISIBLE_DEVICES=0
-#
-# srun python3 ../scripts/generate_movingmnist.py --input_dir ${source_dir}/ --dataset_hparams sequence_length=20 \
-#     --checkpoint  ${checkpoint_dir}/${model} --mode test --model ${model} --results_dir ${results_dir}/${model} \
-#     --batch_size 2 --dataset era5 > generate_era5-out."${SLURM_JOB_ID}"
\ No newline at end of file
diff --git a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh b/video_prediction_tools/HPC_scripts/visualize_postprocess_template.sh
similarity index 100%
rename from video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh
rename to video_prediction_tools/HPC_scripts/visualize_postprocess_template.sh
diff --git a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json b/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json
deleted file mode 100644
index bc5f8983a5aa6b0b2ba3d560bc4c2391995794a4..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json
+++ /dev/null
@@ -1,10 +0,0 @@
-
-{
-    "batch_size": 10,
-    "lr": 0.001,
-    "max_epochs": 2,
-    "context_frames": 12
-}
-
-
-
diff --git a/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json b/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json
deleted file mode 100644
index 0ccf44e6370f765857204317f172c866865b4b35..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json
+++ /dev/null
@@ -1,17 +0,0 @@
-{
-    "batch_size": 16,
-    "lr": 0.0002,
-    "beta1": 0.5,
-    "beta2": 0.999,
-    "l1_weight": 100.0,
-    "l2_weight": 0.0,
-    "kl_weight": 0.0,
-    "video_sn_vae_gan_weight": 0.0,
-    "video_sn_gan_weight": 0.1,
-    "vae_gan_feature_cdist_weight": 0.0,
-    "gan_feature_cdist_weight": 10.0,
-    "state_weight": 0.0,
-    "nz": 32,
-    "max_epochs":2,
-    "context_frames":12
-}
diff --git a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json b/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json
deleted file mode 100644
index 770f9ff516a630ff031b94bb2c8a2b41c1686eec..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json
+++ /dev/null
@@ -1,15 +0,0 @@
-{
-    "batch_size": 32,
-    "lr": 0.001,
-    "beta1": 0.9,
-    "beta2": 0.999,
-    "l1_weight": 1.0,
-    "l2_weight": 0.0,
-    "kl_weight": 1e-05,
-    "video_sn_vae_gan_weight": 0.0,
-    "video_sn_gan_weight": 0.0,
-    "state_weight": 0.0,
-    "nz": 32,
-    "max_epochs":2,
-    "context_frames":12
-}
diff --git a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json
deleted file mode 100644
index f36e1c0b44279ad2e4f9e741c7bfade0a5aa0a05..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json
+++ /dev/null
@@ -1,22 +0,0 @@
-{
-    "batch_size": 32,
-    "lr": 0.0002,
-    "beta1": 0.5,
-    "beta2": 0.999,
-    "l1_weight": 100.0,
-    "l2_weight": 0.0,
-    "kl_weight": 0.01,
-    "video_sn_vae_gan_weight": 0.1,
-    "video_sn_gan_weight": 0.1,
-    "vae_gan_feature_cdist_weight": 10.0,
-    "gan_feature_cdist_weight": 0.0,
-    "state_weight": 0.0,
-    "nz": 16,
-    "max_epochs":4,
-    "context_frames": 12,
-    "opt_var": "0",
-    "decay_steps":[3000,9000],
-    "end_lr": 0.00000008
-}
-
-
diff --git a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json b/video_prediction_tools/hparams/era5/vae/model_hparams_template.json
deleted file mode 100644
index 1306627e24bec0888600fb88fcaa937e5f01dbd7..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json
+++ /dev/null
@@ -1,14 +0,0 @@
-
-{
-    "batch_size": 10,
-    "lr": 0.001,
-    "nz":16,
-    "max_epochs":2,
-    "context_frames":12,
-    "weight_recon":1,
-    "loss_fun": "rmse",
-    "shuffle_on_val": true
-}
-
-
-
diff --git a/video_prediction_tools/hparams/gzprcp/savp/model_hparams_template.json b/video_prediction_tools/hparams/gzprcp/savp/model_hparams_template.json
deleted file mode 100644
index 7c1ab72eea7ad1b341a66a76c4a88d1524450417..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/gzprcp/savp/model_hparams_template.json
+++ /dev/null
@@ -1,20 +0,0 @@
-{
-    "batch_size": 16,
-    "lr": 0.0002,
-    "beta1": 0.5,
-    "beta2": 0.999,
-    "l1_weight": 100.0,
-    "l2_weight": 0.0,
-    "kl_weight": 0.01,
-    "video_sn_vae_gan_weight": 0.1,
-    "video_sn_gan_weight": 0.1,
-    "vae_gan_feature_cdist_weight": 10.0,
-    "gan_feature_cdist_weight": 0.0,
-    "state_weight": 0.0,
-    "nz": 16,
-    "max_epochs":2,
-    "context_frames":10,
-    "sequence_length":30
-}
-
-
diff --git a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json
deleted file mode 100644
index 2b341c11e9853c974c84a7573a70aead6b985d65..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json
+++ /dev/null
@@ -1,12 +0,0 @@
-
-{
-    "batch_size": 10,
-    "lr": 0.001,
-    "max_epochs":20,
-    "context_frames":10,
-    "loss_fun":"cross_entropy",
-    "opt_var": "all"
-}
-
-
-
diff --git a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json
deleted file mode 100644
index 6cda5552d437b9b283a40bdde77eb1d3b3497b36..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json
+++ /dev/null
@@ -1,11 +0,0 @@
-
-{
-    "batch_size": 10,
-    "lr": 0.001,
-    "max_epochs":20,
-    "context_frames":10,
-    "loss_fun":"cross_entropy"
-}
-
-
-
diff --git a/video_prediction_tools/hparams/weatherbench/mcnet/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/mcnet/model_hparams_template.json
deleted file mode 100644
index bc5f8983a5aa6b0b2ba3d560bc4c2391995794a4..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/weatherbench/mcnet/model_hparams_template.json
+++ /dev/null
@@ -1,10 +0,0 @@
-
-{
-    "batch_size": 10,
-    "lr": 0.001,
-    "max_epochs": 2,
-    "context_frames": 12
-}
-
-
-
diff --git a/video_prediction_tools/hparams/weatherbench/ours_gan/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/ours_gan/model_hparams_template.json
deleted file mode 100644
index 0ccf44e6370f765857204317f172c866865b4b35..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/weatherbench/ours_gan/model_hparams_template.json
+++ /dev/null
@@ -1,17 +0,0 @@
-{
-    "batch_size": 16,
-    "lr": 0.0002,
-    "beta1": 0.5,
-    "beta2": 0.999,
-    "l1_weight": 100.0,
-    "l2_weight": 0.0,
-    "kl_weight": 0.0,
-    "video_sn_vae_gan_weight": 0.0,
-    "video_sn_gan_weight": 0.1,
-    "vae_gan_feature_cdist_weight": 0.0,
-    "gan_feature_cdist_weight": 10.0,
-    "state_weight": 0.0,
-    "nz": 32,
-    "max_epochs":2,
-    "context_frames":12
-}
diff --git a/video_prediction_tools/hparams/weatherbench/ours_vae_l1/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/ours_vae_l1/model_hparams_template.json
deleted file mode 100644
index 770f9ff516a630ff031b94bb2c8a2b41c1686eec..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/weatherbench/ours_vae_l1/model_hparams_template.json
+++ /dev/null
@@ -1,15 +0,0 @@
-{
-    "batch_size": 32,
-    "lr": 0.001,
-    "beta1": 0.9,
-    "beta2": 0.999,
-    "l1_weight": 1.0,
-    "l2_weight": 0.0,
-    "kl_weight": 1e-05,
-    "video_sn_vae_gan_weight": 0.0,
-    "video_sn_gan_weight": 0.0,
-    "state_weight": 0.0,
-    "nz": 32,
-    "max_epochs":2,
-    "context_frames":12
-}
diff --git a/video_prediction_tools/hparams/weatherbench/savp/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/savp/model_hparams_template.json
deleted file mode 100644
index f36e1c0b44279ad2e4f9e741c7bfade0a5aa0a05..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/weatherbench/savp/model_hparams_template.json
+++ /dev/null
@@ -1,22 +0,0 @@
-{
-    "batch_size": 32,
-    "lr": 0.0002,
-    "beta1": 0.5,
-    "beta2": 0.999,
-    "l1_weight": 100.0,
-    "l2_weight": 0.0,
-    "kl_weight": 0.01,
-    "video_sn_vae_gan_weight": 0.1,
-    "video_sn_gan_weight": 0.1,
-    "vae_gan_feature_cdist_weight": 10.0,
-    "gan_feature_cdist_weight": 0.0,
-    "state_weight": 0.0,
-    "nz": 16,
-    "max_epochs":4,
-    "context_frames": 12,
-    "opt_var": "0",
-    "decay_steps":[3000,9000],
-    "end_lr": 0.00000008
-}
-
-
diff --git a/video_prediction_tools/hparams/weatherbench/vae/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/vae/model_hparams_template.json
deleted file mode 100644
index 1306627e24bec0888600fb88fcaa937e5f01dbd7..0000000000000000000000000000000000000000
--- a/video_prediction_tools/hparams/weatherbench/vae/model_hparams_template.json
+++ /dev/null
@@ -1,14 +0,0 @@
-
-{
-    "batch_size": 10,
-    "lr": 0.001,
-    "nz":16,
-    "max_epochs":2,
-    "context_frames":12,
-    "weight_recon":1,
-    "loss_fun": "rmse",
-    "shuffle_on_val": true
-}
-
-
-
diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index cf9fc63f01483cc9d0d9fe6925e10fdc2607cd4a..a93642776beeb7b6a8ad9db0968b9be08594fb4b 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -113,6 +113,7 @@ class TrainModel(object):
         """
         if self.model_hparams_dict:
             with open(self.model_hparams_dict, 'r') as f:
+                print("self.model_hparams_dict",self.model_hparams_dict)
                 self.model_hparams_dict_load = json.loads(f.read())
         else:
             raise FileNotFoundError("hparam directory doesn't exist! please check {}!".format(self.model_hparams_dict))
@@ -171,7 +172,7 @@ class TrainModel(object):
         :param mode: "train" used the model graph in train process;  "test" for postprocessing step
         """
         VideoPredictionModel = models.get_model_class(self.model)
-        self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict, mode=mode)
+        self.video_model = VideoPredictionModel(hparams_dict_config=self.model_hparams_dict, mode=mode)
 
     def setup_graph(self):
         """
@@ -209,7 +210,8 @@ class TrainModel(object):
         with open(os.path.join(self.output_dir, "dataset_hparams.json"), "w") as f:
             f.write(json.dumps(dataset.hparams, sort_keys=True, indent=4))
         with open(os.path.join(self.output_dir, "model_hparams.json"), "w") as f:
-            f.write(json.dumps(video_model.hparams, sort_keys=True, indent=4))
+            print("video_model.get_hparams",video_model.get_hparams)
+            f.write(json.dumps(video_model.get_hparams, sort_keys=True, indent=4))
         #with open(os.path.join(self.output_dir, "data_dict.json"), "w") as f:
         #   f.write(json.dumps(dataset.data_dict, sort_keys=True, indent=4))
 
@@ -407,45 +409,24 @@ class TrainModel(object):
         Fetch variables in the graph, this can be custermized based on models and also the needs of users
         """
         # This is the basic fetch for all the models
-        fetch_list = ["train_op", "summary_op", "global_step"]
+        fetch_list = ["train_op", "summary_op", "global_step","total_loss"]
 
         # Append fetches depending on model to be trained
-        if self.video_model.__class__.__name__ == "McNetVideoPredictionModel":
-            fetch_list = fetch_list + ["L_p", "L_gdl", "L_GAN"]
-            self.saver_loss = fetch_list[-3]  # ML: Is this a reasonable choice?
+        if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":
+            self.saver_loss = fetch_list[-1]  
             self.saver_loss_name = "Loss"
         if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
-            fetch_list = fetch_list + ["inputs", "total_loss"]
-            self.saver_loss = fetch_list[-1]
-            self.saver_loss_name = "Total loss"
-        if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
-            fetch_list = fetch_list + ["g_losses", "d_losses", "d_loss", "g_loss", ("g_losses", "gen_l1_loss")]
-            # Add loss that is tracked
-            self.saver_loss = fetch_list[-1][1]                
-            self.saver_loss_dict = fetch_list[-1][0]
-            self.saver_loss_name = "Generator L1 loss"
-        if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
-            fetch_list = fetch_list + ["latent_loss", "recon_loss", "total_loss"]
-            self.saver_loss = fetch_list[-2]
-            self.saver_loss_name = "Reconstruction loss"
-        if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":
-            fetch_list = fetch_list + ["inputs", "total_loss"]
-            self.saver_loss = fetch_list[-1]
-            self.saver_loss_name = "Total loss"
-        if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":
-            fetch_list = fetch_list + ["inputs", "total_loss"]
+            fetch_list = fetch_list + ["inputs"]
             self.saver_loss = fetch_list[-1]
+<<<<<<< HEAD
             self.saver_loss_name = "Total loss"
         if self.video_model.__class__.__name__ == "WeatherBenchModel":
             fetch_list = fetch_list + ["total_loss"]
             self.saver_loss = fetch_list[-1]
             self.saver_loss_name = "Total loss"
-        else:
-            raise ("self.saver_loss is not set up for your video model class {}".format(self.video_model.__class__.__name__ ))
-
-
+        else: 
+            self.saver_loss = "total_loss"
         self.fetches = self.generate_fetches(fetch_list)
-
         return self.fetches
 
     def create_fetches_for_val(self):
@@ -502,27 +483,16 @@ class TrainModel(object):
         Print the training results /validation results from the training step.
         """
         method = TrainModel.print_results.__name__
-
         train_epoch = step/self.steps_per_epoch
         print("%{0}: Progress global step {1:d}  epoch {2:.1f}".format(method, step + 1, train_epoch))
-        if self.video_model.__class__.__name__ == "McNetVideoPredictionModel":
-            print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"], results["L_p"],
-                                                                           results["L_gdl"],results["L_GAN"]))
-        elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or self.video_model.__class__.__name__ == "WeatherBenchModel":
+        if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
             print ("Total_loss:{}".format(results["total_loss"]))
-        elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
-            print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}, gen_l1_loss: {}"
-                  .format(results["g_losses"], results["d_losses"], results["g_loss"], results["d_loss"],
-                          results["gen_l1_loss"]))
-        elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
-            print("Total_loss:{}; latent_losses:{}; reconst_loss:{}"
-                  .format(results["total_loss"], results["latent_loss"], results["recon_loss"]))
         elif self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":
             print("Total_loss:{}"
                   .format(results["total_loss"]))
         else:
-            print("%{0}: Printing results of model '{1}' is not implemented yet".format(method, self.video_model.__class__.__name__))
-
+            print("Total_loss:{}"
+                  .format(results["total_loss"]))
     @staticmethod
     def plot_train(train_losses, val_losses, loss_name, output_dir):
         """
diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py
index 9ab53f1918b2dd57d43e77ee8f4dd5a7556b2d5a..d2bcc4d0d93533e7306d30a6ff655cb2d676be12 100644
--- a/video_prediction_tools/model_modules/model_architectures.py
+++ b/video_prediction_tools/model_modules/model_architectures.py
@@ -1,20 +1,12 @@
 def known_models():
     """
     An auxilary function
-    ours_vae_l1 and ours_gan are from savp papers
     :return: dictionary of known model architectures
     """
     model_mappings = {
-        'ground_truth': 'GroundTruthVideoPredictionModel',
-        'savp': 'SAVPVideoPredictionModel',
-        'vae': 'VanillaVAEVideoPredictionModel',
         'convLSTM': 'VanillaConvLstmVideoPredictionModel',
-        'mcnet': 'McNetVideoPredictionModel',
         'convLSTM_gan': "ConvLstmGANVideoPredictionModel",
-        'ours_vae_l1': 'SAVPVideoPredictionModel',
-        'ours_gan': 'SAVPVideoPredictionModel',
         "weatherBench": "WeatherBenchModel"
-        'precrnn_v2': 'PredRNNv2VideoPredictionModel'
         }
 
     return model_mappings
diff --git a/video_prediction_tools/model_modules/video_prediction/flow_ops.py b/video_prediction_tools/model_modules/video_prediction/flow_ops.py
deleted file mode 100644
index 88114a58f51524fc9dad48a74a201ab4805f3c15..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/flow_ops.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# SPDX-FileCopyrightText: 2017 Simon Meister
-#
-# SPDX-License-Identifier: MIT
-
-import tensorflow as tf
-
-
-def image_warp(im, flow):
-    """Performs a backward warp of an image using the predicted flow.
-
-    Args:
-        im: Batch of images. [num_batch, height, width, channels]
-        flow: Batch of flow vectors. [num_batch, height, width, 2]
-    Returns:
-        warped: transformed image of the same shape as the input image.
-
-    Implementation taken from here: https://github.com/simonmeister/UnFlow
-    """
-    with tf.variable_scope('image_warp'):
-
-        num_batch, height, width, channels = tf.unstack(tf.shape(im))
-        max_x = tf.cast(width - 1, 'int32')
-        max_y = tf.cast(height - 1, 'int32')
-        zero = tf.zeros([], dtype='int32')
-
-        # We have to flatten our tensors to vectorize the interpolation
-        im_flat = tf.reshape(im, [-1, channels])
-        flow_flat = tf.reshape(flow, [-1, 2])
-
-        # Floor the flow, as the final indices are integers
-        # The fractional part is used to control the bilinear interpolation.
-        flow_floor = tf.to_int32(tf.floor(flow_flat))
-        bilinear_weights = flow_flat - tf.floor(flow_flat)
-
-        # Construct base indices which are displaced with the flow
-        pos_x = tf.tile(tf.range(width), [height * num_batch])
-        grid_y = tf.tile(tf.expand_dims(tf.range(height), 1), [1, width])
-        pos_y = tf.tile(tf.reshape(grid_y, [-1]), [num_batch])
-
-        x = flow_floor[:, 0]
-        y = flow_floor[:, 1]
-        xw = bilinear_weights[:, 0]
-        yw = bilinear_weights[:, 1]
-
-        # Compute interpolation weights for 4 adjacent pixels
-        # expand to num_batch * height * width x 1 for broadcasting in add_n below
-        wa = tf.expand_dims((1 - xw) * (1 - yw), 1) # top left pixel
-        wb = tf.expand_dims((1 - xw) * yw, 1) # bottom left pixel
-        wc = tf.expand_dims(xw * (1 - yw), 1) # top right pixel
-        wd = tf.expand_dims(xw * yw, 1) # bottom right pixel
-
-        x0 = pos_x + x
-        x1 = x0 + 1
-        y0 = pos_y + y
-        y1 = y0 + 1
-
-        x0 = tf.clip_by_value(x0, zero, max_x)
-        x1 = tf.clip_by_value(x1, zero, max_x)
-        y0 = tf.clip_by_value(y0, zero, max_y)
-        y1 = tf.clip_by_value(y1, zero, max_y)
-
-        dim1 = width * height
-        batch_offsets = tf.range(num_batch) * dim1
-        base_grid = tf.tile(tf.expand_dims(batch_offsets, 1), [1, dim1])
-        base = tf.reshape(base_grid, [-1])
-
-        base_y0 = base + y0 * width
-        base_y1 = base + y1 * width
-        idx_a = base_y0 + x0
-        idx_b = base_y1 + x0
-        idx_c = base_y0 + x1
-        idx_d = base_y1 + x1
-
-        Ia = tf.gather(im_flat, idx_a)
-        Ib = tf.gather(im_flat, idx_b)
-        Ic = tf.gather(im_flat, idx_c)
-        Id = tf.gather(im_flat, idx_d)
-
-        warped_flat = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
-        warped = tf.reshape(warped_flat, [num_batch, height, width, channels])
-        warped.set_shape(im.shape)
-
-        return warped
diff --git a/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py
index fd743cb9052605e009c8ed0c31c5722d32238dd7..919b0d5ac64f4f9cc081d8aba681974870e91f3f 100644
--- a/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py
+++ b/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py
@@ -95,16 +95,8 @@ class BasicConvLSTMCell(ConvRNNCell):
             #Bing20200930#replace with non-linear convolutional layers
             #concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True)
             input_h_con = tf.concat(axis = 3, values = input_h)
-            concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4, "decode_1", activate="sigmoid")  
-
-            print("concat1:",concat)
-            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+            concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4, "decode_1", activate="sigmoid")
             i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat)
-            print("input gate i:",i)
-            print("new_input j:",j)
-            print("forget gate:",f)
-            print("output gate:",o)
-           
             new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) *
                      self._activation(j))
             new_h = self._activation(new_c) * tf.nn.sigmoid(o)
@@ -113,8 +105,7 @@ class BasicConvLSTMCell(ConvRNNCell):
                 new_state = LSTMStateTuple(new_c, new_h)
             else:
                 new_state = tf.concat(axis = 3, values = [new_c, new_h])
-            print("new h", new_h)
-            print("new state",new_state)
+
             return new_h, new_state
 
 
@@ -150,14 +141,11 @@ def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=No
         matrix = tf.get_variable(
             "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype)
         if len(args) == 1:
-            print("args[0]:",args[0])
+
             res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME')
-            print("res1:",res)
+
         else:
-            print("matrix:",matrix)
-            print("tf.concat(axis = 3, values = args):",tf.concat(axis = 3, values = args))
             res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME')
-            print("res2:",res)
         if not bias:
             return res
         bias_term = tf.get_variable(
diff --git a/video_prediction_tools/model_modules/video_prediction/layers/__init__.py b/video_prediction_tools/model_modules/video_prediction/layers/__init__.py
index 8530ffd70f0899c8d8e0832d0dcd377b78bbe349..8b137891791fe96927ad78e64b0aad7bded08bdc 100644
--- a/video_prediction_tools/model_modules/video_prediction/layers/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/layers/__init__.py
@@ -1 +1 @@
-from .normalization import fused_instance_norm
+
diff --git a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py
index 79d7653a5cc55d72abb7ea2fbdb22ca8be4c3b67..6f7c4f38b222afecb2b21d36bedc938b7813399a 100644
--- a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py
+++ b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py
@@ -58,11 +58,11 @@ def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.l
 
 
 def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"):
-    print("conv_layer activation function",activate) 
+
     with tf.variable_scope('{0}_conv'.format(idx)) as scope:
  
         input_channels = inputs.get_shape()[-1]
-        weights = _variable_with_weight_decay('weights',shape = [kernel_size, kernel_size, 
+        weights = _variable_with_weight_decay('weights', shape = [kernel_size, kernel_size,
                                                                  input_channels, num_features],
                                               stddev = 0.01, wd = weight_decay)
         biases = _variable_on_gpu('biases', [num_features], initializer)
@@ -97,7 +97,7 @@ def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initial
 
         output_shape = tf.stack(
             [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features])
-        print ("output_shape",output_shape)
+
         conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME')
         conv_biased = tf.nn.bias_add(conv, biases)
         if activate == "linear":
@@ -140,6 +140,7 @@ def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01,
         else:
             ip = tf.add(tf.matmul(inputs_processed, weights), biases)
             return tf.nn.elu(ip, name = str(idx) + '_fc')
+
         
 def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None):
     with tf.variable_scope('{0}_bn'.format(idx)) as scope:
@@ -159,5 +160,20 @@ def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None):
         else:
              return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon)
 
-def bn_layers_wrapper(inputs, is_training):
-    pass
+
+class batch_norm(object):
+    def __init__(self, epsilon=1e-5, momentum=0.9, name="batch_norm"):
+        with tf.variable_scope(name):
+            self.epsilon = epsilon
+            self.momentum = momentum
+            self.name = name
+
+    def __call__(self, x, train=True):
+        return tf.contrib.layers.batch_norm(x,
+                      decay=self.momentum,
+                      updates_collections=None,
+                      epsilon=self.epsilon,
+                      scale=True,
+                      is_training=train,
+                      scope=self.name)
+
diff --git a/video_prediction_tools/model_modules/video_prediction/layers/mcnet_ops.py b/video_prediction_tools/model_modules/video_prediction/layers/mcnet_ops.py
deleted file mode 100644
index 0b2e41ea2ebcef4247cceaa547fc97b73799ccc1..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/layers/mcnet_ops.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
-#
-# SPDX-License-Identifier: MIT
-
-import math
-import numpy as np 
-import tensorflow as tf
-
-from tensorflow.python.framework import ops
-from model_modules.video_prediction.utils.mcnet_utils import *
-
-
-def batch_norm(inputs, name, train=True, reuse=False):
-    return tf.contrib.layers.batch_norm(inputs=inputs,is_training=train,
-                                      reuse=reuse,scope=name,scale=True)
-
-
-def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="conv2d", reuse=False, padding='SAME'):
-    with tf.variable_scope(name, reuse=reuse):
-        w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
-                         initializer=tf.contrib.layers.xavier_initializer())
-        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding)
- 
-        biases = tf.get_variable('biases', [output_dim],
-                              initializer=tf.constant_initializer(0.0))
-        conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
- 
-    return conv
-
-
-def deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 
-             name="deconv2d", reuse=False, with_w=False, padding='SAME'):
-        with tf.variable_scope(name, reuse=reuse):
-          # filter : [height, width, output_channels, in_channels]
-            w = tf.get_variable('w', [k_h, k_h, output_shape[-1],
-                                input_.get_shape()[-1]],
-                                initializer=tf.contrib.layers.xavier_initializer())
-    
-            try:
-                deconv = tf.nn.conv2d_transpose(input_, w,
-                                      output_shape=output_shape,
-                                      strides=[1, d_h, d_w, 1],
-                                      padding=padding)
-
-            # Support for verisons of TensorFlow before 0.7.0
-            except AttributeError:
-                deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
-                          strides=[1, d_h, d_w, 1])
-            biases = tf.get_variable('biases', [output_shape[-1]],
-                             initializer=tf.constant_initializer(0.0))
-            deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
-
-            if with_w:
-                return deconv, w, biases
-            else:
-                return deconv
-
-
-def lrelu(x, leak=0.2, name="lrelu"):
-    with tf.variable_scope(name):
-        f1 = 0.5 * (1 + leak)
-        f2 = 0.5 * (1 - leak)
-    return f1 * x + f2 * abs(x)
-
-
-def relu(x):
-    return tf.nn.relu(x)
-
-
-def tanh(x):
-    return tf.nn.tanh(x)
-
-
-def shape2d(a):
-    """
-    a: a int or tuple/list of length 2
-    """
-    if type(a) == int:
-        return [a, a]
-    if isinstance(a, (list, tuple)):
-        assert len(a) == 2
-        return list(a)
-    raise RuntimeError("Illegal shape: {}".format(a))
-
-
-def shape4d(a):
-  # for use with tensorflow
-    return [1] + shape2d(a) + [1]
-
-
-def UnPooling2x2ZeroFilled(x):
-    out = tf.concat(axis=3, values=[x, tf.zeros_like(x)])
-    out = tf.concat(axis=2, values=[out, tf.zeros_like(out)])
-
-    sh = x.get_shape().as_list()
-    if None not in sh[1:]:
-        out_size = [-1, sh[1] * 2, sh[2] * 2, sh[3]]
-        return tf.reshape(out, out_size)
-    else:
-        sh = tf.shape(x)
-        return tf.reshape(out, [-1, sh[1] * 2, sh[2] * 2, sh[3]])
-
-
-def MaxPooling(x, shape, stride=None, padding='VALID'):
-    """
-    MaxPooling on images.
-    :param input: NHWC tensor.
-    :param shape: int or [h, w]
-    :param stride: int or [h, w]. default to be shape.
-    :param padding: 'valid' or 'same'. default to 'valid'
-    :returns: NHWC tensor.
-    """
-    padding = padding.upper()
-    shape = shape4d(shape)
-    if stride is None:
-        stride = shape
-    else:
-        stride = shape4d(stride)
-
-    return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding)
-
-
-#@layer_register()
-def FixedUnPooling(x, shape):
-    """
-    Unpool the input with a fixed mat to perform kronecker product with.
-    :param input: NHWC tensor
-    :param shape: int or [h, w]
-    :returns: NHWC tensor
-    """
-    shape = shape2d(shape)
-  
-     # a faster implementation for this special case
-    return UnPooling2x2ZeroFilled(x)
-
-
-def gdl(gen_frames, gt_frames, alpha):
-    """
-    Calculates the sum of GDL losses between the predicted and gt frames.
-    @param gen_frames: The predicted frames at each scale.
-    @param gt_frames: The ground truth frames at each scale
-    @param alpha: The power to which each gradient term is raised.
-    @return: The GDL loss.
-    """
-    # create filters [-1, 1] and [[1],[-1]]
-    # for diffing to the left and down respectively.
-    pos = tf.constant(np.identity(3), dtype=tf.float32)
-    neg = -1 * pos
-    # [-1, 1]
-    filter_x = tf.expand_dims(tf.stack([neg, pos]), 0)
-    # [[1],[-1]]
-    filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)])
-    strides = [1, 1, 1, 1]  # stride of (1, 1)
-    padding = 'SAME'
-
-    gen_dx = tf.abs(tf.nn.conv2d(gen_frames, filter_x, strides, padding=padding))
-    gen_dy = tf.abs(tf.nn.conv2d(gen_frames, filter_y, strides, padding=padding))
-    gt_dx = tf.abs(tf.nn.conv2d(gt_frames, filter_x, strides, padding=padding))
-    gt_dy = tf.abs(tf.nn.conv2d(gt_frames, filter_y, strides, padding=padding))
-
-    grad_diff_x = tf.abs(gt_dx - gen_dx)
-    grad_diff_y = tf.abs(gt_dy - gen_dy)
-
-    gdl_loss = tf.reduce_mean((grad_diff_x ** alpha + grad_diff_y ** alpha))
-
-    # condense into one tensor and avg
-    return gdl_loss
-
-
-def linear(input_, output_size, name, stddev=0.02, bias_start=0.0,
-           reuse=False, with_w=False):
-    shape = input_.get_shape().as_list()
-
-    with tf.variable_scope(name, reuse=reuse):
-        matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
-                                 tf.random_normal_initializer(stddev=stddev))
-        bias = tf.get_variable("bias", [output_size],
-            initializer=tf.constant_initializer(bias_start))
-        if with_w:
-            return tf.matmul(input_, matrix) + bias, matrix, bias
-        else:
-            return tf.matmul(input_, matrix) + bias
diff --git a/video_prediction_tools/model_modules/video_prediction/layers/normalization.py b/video_prediction_tools/model_modules/video_prediction/layers/normalization.py
deleted file mode 100644
index e6f79effdb83ec43224dcf808e8d7e7d55fcdc60..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/layers/normalization.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at 
-# 
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-"""Contains the normalization layer classes and their functional aliases."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from tensorflow.contrib.framework.python.ops import variables
-from tensorflow.contrib.layers.python.layers import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import variable_scope
-
-
-DATA_FORMAT_NCHW = 'NCHW'
-DATA_FORMAT_NHWC = 'NHWC'
-
-
-def fused_instance_norm(inputs,
-                        center=True,
-                        scale=True,
-                        epsilon=1e-6,
-                        activation_fn=None,
-                        param_initializers=None,
-                        reuse=None,
-                        variables_collections=None,
-                        outputs_collections=None,
-                        trainable=True,
-                        data_format=DATA_FORMAT_NHWC,
-                        scope=None):
-  """Functional interface for the instance normalization layer.
-
-  Reference: https://arxiv.org/abs/1607.08022.
-
-    "Instance Normalization: The Missing Ingredient for Fast Stylization"
-    Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky
-
-  Args:
-    inputs: A tensor with 2 or more dimensions, where the first dimension has
-      `batch_size`. The normalization is over all but the last dimension if
-      `data_format` is `NHWC` and the second dimension if `data_format` is
-      `NCHW`.
-    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
-      is ignored.
-    scale: If True, multiply by `gamma`. If False, `gamma` is
-      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
-      disabled since the scaling can be done by the next layer.
-    epsilon: Small float added to variance to avoid dividing by zero.
-    activation_fn: Activation function, default set to None to skip it and
-      maintain a linear activation.
-    param_initializers: Optional initializers for beta, gamma, moving mean and
-      moving variance.
-    reuse: Whether or not the layer and its variables should be reused. To be
-      able to reuse the layer scope must be given.
-    variables_collections: Optional collections for the variables.
-    outputs_collections: Collections to add the outputs.
-    trainable: If `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    data_format: A string. `NHWC` (default) and `NCHW` are supported.
-    scope: Optional scope for `variable_scope`.
-
-  Returns:
-    A `Tensor` representing the output of the operation.
-
-  Raises:
-    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
-    ValueError: If the rank of `inputs` is undefined.
-    ValueError: If rank or channels dimension of `inputs` is undefined.
-  """
-  inputs = ops.convert_to_tensor(inputs)
-  inputs_shape = inputs.shape
-  inputs_rank = inputs.shape.ndims
-
-  if inputs_rank is None:
-    raise ValueError('Inputs %s has undefined rank.' % inputs.name)
-  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
-    raise ValueError('data_format has to be either NCHW or NHWC.')
-
-  with variable_scope.variable_scope(
-      scope, 'InstanceNorm', [inputs], reuse=reuse) as sc:
-    if data_format == DATA_FORMAT_NCHW:
-      reduction_axis = 1
-      # For NCHW format, rather than relying on implicit broadcasting, we
-      # explicitly reshape the params to params_shape_broadcast when computing
-      # the moments and the batch normalization.
-      params_shape_broadcast = list(
-          [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
-    else:
-      reduction_axis = inputs_rank - 1
-      params_shape_broadcast = None
-    moments_axes = list(range(inputs_rank))
-    del moments_axes[reduction_axis]
-    del moments_axes[0]
-    params_shape = inputs_shape[reduction_axis:reduction_axis + 1]
-    if not params_shape.is_fully_defined():
-      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
-          inputs.name, params_shape))
-
-    # Allocate parameters for the beta and gamma of the normalization.
-    beta, gamma = None, None
-    dtype = inputs.dtype.base_dtype
-    if param_initializers is None:
-      param_initializers = {}
-    if center:
-      beta_collections = utils.get_variable_collections(
-          variables_collections, 'beta')
-      beta_initializer = param_initializers.get(
-          'beta', init_ops.zeros_initializer())
-      beta = variables.model_variable('beta',
-                                      shape=params_shape,
-                                      dtype=dtype,
-                                      initializer=beta_initializer,
-                                      collections=beta_collections,
-                                      trainable=trainable)
-      if params_shape_broadcast:
-        beta = array_ops.reshape(beta, params_shape_broadcast)
-    if scale:
-      gamma_collections = utils.get_variable_collections(
-          variables_collections, 'gamma')
-      gamma_initializer = param_initializers.get(
-          'gamma', init_ops.ones_initializer())
-      gamma = variables.model_variable('gamma',
-                                       shape=params_shape,
-                                       dtype=dtype,
-                                       initializer=gamma_initializer,
-                                       collections=gamma_collections,
-                                       trainable=trainable)
-      if params_shape_broadcast:
-        gamma = array_ops.reshape(gamma, params_shape_broadcast)
-
-    if data_format == DATA_FORMAT_NHWC:
-      inputs = array_ops.transpose(inputs, list(range(1, reduction_axis)) + [0, reduction_axis])
-    if data_format == DATA_FORMAT_NCHW:
-      inputs = array_ops.transpose(inputs, list(range(2, inputs_rank)) + [0, reduction_axis])
-    hw, n, c = inputs.shape.as_list()[:-2], inputs.shape[-2].value, inputs.shape[-1].value
-    inputs = array_ops.reshape(inputs, [1] + hw + [n * c])
-    if inputs.shape.ndims != 4:
-        # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W]
-        if inputs.shape.ndims > 4:
-            inputs_ndims4_shape = [1, hw[0], -1, n * c]
-        else:
-            inputs_ndims4_shape = [1, 1, -1, n * c]
-        inputs = array_ops.reshape(inputs, inputs_ndims4_shape)
-    beta = array_ops.reshape(array_ops.tile(beta[None, :], [n, 1]), [-1])
-    gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [n, 1]), [-1])
-
-    outputs, _, _ = nn.fused_batch_norm(
-        inputs, gamma, beta, epsilon=epsilon,
-        data_format=DATA_FORMAT_NHWC, name='instancenorm')
-
-    outputs = array_ops.reshape(outputs, hw + [n, c])
-    if data_format == DATA_FORMAT_NHWC:
-      outputs = array_ops.transpose(outputs, [inputs_rank - 2] + list(range(inputs_rank - 2)) + [inputs_rank - 1])
-    if data_format == DATA_FORMAT_NCHW:
-      outputs = array_ops.transpose(outputs, [inputs_rank - 2, inputs_rank - 1] + list(range(inputs_rank - 2)))
-
-    # if data_format == DATA_FORMAT_NHWC:
-    #   inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis)))
-    # inputs_nchw_shape = inputs.shape
-    # inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:])
-    # if inputs.shape.ndims != 4:
-    #     # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W]
-    #     if inputs.shape.ndims > 4:
-    #         inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]]
-    #     else:
-    #         inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1]
-    #     inputs = array_ops.reshape(inputs, inputs_ndims4_shape)
-    # beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
-    # gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
-    #
-    # outputs, _, _ = nn.fused_batch_norm(
-    #     inputs, gamma, beta, epsilon=epsilon,
-    #     data_format=DATA_FORMAT_NCHW, name='instancenorm')
-    #
-    # outputs = array_ops.reshape(outputs, inputs_nchw_shape)
-    # if data_format == DATA_FORMAT_NHWC:
-    #   outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1])
-
-    if activation_fn is not None:
-      outputs = activation_fn(outputs)
-    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
diff --git a/video_prediction_tools/model_modules/video_prediction/losses.py b/video_prediction_tools/model_modules/video_prediction/losses.py
index 4ba586d5a51e8ff065c98afebae26981b961f997..bdf4f651fce3896fd01fcf34092cfff9c68a029c 100644
--- a/video_prediction_tools/model_modules/video_prediction/losses.py
+++ b/video_prediction_tools/model_modules/video_prediction/losses.py
@@ -4,8 +4,6 @@
 
 import tensorflow as tf
 
-from model_modules.video_prediction.ops import sigmoid_kl_with_logits
-
 
 def l1_loss(pred, target):
     return tf.reduce_mean(tf.abs(target - pred))
@@ -57,15 +55,3 @@ def gan_loss(logits, labels, gan_loss_type):
         raise ValueError('Unknown GAN loss type %s' % gan_loss_type)
     return loss
 
-
-def kl_loss(mu, log_sigma_sq, mu2=None, log_sigma2_sq=None):
-    if mu2 is None and log_sigma2_sq is None:
-        sigma_sq = tf.exp(log_sigma_sq)
-        return -0.5 * tf.reduce_mean(tf.reduce_sum(1 + log_sigma_sq - tf.square(mu) - sigma_sq, axis=-1))
-    else:
-        mu1 = mu
-        log_sigma1_sq = log_sigma_sq
-        return tf.reduce_mean(tf.reduce_sum(
-            (log_sigma2_sq - log_sigma1_sq) / 2
-            + (tf.exp(log_sigma1_sq) + tf.square(mu1 - mu2)) / (2 * tf.exp(log_sigma2_sq))
-            - 1 / 2, axis=-1))
diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
index 1bb913f18ccc7b9af2053b446584594bc6875c1b..10c5c72e932e38597731a61555f56803d639bde6 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
@@ -2,21 +2,12 @@
 #
 # SPDX-License-Identifier: MIT
 
-from .base_model import BaseVideoPredictionModel
-from .base_model import VideoPredictionModel
-from .non_trainable_model import NonTrainableVideoPredictionModel
-from .non_trainable_model import GroundTruthVideoPredictionModel
-from .non_trainable_model import RepeatVideoPredictionModel
-from .savp_model import SAVPVideoPredictionModel
-from .vanilla_vae_model import VanillaVAEVideoPredictionModel
+
 from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel
-from .mcnet_model import McNetVideoPredictionModel
 from .test_model import TestModelVideoPredictionModel
 from model_modules.model_architectures import known_models
 from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel
 from .weatherBench3DCNN import WeatherBenchModel
-#from .vanilla_predrnnv2 import PredRNNv2VideoPredictionModel
-
 
 def get_model_class(model):
     model_mappings = known_models()
diff --git a/video_prediction_tools/model_modules/video_prediction/models/base_model.py b/video_prediction_tools/model_modules/video_prediction/models/base_model.py
deleted file mode 100644
index 8ac05c98732eeca53c0356845faa3c8db59cb527..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/models/base_model.py
+++ /dev/null
@@ -1,880 +0,0 @@
-# SPDX-FileCopyrightText: 2018, alexlee-gk
-#
-# SPDX-License-Identifier: MIT
-
-import functools
-import itertools
-import os
-import re
-from collections import OrderedDict
-import numpy as np
-import tensorflow as tf
-print('tensorflow version: {}'.format(tf.__version__))
-print(tf.contrib.training.HParams)
-from tensorflow.contrib.training import HParams
-from tensorflow.python.util import nest
-import model_modules.video_prediction as vp
-from model_modules.video_prediction.utils import tf_utils
-from model_modules.video_prediction.utils.tf_utils import compute_averaged_gradients, reduce_tensors, local_device_setter, \
-    replace_read_ops, print_loss_info, transpose_batch_time, add_gif_summaries, add_scalar_summaries, \
-    add_plot_and_scalar_summaries, add_summaries
-
-
-class BaseVideoPredictionModel(object):
-    def __init__(self, mode='train', hparams_dict=None, hparams=None,
-                 num_gpus=None, eval_num_samples=100,
-                 eval_num_samples_for_diversity=10, eval_parallel_iterations=1):
-        """
-        Base video prediction model.
-
-        Trainable and non-trainable video prediction models can be derived
-        from this base class.
-
-        Args:
-            mode: `'train'` or `'test'`.
-            hparams_dict: a dict of `name=value` pairs, where `name` must be
-                defined in `self.get_default_hparams()`.
-            hparams: a string of comma separated list of `name=value` pairs,
-                where `name` must be defined in `self.get_default_hparams()`.
-                These values overrides any values in hparams_dict (if any).
-        """
-        if mode not in ('train', 'test'):
-            raise ValueError('mode must be train or test, but %s given' % mode)
-        self.mode = mode
-        cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
-        if cuda_visible_devices == '':
-            max_num_gpus = 0
-        else:
-            max_num_gpus = len(cuda_visible_devices.split(','))
-        if num_gpus is None:
-            num_gpus = max_num_gpus
-        elif num_gpus > max_num_gpus:
-            raise ValueError('num_gpus=%d is greater than the number of visible devices %d' % (num_gpus, max_num_gpus))
-        self.num_gpus = num_gpus
-        self.eval_num_samples = eval_num_samples
-        self.eval_num_samples_for_diversity = eval_num_samples_for_diversity
-        self.eval_parallel_iterations = eval_parallel_iterations
-        self.hparams = self.parse_hparams(hparams_dict, hparams)
-        if self.hparams.context_frames == -1:
-            raise ValueError('Invalid context_frames %r. It might have to be '
-                             'specified.' % self.hparams.context_frames)
-        if self.hparams.sequence_length == -1:
-            raise ValueError('Invalid sequence_length %r. It might have to be '
-                             'specified.' % self.hparams.sequence_length)
-
-        # should be overriden by descendant class if the model is stochastic
-        self.deterministic = True
-
-        # member variables that should be set by `self.build_graph`
-        self.inputs = None
-        self.gen_images = None
-        self.outputs = None
-        self.metrics = None
-        self.eval_outputs = None
-        self.eval_metrics = None
-        self.accum_eval_metrics = None
-        self.saveable_variables = None
-        self.post_init_ops = None
-        # ML 2021-06-23: Do not hide global step in self.saveable_variables
-        self.global_step = None
-
-    def get_default_hparams_dict(self):
-        """
-        The keys of this dict define valid hyperparameters for instances of
-        this class. A class inheriting from this one should override this
-        method if it has a different set of hyperparameters.
-
-        Returns:
-            A dict with the following hyperparameters.
-
-            context_frames: the number of ground-truth frames to pass in at
-                start. Must be specified during instantiation.
-            sequence_length: the number of frames in the video sequence,
-                including the context frames, so this model predicts
-                `sequence_length - context_frames` future frames. Must be
-                specified during instantiation.
-            repeat: the number of repeat actions (if applicable).
-            opt_var :string: "0","1",..."n", or "all", the target variable to be optimized in the loss function, if "all" means optimize all the variables and channels
-
-        """
-        hparams = dict(
-            context_frames=-1,
-            sequence_length=-1,
-            repeat=1,
-            opt_var="0"
-        )
-        return hparams
-
-    def get_default_hparams(self):
-        return HParams(**self.get_default_hparams_dict())
-
-    def parse_hparams(self, hparams_dict, hparams):
-        parsed_hparams = self.get_default_hparams().override_from_dict(hparams_dict or {})
-        if hparams:
-            if not isinstance(hparams, (list, tuple)):
-                hparams = [hparams]
-            for hparam in hparams:
-                parsed_hparams.parse(hparam)
-        return parsed_hparams
-
-    def build_graph(self, inputs):
-        self.inputs = inputs
-
-    def metrics_fn(self, inputs, outputs):
-        metrics = OrderedDict()
-        sequence_length = tf.shape(inputs['images'])[0]
-        context_frames = self.hparams.context_frames
-        future_length = sequence_length - context_frames
-        # target_images and pred_images include only the future frames
-        target_images = inputs['images'][-future_length:]
-        pred_images = outputs['gen_images'][-future_length:]
-        metric_fns = [
-            ('psnr', vp.metrics.psnr),
-            ('mse', vp.metrics.mse),
-            ('ssim', vp.metrics.ssim),
-            #('lpips', vp.metrics.lpips), #bing : remove lpips metric course the url fetching issue
-        ]
-        for metric_name, metric_fn in metric_fns:
-            metrics[metric_name] = tf.reduce_mean(metric_fn(target_images, pred_images))
-        return metrics
-
-    def eval_outputs_and_metrics_fn(self, inputs, outputs, num_samples=None,
-                                    num_samples_for_diversity=None, parallel_iterations=None):
-        num_samples = num_samples or self.eval_num_samples
-        num_samples_for_diversity = num_samples_for_diversity or self.eval_num_samples_for_diversity
-        parallel_iterations = parallel_iterations or self.eval_parallel_iterations
-
-        sequence_length, batch_size = inputs['images'].shape[:2].as_list()
-        if batch_size is None:
-            batch_size = tf.shape(inputs['images'])[1]
-        if sequence_length is None:
-            sequence_length = tf.shape(inputs['images'])[0]
-        context_frames = self.hparams.context_frames
-        future_length = sequence_length - context_frames
-        # the outputs include all the frames, whereas the metrics include only the future frames
-        eval_outputs = OrderedDict()
-        eval_metrics = OrderedDict()
-        metric_fns = [
-            ('psnr', vp.metrics.psnr),
-            ('mse', vp.metrics.mse),
-            ('ssim', vp.metrics.ssim),
-           # ('lpips', vp.metrics.lpips), #bing
-        ]
-        # images and gen_images include all the frames
-        images = inputs['images']
-        gen_images = outputs['gen_images']
-        # target_images and pred_images include only the future frames
-        target_images = inputs['images'][-future_length:]
-        pred_images = outputs['gen_images'][-future_length:]
-        # ground truth is the same for deterministic and stochastic models
-        eval_outputs['eval_images'] = images
-        if self.deterministic:
-            for metric_name, metric_fn in metric_fns:
-                metric = metric_fn(target_images, pred_images)
-                eval_metrics['eval_%s/min' % metric_name] = metric
-                eval_metrics['eval_%s/avg' % metric_name] = metric
-                eval_metrics['eval_%s/max' % metric_name] = metric
-            eval_outputs['eval_gen_images'] = gen_images
-        else:
-            def where_axis1(cond, x, y):
-                return transpose_batch_time(tf.where(cond, transpose_batch_time(x), transpose_batch_time(y)))
-
-            def sort_criterion(x):
-                return tf.reduce_mean(x, axis=0)
-
-            def accum_gen_images_and_metrics_fn(a, unused):
-                with tf.variable_scope(self.generator_scope, reuse=True):
-                    outputs_sample = self.generator_fn(inputs)
-                    gen_images_sample = outputs_sample['gen_images']
-                    pred_images_sample = gen_images_sample[-future_length:]
-                    # set the posisbly static shape since it might not have been inferred correctly
-                    pred_images_sample = tf.reshape(pred_images_sample, tf.shape(a['eval_pred_images_last']))
-                for name, metric_fn in metric_fns:
-                    metric = metric_fn(target_images, pred_images_sample)  # time, batch_size
-                    cond_min = tf.less(sort_criterion(metric), sort_criterion(a['eval_%s/min' % name]))
-                    cond_max = tf.greater(sort_criterion(metric), sort_criterion(a['eval_%s/max' % name]))
-                    a['eval_%s/min' % name] = where_axis1(cond_min, metric, a['eval_%s/min' % name])
-                    a['eval_%s/sum' % name] = metric + a['eval_%s/sum' % name]
-                    a['eval_%s/max' % name] = where_axis1(cond_max, metric, a['eval_%s/max' % name])
-                    a['eval_gen_images_%s/min' % name] = where_axis1(cond_min, gen_images_sample, a['eval_gen_images_%s/min' % name])
-                    a['eval_gen_images_%s/sum' % name] = gen_images_sample + a['eval_gen_images_%s/sum' % name]
-                    a['eval_gen_images_%s/max' % name] = where_axis1(cond_max, gen_images_sample, a['eval_gen_images_%s/max' % name])
-                #bing
-                # a['eval_diversity'] = tf.cond(
-                #     tf.logical_and(tf.less(0, a['eval_sample_ind']),
-                #                    tf.less_equal(a['eval_sample_ind'], num_samples_for_diversity)),
-                #     lambda: -vp.metrics.lpips(a['eval_pred_images_last'], pred_images_sample) + a['eval_diversity'],
-                #     lambda: a['eval_diversity'])
-                a['eval_sample_ind'] = 1 + a['eval_sample_ind']
-                a['eval_pred_images_last'] = pred_images_sample
-                return a
-
-            initializer = {}
-            for name, _ in metric_fns:
-                initializer['eval_gen_images_%s/min' % name] = tf.zeros_like(gen_images)
-                initializer['eval_gen_images_%s/sum' % name] = tf.zeros_like(gen_images)
-                initializer['eval_gen_images_%s/max' % name] = tf.zeros_like(gen_images)
-                initializer['eval_%s/min' % name] = tf.fill([future_length, batch_size], float('inf'))
-                initializer['eval_%s/sum' % name] = tf.zeros([future_length, batch_size])
-                initializer['eval_%s/max' % name] = tf.fill([future_length, batch_size], float('-inf'))
-            #initializer['eval_diversity'] = tf.zeros([future_length, batch_size])
-            initializer['eval_sample_ind'] = tf.zeros((), dtype=tf.int32)
-            initializer['eval_pred_images_last'] = tf.zeros_like(pred_images)
-
-            eval_outputs_and_metrics = tf.foldl(
-                accum_gen_images_and_metrics_fn, tf.zeros([num_samples, 0]), initializer=initializer, back_prop=False,
-                parallel_iterations=parallel_iterations)
-
-            for name, _ in metric_fns:
-                eval_outputs['eval_gen_images_%s/min' % name] = eval_outputs_and_metrics['eval_gen_images_%s/min' % name]
-                eval_outputs['eval_gen_images_%s/avg' % name] = eval_outputs_and_metrics['eval_gen_images_%s/sum' % name] / float(num_samples)
-                eval_outputs['eval_gen_images_%s/max' % name] = eval_outputs_and_metrics['eval_gen_images_%s/max' % name]
-                eval_metrics['eval_%s/min' % name] = eval_outputs_and_metrics['eval_%s/min' % name]
-                eval_metrics['eval_%s/avg' % name] = eval_outputs_and_metrics['eval_%s/sum' % name] / float(num_samples)
-                eval_metrics['eval_%s/max' % name] = eval_outputs_and_metrics['eval_%s/max' % name]
-            #eval_metrics['eval_diversity'] = eval_outputs_and_metrics['eval_diversity'] / float(num_samples_for_diversity)
-        return eval_outputs, eval_metrics
-
-    def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
-        
-        method = BaseVideoPredictionModel.restore.__name__
-        if checkpoints:
-            var_list = self.saveable_variables
-            # possibly restore from multiple checkpoints. useful if subset of weights
-            # (e.g. generator or discriminator) are on different checkpoints.
-            if not isinstance(checkpoints, (list, tuple)):
-                checkpoints = [checkpoints]
-            # automatically skip global_step if more than one checkpoint is provided
-            skip_global_step = len(checkpoints) > 1
-            savers = []
-            for checkpoint in checkpoints:
-                print("%{0}: Creating restore saver from checkpoint '{1}'".format(method, checkpoint))
-                saver, _ = tf_utils.get_checkpoint_restore_saver(
-                    checkpoint, var_list, skip_global_step=skip_global_step,
-                    restore_to_checkpoint_mapping=restore_to_checkpoint_mapping)
-                savers.append(saver)
-            restore_op = [saver.saver_def.restore_op_name for saver in savers]
-            sess.run(restore_op)
-            return True
-        else:
-            return False
-
-class VideoPredictionModel(BaseVideoPredictionModel):
-    def __init__(self,
-                 generator_fn,
-                 discriminator_fn=None,
-                 generator_scope='generator',
-                 discriminator_scope='discriminator',
-                 aggregate_nccl=False,
-                 mode='train',
-                 hparams_dict=None,
-                 hparams=None,
-                 **kwargs):
-        """
-        Trainable video prediction model with CPU and multi-GPU support.
-
-        If num_gpus <= 1, the devices for the ops in `self.build_graph` are
-        automatically chosen by TensorFlow (i.e. `tf.device` is not specified),
-        otherwise they are explicitly chosen.
-
-        Args:
-            generator_fn: callable that takes in inputs and returns a dict of
-                tensors.
-            discriminator_fn: callable that takes in fake/real data (and
-                optionally conditioned on inputs) and returns a dict of
-                tensors.
-            hparams_dict: a dict of `name=value` pairs, where `name` must be
-                defined in `self.get_default_hparams()`.
-            hparams: a string of comma separated list of `name=value` pairs,
-                where `name` must be defined in `self.get_default_hparams()`.
-                These values overrides any values in hparams_dict (if any).
-        """
-        super(VideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
-        self.generator_fn = functools.partial(generator_fn, mode=self.mode, hparams=self.hparams)
-        self.discriminator_fn = functools.partial(discriminator_fn, mode=self.mode, hparams=self.hparams) if discriminator_fn else None
-        self.generator_scope = generator_scope
-        self.discriminator_scope = discriminator_scope
-        self.aggregate_nccl = aggregate_nccl
-
-        if any(self.hparams.lr_boundaries):
-            global_step = tf.train.get_or_create_global_step()
-            lr_values = list(self.hparams.lr * 0.1 ** np.arange(len(self.hparams.lr_boundaries) + 1))
-            self.learning_rate = tf.train.piecewise_constant(global_step, self.hparams.lr_boundaries, lr_values)
-        elif any(self.hparams.decay_steps):
-            lr, end_lr = self.hparams.lr, self.hparams.end_lr
-            start_step, end_step = self.hparams.decay_steps
-            if start_step == end_step:
-                schedule = tf.cond(tf.less(tf.train.get_or_create_global_step(), start_step),
-                                   lambda: 0.0, lambda: 1.0)
-            else:
-                step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step)
-                schedule = tf.to_float(step - start_step) / tf.to_float(end_step - start_step)
-            self.learning_rate = lr + (end_lr - lr) * schedule
-        else:
-            self.learning_rate = self.hparams.lr
-
-        if self.hparams.kl_weight:
-            if self.hparams.kl_anneal == 'none':
-                self.kl_weight = tf.constant(self.hparams.kl_weight, tf.float32)
-            elif self.hparams.kl_anneal == 'sigmoid':
-                k = self.hparams.kl_anneal_k
-                if k == -1.0:
-                    raise ValueError('Invalid kl_anneal_k %d when kl_anneal is sigmoid.' % k)
-                iter_num = tf.train.get_or_create_global_step()
-                self.kl_weight = self.hparams.kl_weight / (1 + k * tf.exp(-tf.to_float(iter_num) / k))
-            elif self.hparams.kl_anneal == 'linear':
-                start_step, end_step = self.hparams.kl_anneal_steps
-                step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step)
-                self.kl_weight = self.hparams.kl_weight * tf.to_float(step - start_step) / tf.to_float(end_step - start_step)
-            else:
-                raise NotImplementedError
-        else:
-            self.kl_weight = None
-
-        # member variables that should be set by `self.build_graph`
-        # (in addition to the ones in the base class)
-        self.gen_images_enc = None
-        self.g_losses = None
-        self.d_losses = None
-        self.g_loss = None
-        self.d_loss = None
-        self.g_vars = None
-        self.d_vars = None
-        self.train_op = None
-        self.summary_op = None
-        self.image_summary_op = None
-        self.eval_summary_op = None
-        self.accum_eval_summary_op = None
-        self.accum_eval_metrics_reset_op = None
-
-    def get_default_hparams_dict(self):
-        """
-        The keys of this dict define valid hyperparameters for instances of
-        this class. A class inheriting from this one should override this
-        method if it has a different set of hyperparameters.
-
-        Returns:
-            A dict with the following hyperparameters.
-
-            batch_size: batch size for training.
-            lr: learning rate. if decay steps is non-zero, this is the
-                learning rate for steps <= decay_step.
-            end_lr: learning rate for steps >= end_decay_step if decay_steps
-                is non-zero, ignored otherwise.
-            decay_steps: (decay_step, end_decay_step) tuple.
-            max_steps: number of training steps.
-            beta1: momentum term of Adam.
-            beta2: momentum term of Adam.
-            context_frames: the number of ground-truth frames to pass in at
-                start. Must be specified during instantiation.
-            sequence_length: the number of frames in the video sequence,
-                including the context frames, so this model predicts
-                `sequence_length - context_frames` future frames. Must be
-                specified during instantiation.
-        """
-        default_hparams = super(VideoPredictionModel, self).get_default_hparams_dict()
-        hparams = dict(
-            batch_size=16,
-            lr=0.001,
-            end_lr=0.0,
-            decay_steps=(200000, 300000),
-            lr_boundaries=(0,),
-            max_epochs=35,
-            beta1=0.9,
-            beta2=0.999,
-            context_frames=-1,
-            sequence_length=-1,
-            clip_length=10, 
-            l1_weight=0.0,
-            l2_weight=1.0,
-            vgg_cdist_weight=0.0,
-            feature_l2_weight=0.0,
-            ae_l2_weight=0.0,
-            state_weight=0.0,
-            tv_weight=0.0,
-            image_sn_gan_weight=0.0,
-            image_sn_vae_gan_weight=0.0,
-            images_sn_gan_weight=0.0,
-            images_sn_vae_gan_weight=0.0,
-            video_sn_gan_weight=0.0,
-            video_sn_vae_gan_weight=0.0,
-            gan_feature_l2_weight=0.0,
-            gan_feature_cdist_weight=0.0,
-            vae_gan_feature_l2_weight=0.0,
-            vae_gan_feature_cdist_weight=0.0,
-            gan_loss_type='LSGAN',
-            joint_gan_optimization=False,
-            kl_weight=0.0,
-            kl_anneal='linear',
-            kl_anneal_k=-1.0,
-            kl_anneal_steps=(50000, 100000),
-            z_l1_weight=0.0,
-        )
-        return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-    def tower_fn(self, inputs):
-        """
-        This method doesn't have side-effects. `inputs`, `targets`, and
-        `outputs` are batch-major but internal calculations use time-major
-        tensors.
-        """
-        # batch-major to time-major
-        inputs = nest.map_structure(transpose_batch_time, inputs)
-
-        with tf.variable_scope(self.generator_scope):
-            gen_outputs = self.generator_fn(inputs)
-
-        if self.discriminator_fn:
-            with tf.variable_scope(self.discriminator_scope) as discrim_scope:
-                discrim_outputs = self.discriminator_fn(inputs, gen_outputs)
-            # post-update discriminator tensors (i.e. after the discriminator weights have been updated)
-            with tf.variable_scope(discrim_scope, reuse=True):
-                discrim_outputs_post = self.discriminator_fn(inputs, gen_outputs)
-        else:
-            discrim_outputs = {}
-            discrim_outputs_post = {}
-
-        outputs = [gen_outputs, discrim_outputs]
-        total_num_outputs = sum([len(output) for output in outputs])
-        outputs = OrderedDict(itertools.chain(*[output.items() for output in outputs]))
-        assert len(outputs) == total_num_outputs  # ensure no output is lost because of repeated keys
-
-        if isinstance(self.learning_rate, tf.Tensor):
-            outputs['learning_rate'] = self.learning_rate
-        if isinstance(self.kl_weight, tf.Tensor):
-            outputs['kl_weight'] = self.kl_weight
-
-        if self.mode == 'train':
-            with tf.name_scope("discriminator_loss"):
-                d_losses = self.discriminator_loss_fn(inputs, outputs)
-                print_loss_info(d_losses, inputs, outputs)
-            with tf.name_scope("generator_loss"):
-                g_losses = self.generator_loss_fn(inputs, outputs)
-                print_loss_info(g_losses, inputs, outputs)
-                if discrim_outputs_post:
-                    outputs_post = OrderedDict(itertools.chain(gen_outputs.items(), discrim_outputs_post.items()))
-                    # generator losses after the discriminator weights have been updated
-                    g_losses_post = self.generator_loss_fn(inputs, outputs_post)
-                else:
-                    g_losses_post = g_losses
-        else:
-            d_losses = {}
-            g_losses = {}
-            g_losses_post = {}
-        with tf.name_scope("metrics"):
-            metrics = self.metrics_fn(inputs, outputs)
-        with tf.name_scope("eval_outputs_and_metrics"):
-            eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs)
-
-        # time-major to batch-major
-        outputs_tuple = (outputs, eval_outputs)
-        outputs_tuple = nest.map_structure(transpose_batch_time, outputs_tuple)
-        losses_tuple = (d_losses, g_losses, g_losses_post)
-        losses_tuple = nest.map_structure(tf.convert_to_tensor, losses_tuple)
-        loss_tuple = tuple(tf.accumulate_n([loss * weight for loss, weight in losses.values()])
-                           if losses else tf.zeros(()) for losses in losses_tuple)
-        metrics_tuple = (metrics, eval_metrics)
-        metrics_tuple = nest.map_structure(transpose_batch_time, metrics_tuple)
-        return outputs_tuple, losses_tuple, loss_tuple, metrics_tuple
-
-    def build_graph(self, inputs,finetune=False):
-        BaseVideoPredictionModel.build_graph(self, inputs)
-
-        global_step = tf.train.get_or_create_global_step()
-        # ML 2021-06-23: Do not hide global step in self.saveable_variables
-        self.global_step = global_step
-        # Capture the variables created from here until the train_op for the
-        # saveable_variables. Note that if variables are being reused (e.g.
-        # they were created by a previously built model), those variables won't
-        # be captured here.
-        original_global_variables = tf.global_variables()
-
-        if self.num_gpus <= 1:  # cpu or 1 gpu
-            outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs)
-            self.outputs, self.eval_outputs = outputs_tuple
-            self.d_losses, self.g_losses, g_losses_post = losses_tuple
-            self.d_loss, self.g_loss, g_loss_post = loss_tuple
-            self.metrics, self.eval_metrics = metrics_tuple
-
-            self.d_vars = tf.trainable_variables(self.discriminator_scope)
-            self.g_vars = tf.trainable_variables(self.generator_scope)
-            g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
-            d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
-        
-
-            if self.mode == 'train' and (self.d_losses or self.g_losses):
-                with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
-                    if self.d_losses:
-                        with tf.name_scope('d_compute_gradients'):
-                            d_gradvars = d_optimizer.compute_gradients(self.d_loss, var_list=self.d_vars)
-                        with tf.name_scope('d_apply_gradients'):
-                            d_train_op = d_optimizer.apply_gradients(d_gradvars)
-
-                    else:
-                        d_train_op = tf.no_op()
-                with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []):
-                    if g_losses_post:
-                        if not self.hparams.joint_gan_optimization:
-                            replace_read_ops(g_loss_post, self.d_vars)
-                        with tf.name_scope('g_compute_gradients'):
-                            g_gradvars = g_optimizer.compute_gradients(g_loss_post, var_list=self.g_vars)
-                        with tf.name_scope('g_apply_gradients'):
-                            g_train_op = g_optimizer.apply_gradients(g_gradvars)
-                    else:
-                        g_train_op = tf.no_op()
-                with tf.control_dependencies([g_train_op]):
-                    train_op = tf.assign_add(global_step, 1)
-                self.train_op = train_op
-            else:
-                self.train_op = None
-
-            global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-            self.saveable_variables = [global_step] + global_variables
-            self.post_init_ops = []
-        else:
-            if tf.get_variable_scope().name:
-                # This is because how variable scope works with empty strings when it's not the root scope, causing
-                # repeated forward slashes.
-                raise NotImplementedError('Unable to handle multi-gpu model created within a non-root variable scope.')
-
-            tower_inputs = [OrderedDict() for _ in range(self.num_gpus)]
-            for name, input in self.inputs.items():
-                input_splits = tf.split(input, self.num_gpus)  # assumes batch_size is divisible by num_gpus
-                for i in range(self.num_gpus):
-                    tower_inputs[i][name] = input_splits[i]
-
-            tower_outputs_tuple = []
-            tower_d_losses = []
-            tower_g_losses = []
-            tower_g_losses_post = []
-            tower_d_loss = []
-            tower_g_loss = []
-            tower_g_loss_post = []
-            tower_metrics_tuple = []
-            for i in range(self.num_gpus):
-                worker_device = '/gpu:%d' % i
-                if self.aggregate_nccl:
-                    scope_name = '' if i == 0 else 'v%d' % i
-                    scope_reuse = False
-                    device_setter = worker_device
-                else:
-                    scope_name = ''
-                    scope_reuse = i > 0
-                    device_setter = local_device_setter(worker_device=worker_device)
-                with tf.variable_scope(scope_name, reuse=scope_reuse):
-                    with tf.device(device_setter):
-                        outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(tower_inputs[i])
-                        tower_outputs_tuple.append(outputs_tuple)
-                        d_losses, g_losses, g_losses_post = losses_tuple
-                        tower_d_losses.append(d_losses)
-                        tower_g_losses.append(g_losses)
-                        tower_g_losses_post.append(g_losses_post)
-                        d_loss, g_loss, g_loss_post = loss_tuple
-                        tower_d_loss.append(d_loss)
-                        tower_g_loss.append(g_loss)
-                        tower_g_loss_post.append(g_loss_post)
-                        tower_metrics_tuple.append(metrics_tuple)
-            self.d_vars = tf.trainable_variables(self.discriminator_scope)
-            self.g_vars = tf.trainable_variables(self.generator_scope)
-
-            if self.aggregate_nccl:
-                scope_replica = lambda scope, i: ('' if i == 0 else 'v%d/' % i) + scope
-                tower_d_vars = [tf.trainable_variables(
-                    scope_replica(self.discriminator_scope, i)) for i in range(self.num_gpus)]
-                tower_g_vars = [tf.trainable_variables(
-                    scope_replica(self.generator_scope, i)) for i in range(self.num_gpus)]
-                assert self.d_vars == tower_d_vars[0]
-                assert self.g_vars == tower_g_vars[0]
-                tower_d_optimizer = [tf.train.AdamOptimizer(
-                    self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)]
-                tower_g_optimizer = [tf.train.AdamOptimizer(
-                    self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)]
-
-                if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)):
-                    tower_d_gradvars = []
-                    tower_g_gradvars = []
-                    tower_d_train_op = []
-                    tower_g_train_op = []
-                    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
-                        if any(tower_d_losses):
-                            for i in range(self.num_gpus):
-                                with tf.device('/gpu:%d' % i):
-                                    with tf.name_scope(scope_replica('d_compute_gradients', i)):
-                                        d_gradvars = tower_d_optimizer[i].compute_gradients(
-                                            tower_d_loss[i], var_list=tower_d_vars[i])
-                                        tower_d_gradvars.append(d_gradvars)
-
-                            all_d_grads, all_d_vars = tf_utils.split_grad_list(tower_d_gradvars)
-                            all_d_grads = tf_utils.allreduce_grads(all_d_grads, average=True)
-                            tower_d_gradvars = tf_utils.merge_grad_list(all_d_grads, all_d_vars)
-
-                            for i in range(self.num_gpus):
-                                with tf.device('/gpu:%d' % i):
-                                    with tf.name_scope(scope_replica('d_apply_gradients', i)):
-                                        d_train_op = tower_d_optimizer[i].apply_gradients(tower_d_gradvars[i])
-                                        tower_d_train_op.append(d_train_op)
-                            d_train_op = tf.group(*tower_d_train_op)
-                        else:
-                            d_train_op = tf.no_op()
-                    with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []):
-                        if any(tower_g_losses_post):
-                            for i in range(self.num_gpus):
-                                with tf.device('/gpu:%d' % i):
-                                    if not self.hparams.joint_gan_optimization:
-                                        replace_read_ops(tower_g_loss_post[i], tower_d_vars[i])
-
-                                    with tf.name_scope(scope_replica('g_compute_gradients', i)):
-                                        g_gradvars = tower_g_optimizer[i].compute_gradients(
-                                            tower_g_loss_post[i], var_list=tower_g_vars[i])
-                                        tower_g_gradvars.append(g_gradvars)
-
-                            all_g_grads, all_g_vars = tf_utils.split_grad_list(tower_g_gradvars)
-                            all_g_grads = tf_utils.allreduce_grads(all_g_grads, average=True)
-                            tower_g_gradvars = tf_utils.merge_grad_list(all_g_grads, all_g_vars)
-
-                            for i, g_gradvars in enumerate(tower_g_gradvars):
-                                with tf.device('/gpu:%d' % i):
-                                    with tf.name_scope(scope_replica('g_apply_gradients', i)):
-                                        g_train_op = tower_g_optimizer[i].apply_gradients(g_gradvars)
-                                        tower_g_train_op.append(g_train_op)
-                            g_train_op = tf.group(*tower_g_train_op)
-                        else:
-                            g_train_op = tf.no_op()
-                    with tf.control_dependencies([g_train_op]):
-                        train_op = tf.assign_add(global_step, 1)
-                    self.train_op = train_op
-                else:
-                    self.train_op = None
-
-                global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-                tower_saveable_vars = [[] for _ in range(self.num_gpus)]
-                for var in global_variables:
-                    m = re.match('v(\d+)/.*', var.name)
-                    i = int(m.group(1)) if m else 0
-                    tower_saveable_vars[i].append(var)
-                self.saveable_variables = [global_step] + tower_saveable_vars[0]
-
-                post_init_ops = []
-                for i, saveable_vars in enumerate(tower_saveable_vars[1:], 1):
-                    assert len(saveable_vars) == len(tower_saveable_vars[0])
-                    for var, var0 in zip(saveable_vars, tower_saveable_vars[0]):
-                        assert var.name == 'v%d/%s' % (i, var0.name)
-                        post_init_ops.append(var.assign(var0.read_value()))
-                self.post_init_ops = post_init_ops
-            else:  # not self.aggregate_nccl (i.e. aggregation in cpu)
-                g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
-                d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
-
-                if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)):
-                    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
-                        if any(tower_d_losses):
-                            with tf.name_scope('d_compute_gradients'):
-                                d_gradvars = compute_averaged_gradients(
-                                    d_optimizer, tower_d_loss, var_list=self.d_vars)
-                            with tf.name_scope('d_apply_gradients'):
-                                d_train_op = d_optimizer.apply_gradients(d_gradvars)
-                        else:
-                            d_train_op = tf.no_op()
-                    with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []):
-                        if any(tower_g_losses_post):
-                            for g_loss_post in tower_g_loss_post:
-                                if not self.hparams.joint_gan_optimization:
-                                    replace_read_ops(g_loss_post, self.d_vars)
-                            with tf.name_scope('g_compute_gradients'):
-                                g_gradvars = compute_averaged_gradients(
-                                    g_optimizer, tower_g_loss_post, var_list=self.g_vars)
-                            with tf.name_scope('g_apply_gradients'):
-                                g_train_op = g_optimizer.apply_gradients(g_gradvars)
-                        else:
-                            g_train_op = tf.no_op()
-                    with tf.control_dependencies([g_train_op]):
-                        train_op = tf.assign_add(global_step, 1)
-                    self.train_op = train_op
-                else:
-                    self.train_op = None
-
-                global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-                self.saveable_variables = [global_step] + global_variables
-                self.post_init_ops = []
-
-            # Device that runs the ops to apply global gradient updates.
-            consolidation_device = '/cpu:0'
-            with tf.device(consolidation_device):
-                with tf.name_scope('consolidation'):
-                    self.outputs, self.eval_outputs = reduce_tensors(tower_outputs_tuple)
-                    self.d_losses = reduce_tensors(tower_d_losses, shallow=True)
-                    self.g_losses = reduce_tensors(tower_g_losses, shallow=True)
-                    self.metrics, self.eval_metrics = reduce_tensors(tower_metrics_tuple)
-                    self.d_loss = reduce_tensors(tower_d_loss)
-                    self.g_loss = reduce_tensors(tower_g_loss)
-
-        original_local_variables = set(tf.local_variables())
-        self.accum_eval_metrics = OrderedDict()
-        for name, eval_metric in self.eval_metrics.items():
-            _, self.accum_eval_metrics['accum_' + name] = tf.metrics.mean_tensor(eval_metric)
-        local_variables = set(tf.local_variables()) - original_local_variables
-        self.accum_eval_metrics_reset_op = tf.group([tf.assign(v, tf.zeros_like(v)) for v in local_variables])
-
-        original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
-        add_summaries(self.inputs)
-        add_summaries(self.outputs)
-        add_scalar_summaries(self.d_losses)
-        add_scalar_summaries(self.g_losses)
-        add_scalar_summaries(self.metrics)
-        if self.d_losses:
-            add_scalar_summaries({'d_loss': self.d_loss})
-        if self.g_losses:
-            add_scalar_summaries({'g_loss': self.g_loss})
-        if self.d_losses and self.g_losses:
-            add_scalar_summaries({'loss': self.d_loss + self.g_loss})
-        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries
-        # split summaries into non-image summaries and image summaries
-        self.summary_op = tf.summary.merge(list(summaries - set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))))
-        self.image_summary_op = tf.summary.merge(list(summaries & set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))))
-
-        original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
-        add_gif_summaries(self.eval_outputs)
-        add_plot_and_scalar_summaries(
-            {name: tf.reduce_mean(metric, axis=0) for name, metric in self.eval_metrics.items()},
-            x_offset=self.hparams.context_frames + 1)
-        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries
-        self.eval_summary_op = tf.summary.merge(list(summaries))
-
-        original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
-        add_plot_and_scalar_summaries(
-            {name: tf.reduce_mean(metric, axis=0) for name, metric in self.accum_eval_metrics.items()},
-            x_offset=self.hparams.context_frames + 1)
-        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries
-        self.accum_eval_summary_op = tf.summary.merge(list(summaries))
-
-    def generator_loss_fn(self, inputs, outputs):
-        hparams = self.hparams
-        opt_var = self.hparams.opt_var
-        gen_losses = OrderedDict()
-        if opt_var == "all":
-            gen_images = outputs.get("gen_images_enc", outputs["gen_images"])
-            target_images = inputs["images"][1:]
-            print("The model is optimized on all variables/channels in the loss function")
-        elif opt_var != "all" and isinstance(opt_var,str):
-            opt_var = int(opt_var)
-            print("The model is optimized on the {} variable/channel in the loss function".format(opt_var))
-            gen_images = outputs.get("gen_images_enc", outputs["gen_images"])[:, :, :, :, opt_var:opt_var+1]
-            target_images = inputs["images"][1:][:, :, :, :, opt_var:opt_var+1]
-        else:
-            raise ValueError("The opt_var in the hyper-parameters setting should be int or 'all'")
-
-
-        if hparams.l1_weight:
-            gen_l1_loss = vp.losses.l1_loss(gen_images, target_images)
-            gen_losses["gen_l1_loss"] = (gen_l1_loss, hparams.l1_weight)
-        if hparams.l2_weight:
-            gen_l2_loss = vp.losses.l2_loss(gen_images, target_images)
-            gen_losses["gen_l2_loss"] = (gen_l2_loss, hparams.l2_weight)
-        if hparams.vgg_cdist_weight:
-            gen_vgg_cdist_loss = vp.metrics.vgg_cosine_distance(gen_images, target_images)
-            gen_losses['gen_vgg_cdist_loss'] = (gen_vgg_cdist_loss, hparams.vgg_cdist_weight)
-        if hparams.feature_l2_weight:
-            gen_features = outputs.get('gen_features_enc', outputs['gen_features'])
-            target_features = outputs['features'][1:]
-            gen_feature_l2_loss = vp.losses.l2_loss(gen_features, target_features)
-            gen_losses["gen_feature_l2_loss"] = (gen_feature_l2_loss, hparams.feature_l2_weight)
-        if hparams.ae_l2_weight:
-            gen_images_dec = outputs.get('gen_images_dec_enc', outputs['gen_images_dec'])  # they both should be the same
-            target_images = inputs['images']
-            gen_ae_l2_loss = vp.losses.l2_loss(gen_images_dec, target_images)
-            gen_losses["gen_ae_l2_loss"] = (gen_ae_l2_loss, hparams.ae_l2_weight)
-        if hparams.state_weight:
-            gen_states = outputs.get('gen_states_enc', outputs['gen_states'])
-            target_states = inputs['states'][1:]
-            gen_state_loss = vp.losses.l2_loss(gen_states, target_states)
-            gen_losses["gen_state_loss"] = (gen_state_loss, hparams.state_weight)
-        if hparams.tv_weight:
-            gen_flows = outputs.get('gen_flows_enc', outputs['gen_flows'])
-            flow_diff1 = gen_flows[..., 1:, :, :, :] - gen_flows[..., :-1, :, :, :]
-            flow_diff2 = gen_flows[..., :, 1:, :, :] - gen_flows[..., :, :-1, :, :]
-            # sum over the multiple transformations but take the mean for the other dimensions
-            gen_tv_loss = (tf.reduce_mean(tf.reduce_sum(tf.abs(flow_diff1), axis=(-2, -1))) +
-                           tf.reduce_mean(tf.reduce_sum(tf.abs(flow_diff2), axis=(-2, -1))))
-            gen_losses['gen_tv_loss'] = (gen_tv_loss, hparams.tv_weight)
-        gan_weights = {'_image_sn': hparams.image_sn_gan_weight,
-                       '_images_sn': hparams.images_sn_gan_weight,
-                       '_video_sn': hparams.video_sn_gan_weight}
-        for infix, gan_weight in gan_weights.items():
-            if gan_weight:
-                gen_gan_loss = vp.losses.gan_loss(outputs['discrim%s_logits_fake' % infix], 1.0, hparams.gan_loss_type)
-                gen_losses["gen%s_gan_loss" % infix] = (gen_gan_loss, gan_weight)
-            if gan_weight and (hparams.gan_feature_l2_weight or hparams.gan_feature_cdist_weight):
-                i_feature = 0
-                discrim_features_fake = []
-                discrim_features_real = []
-                while True:
-                    discrim_feature_fake = outputs.get('discrim%s_feature%d_fake' % (infix, i_feature))
-                    discrim_feature_real = outputs.get('discrim%s_feature%d_real' % (infix, i_feature))
-                    if discrim_feature_fake is None or discrim_feature_real is None:
-                        break
-                    discrim_features_fake.append(discrim_feature_fake)
-                    discrim_features_real.append(discrim_feature_real)
-                    i_feature += 1
-                if hparams.gan_feature_l2_weight:
-                    gen_gan_feature_l2_loss = sum([vp.losses.l2_loss(discrim_feature_fake, discrim_feature_real)
-                                                   for discrim_feature_fake, discrim_feature_real in zip(discrim_features_fake, discrim_features_real)])
-                    gen_losses["gen%s_gan_feature_l2_loss" % infix] = (gen_gan_feature_l2_loss, hparams.gan_feature_l2_weight)
-                if hparams.gan_feature_cdist_weight:
-                    gen_gan_feature_cdist_loss = sum([vp.losses.cosine_distance(discrim_feature_fake, discrim_feature_real)
-                                                      for discrim_feature_fake, discrim_feature_real in zip(discrim_features_fake, discrim_features_real)])
-                    gen_losses["gen%s_gan_feature_cdist_loss" % infix] = (gen_gan_feature_cdist_loss, hparams.gan_feature_cdist_weight)
-        vae_gan_weights = {'_image_sn': hparams.image_sn_vae_gan_weight,
-                           '_images_sn': hparams.images_sn_vae_gan_weight,
-                           '_video_sn': hparams.video_sn_vae_gan_weight}
-        for infix, vae_gan_weight in vae_gan_weights.items():
-            if vae_gan_weight:
-                gen_vae_gan_loss = vp.losses.gan_loss(outputs['discrim%s_logits_enc_fake' % infix], 1.0, hparams.gan_loss_type)
-                gen_losses["gen%s_vae_gan_loss" % infix] = (gen_vae_gan_loss, vae_gan_weight)
-            if vae_gan_weight and (hparams.vae_gan_feature_l2_weight or hparams.vae_gan_feature_cdist_weight):
-                i_feature = 0
-                discrim_features_enc_fake = []
-                discrim_features_enc_real = []
-                while True:
-                    discrim_feature_enc_fake = outputs.get('discrim%s_feature%d_enc_fake' % (infix, i_feature))
-                    discrim_feature_enc_real = outputs.get('discrim%s_feature%d_enc_real' % (infix, i_feature))
-                    if discrim_feature_enc_fake is None or discrim_feature_enc_real is None:
-                        break
-                    discrim_features_enc_fake.append(discrim_feature_enc_fake)
-                    discrim_features_enc_real.append(discrim_feature_enc_real)
-                    i_feature += 1
-                if hparams.vae_gan_feature_l2_weight:
-                    gen_vae_gan_feature_l2_loss = sum([vp.losses.l2_loss(discrim_feature_enc_fake, discrim_feature_enc_real)
-                                                       for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real)])
-                    gen_losses["gen%s_vae_gan_feature_l2_loss" % infix] = (gen_vae_gan_feature_l2_loss, hparams.vae_gan_feature_l2_weight)
-                if hparams.vae_gan_feature_cdist_weight:
-                    gen_vae_gan_feature_cdist_loss = sum([vp.losses.cosine_distance(discrim_feature_enc_fake, discrim_feature_enc_real)
-                                                          for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real)])
-                    gen_losses["gen%s_vae_gan_feature_cdist_loss" % infix] = (gen_vae_gan_feature_cdist_loss, hparams.vae_gan_feature_cdist_weight)
-        if hparams.kl_weight:
-            gen_kl_loss = vp.losses.kl_loss(outputs['zs_mu_enc'], outputs['zs_log_sigma_sq_enc'],
-                                            outputs.get('zs_mu_prior'), outputs.get('zs_log_sigma_sq_prior'))
-            gen_losses["gen_kl_loss"] = (gen_kl_loss, self.kl_weight)  # possibly annealed kl_weight
-        return gen_losses
-
-    def discriminator_loss_fn(self, inputs, outputs):
-        hparams = self.hparams
-        discrim_losses = OrderedDict()
-        gan_weights = {'_image_sn': hparams.image_sn_gan_weight,
-                       '_images_sn': hparams.images_sn_gan_weight,
-                       '_video_sn': hparams.video_sn_gan_weight}
-        for infix, gan_weight in gan_weights.items():
-            if gan_weight:
-                discrim_gan_loss_real = vp.losses.gan_loss(outputs['discrim%s_logits_real' % infix], 1.0, hparams.gan_loss_type)
-                discrim_gan_loss_fake = vp.losses.gan_loss(outputs['discrim%s_logits_fake' % infix], 0.0, hparams.gan_loss_type)
-                discrim_gan_loss = discrim_gan_loss_real + discrim_gan_loss_fake
-                discrim_losses["discrim%s_gan_loss" % infix] = (discrim_gan_loss, gan_weight)
-        vae_gan_weights = {'_image_sn': hparams.image_sn_vae_gan_weight,
-                           '_images_sn': hparams.images_sn_vae_gan_weight,
-                           '_video_sn': hparams.video_sn_vae_gan_weight}
-        for infix, vae_gan_weight in vae_gan_weights.items():
-            if vae_gan_weight:
-                discrim_vae_gan_loss_real = vp.losses.gan_loss(outputs['discrim%s_logits_enc_real' % infix], 1.0, hparams.gan_loss_type)
-                discrim_vae_gan_loss_fake = vp.losses.gan_loss(outputs['discrim%s_logits_enc_fake' % infix], 0.0, hparams.gan_loss_type)
-                discrim_vae_gan_loss = discrim_vae_gan_loss_real + discrim_vae_gan_loss_fake
-                discrim_losses["discrim%s_vae_gan_loss" % infix] = (discrim_vae_gan_loss, vae_gan_weight)
-        return discrim_losses
diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
index c4035963dc37b840e55cdacb2c7288b10e458cfd..5bad1430f9c7499965899fb948bd46bb8e7686c9 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
@@ -3,35 +3,27 @@
 # SPDX-License-Identifier: MIT
 
 __email__ = "b.gong@fz-juelich.de"
-__author__ = "Bing Gong,Yanji"
+__author__ = "Bing Gong"
 __date__ = "2021-04-13"
 
-from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
 import tensorflow as tf
-from model_modules.video_prediction.layers import layer_def as ld
-from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
+from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
+from model_modules.video_prediction.layers.layer_def import batch_norm
+from model_modules.video_prediction.models.vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel as convLSTM
 from .our_base_model import BaseModels
 
 class ConvLstmGANVideoPredictionModel(BaseModels):
-    def __init__(self, hparams_dict=None, mode='train', **kwargs):
-        """
-        This is class for building convLSTM_GAN architecture by using updated hparameters
-        args:
-             mode        :str,   "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model
-             hparams_dict: dict, the dictionary contains the hparaemters names and values
-        """
-        super().__init__(hparams_dict)
-        self.hparams = self.get_hparams()
-        self.mode = mode
+
+    def __init__(self, hparams_dict_config=None, mode='train'):
+        super().__init__(hparams_dict_config, mode)
         self.bd1 = batch_norm(name = "bd1")
         self.bd2 = batch_norm(name = "bd2")
+        self.bd3 = batch_norm(name = "dis3")
 
-    def get_hparams(self):
+    def parse_hparams(self, hparams):
         """
-        obtain the hparams from the dict to the class variables
+        Obtain the hparams from the dict to the class variables
         """
-        method = BaseModels.get_hparams.__name__
-
         try:
             self.context_frames = self.hparams.context_frames
             self.max_epochs = self.hparams.max_epochs
@@ -41,168 +33,153 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
             self.recon_weight = self.hparams.recon_weight
             self.learning_rate = self.hparams.lr
             self.sequence_length = self.hparams.sequence_length
+            self.opt_var = self.hparams.opt_var
             self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
             self.ngf = self.hparams.ngf
             self.ndf = self.hparams.ndf
 
-
         except Exception as error:
-           print("Method %{}: error: {}".format(method,error))
-           raise("Method %{}: the hparameter dictionary must include parameters above".format(method))
+           print("error: {}".format(error))
+           raise ValueError("Method %{}: the hyper-parameter dictionary must include parameters above")
 
 
     def build_graph(self, x: tf.Tensor):
 
-        self.is_build_graph = False
         self.inputs = x
+
         self.global_step = tf.train.get_or_create_global_step()
         original_global_variables = tf.global_variables()
-        # Architecture
-        self.build_model()
-        # define loss function
-        self.total_loss = (1-self.recon_weight) * self.G_loss + self.recon_weight*self.recon_loss
-        self.D_loss =  (1-self.recon_weight) * self.D_loss
+
+        #Build graph
+        x_hat = self.build_model(x)
+
+        #Get losses (reconstruciton loss, total loss and descriminator loss)
+        self.total_loss = self.get_loss(x, x_hat)
+
+        #Define optimizer
+        self.train_op = self.optimizer(self.total_loss)
+
+        #Save to outputs
+        self.outputs["gen_images"] = x_hat
+        self.outputs["total_loss"] = self.total_loss
+        # Summary op
+        sum_dict = {"total_loss": self.total_loss,
+                  "D_loss": self.D_loss,
+                  "G_loss": self.G_loss,
+                  "D_loss_fake": self.D_loss_fake,
+                  "D_loss_real": self.D_loss_real,
+                  "recon_loss": self.recon_loss}
+
+        self.summary_op = self.summary(**sum_dict)
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        self.is_build_graph = True
+        return self.is_build_graph
+
+    def get_loss(self, x: tf.Tensor, x_hat: tf.Tensor):
+        """
+        We use the loss from vanilla convolutional LSTM as reconstruction loss
+        """
+        self.G_loss = self.get_gen_loss()
+        self.D_loss = self.get_disc_loss()
+        self._get_vars()
+        #self.recon_loss = self.get_loss(self, x, x_hat) #use the loss from vanilla convLSTM
+
+        if self.opt_var == "all":
+            x = x[:, self.context_frames:, :, :, :]
+            print("The model is optimzied on all the variables in the loss function")
+        elif self.opt_var != "all" and isinstance(self.opt_var, str):
+            self.opt_var = int(self.opt_var)
+            print("The model is optimized on the {} variable in the loss function".format(self.opt_var))
+            x = x[:, self.context_frames:, :, :, self.opt_var]
+            x_hat = x_hat[:, :, :, :, self.opt_var]
+        else:
+            raise ValueError("The opt var in the hyperparameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables")
+
+        if self.loss_fun == "mse":
+            self.recon_loss = tf.reduce_mean(tf.square(x - x_hat))
+        elif self.loss_fun == "cross_entropy":
+            x_flatten = tf.reshape(x, [-1])
+            x_hat_predict_frames_flatten = tf.reshape(x_hat, [-1])
+            bce = tf.keras.losses.BinaryCrossentropy()
+            self.recon_loss = bce(x_flatten, x_hat_predict_frames_flatten)
+        else:
+            raise ValueError("Loss function is not selected properly, you should chose either 'mse' or 'cross_entropy'")
+
+        self.D_loss = (1 - self.recon_weight) * self.D_loss
+        total_loss = (1-self.recon_weight) * self.G_loss + self.recon_weight*self.recon_loss
+        return total_loss
+
+    def optimizer(self, *args):
 
         if self.mode == "train":
             if self.recon_weight == 1:
-                print("Only train generator- convLSTM") 
-                self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) 
+                print("Only train generator- ConvLSTM")
+                train_op = tf.train.AdamOptimizer(learning_rate =
+                                                       self.learning_rate).\
+                    minimize(self.total_loss, var_list=self.gen_vars)
             else:
-                print("Training distriminator")
-                self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars)
+                print("Training discriminator")
+                self.D_solver = tf.train.AdamOptimizer(learning_rate =self.learning_rate).\
+                    minimize(self.D_loss, var_list=self.disc_vars)
                 with tf.control_dependencies([self.D_solver]):
                     print("Training generator....")
-                    self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars)
+                    self.G_solver = tf.train.AdamOptimizer(learning_rate =self.learning_rate).\
+                        minimize(self.total_loss, var_list=self.gen_vars)
                 with tf.control_dependencies([self.G_solver]):
-                    self.train_op = tf.assign_add(self.global_step,1)
+                    train_op = tf.assign_add(self.global_step, 1)
         else:
-           self.train_op = None 
+           train_op = None
+        return train_op
 
-        self.outputs["gen_images"] = self.gen_images
-        self.outputs["total_loss"] = self.total_loss
-        # Summary op
-        tf.summary.scalar("total_loss", self.total_loss)
-        tf.summary.scalar("D_loss", self.D_loss)
-        tf.summary.scalar("G_loss", self.G_loss)
-        tf.summary.scalar("D_loss_fake", self.D_loss_fake) 
-        tf.summary.scalar("D_loss_real", self.D_loss_real)
-        tf.summary.scalar("recon_loss",self.recon_loss)
-        self.summary_op = tf.summary.merge_all()
-        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-        self.saveable_variables = [self.global_step] + global_variables
-        self.is_build_graph = True
-        return self.is_build_graph 
 
-    @staticmethod
-    def Unet_ConvLSTM_cell(x: tf.Tensor, ngf: int, hidden: tf.Tensor):
+    def build_model(self, x):
         """
-        Build up a Unet ConvLSTM cell for each time stamp i
-        params: x:     the input at timestamp i
-        params: ngf:   the numnber of filters for convoluational layers
-        params: hidden: the hidden state from the previous timestamp t-1
-        return:
-               outputs: the predict frame at timestamp i
-               hidden:  the hidden state at current timestamp i
+        Define gan architectures
         """
-        input_shape = x.get_shape().as_list()
-        num_channels = input_shape[3]
-        with tf.variable_scope("down_scale", reuse = tf.AUTO_REUSE):
-            conv1f = ld.conv_layer(x, 3 , 1, ngf, 1, initializer=tf.contrib.layers.xavier_initializer(), activate="relu")
-            conv1s = ld.conv_layer(conv1f, 3, 1, ngf, 2, initializer=tf.contrib.layers.xavier_initializer(), activate="relu")
-            pool1 = tf.layers.max_pooling2d(conv1s, pool_size=(2, 2), strides=(2, 2))
-            print('pool1 shape: ',pool1.shape)
-
-            conv2f = ld.conv_layer(pool1, 3, 1, ngf * 2, 3, initializer=tf.contrib.layers.xavier_initializer(), activate="relu")
-            conv2s = ld.conv_layer(conv2f, 3, 1, ngf * 2, 4, initializer = tf.contrib.layers.xavier_initializer(), activate = "relu")
-            pool2 = tf.layers.max_pooling2d(conv2s, pool_size=(2, 2), strides=(2, 2))
-            print('pool2 shape: ',pool2.shape)
-
-            conv3f = ld.conv_layer(pool2, 3, 1, ngf * 4, 5, initializer=tf.contrib.layers.xavier_initializer(), activate="relu")
-            conv3s = ld.conv_layer(conv3f, 3, 1, ngf * 4, 6, initializer = tf.contrib.layers.xavier_initializer(), activate = "relu")
-            pool3 = tf.layers.max_pooling2d(conv3s, pool_size=(2, 2), strides=(2, 2))
-            print('pool3 shape: ',pool3.shape)
-
-            convLSTM_input = pool3
-            #convLSTM_input = tf.layers.dropout(pool2, 0.8)
-
-        convLSTM4, hidden = ConvLstmGANVideoPredictionModel.convLSTM_cell(convLSTM_input, hidden)
-        print('convLSTM4 shape: ',convLSTM4.shape)
-  
-        with tf.variable_scope("upscale", reuse = tf.AUTO_REUSE):
-            deconv5 = ld.transpose_conv_layer(convLSTM4, 2, 2, ngf * 4, 1, initializer=tf.contrib.layers.xavier_initializer(), activate="relu")
-            print('deconv5 shape: ',deconv5.shape)
-            up5 = tf.concat([deconv5, conv3s], axis=3)
-            print('up5 shape: ',up5.shape)
-
-            conv5f = ld.conv_layer(up5, 3, 1, ngf * 4, 2, initializer = tf.contrib.layers.xavier_initializer(), activate="relu")
-            conv5s = ld.conv_layer(conv5f, 3, 1, ngf * 4, 3, initializer = tf.contrib.layers.xavier_initializer(), activate="relu")
-            print('conv5s shape:',conv5s.shape)
-
-            deconv6 = ld.transpose_conv_layer(conv5s, 2, 2, ngf * 2, 4, initializer=tf.contrib.layers.xavier_initializer(), activate="relu")
-            print('deconv6 shape: ',deconv6.shape)
-            up6 = tf.concat([deconv6, conv2s], axis=3)
-            print('up6 shape: ',up6.shape)
-
-            conv6f = ld.conv_layer(up6, 3, 1, ngf * 2, 5, initializer = tf.contrib.layers.xavier_initializer(), activate="relu")
-            conv6s = ld.conv_layer(conv6f, 3, 1, ngf * 2, 6, initializer = tf.contrib.layers.xavier_initializer(), activate="relu")
-            print('conv6s shape:',conv6s.shape)
-
-            deconv7 = ld.transpose_conv_layer(conv6s, 2, 2, ngf, 7, initializer = tf.contrib.layers.xavier_initializer(), activate="relu")
-            print('deconv7 shape: ',deconv7.shape)
-            up7 = tf.concat([deconv7, conv1s], axis=3)
-            print('up7 shape: ',up7.shape)
-
-            conv7f = ld.conv_layer(up7, 3, 1, ngf, 8, initializer = tf.contrib.layers.xavier_initializer(), activate="relu")
-            conv7s = ld.conv_layer(conv7f, 3, 1, ngf, 9, initializer = tf.contrib.layers.xavier_initializer(),activate= "relu")
-            print('conv7s shape:',conv7s.shape)
-
-            conv7t = ld.conv_layer(conv7s, 3, 1, num_channels, 10, initializer = tf.contrib.layers.xavier_initializer(),activate="relu")
-            outputs = ld.conv_layer(conv7t, 1, 1, num_channels, 11, initializer = tf.contrib.layers.xavier_initializer(),activate="linear")
-            print('outputs shape: ',outputs.shape)
-
-        return outputs, hidden
-
-    def generator(self, x: tf.Tensor):
+        #conditional GAN
+        x_hat = self.generator(x)
+
+        self.D_real, self.D_real_logits = self.discriminator(self.inputs[:, self.context_frames:, :, :, 0:1])
+        self.D_fake, self.D_fake_logits = self.discriminator(x_hat[:, self.context_frames - 1:, :, :, 0:1])
+
+        return x_hat
+
+
+    def generator(self,x):
         """
-        Function to build up the generator architecture, here we take Unet_ConvLSTM as generator
+        Function to build up the generator architecture
         args:
             input images: a input tensor with dimension (n_batch,sequence_length,height,width,channel)
-            output images: (n_batch,forecast_length,height,width,channel)
         """
-        network_template = tf.make_template('network', ConvLstmGANVideoPredictionModel.Unet_ConvLSTM_cell)
         with tf.variable_scope("generator", reuse = tf.AUTO_REUSE):
-            # create network
-            x_hat = []
-            #This is for training (optimization of convLSTM layer)
-            hidden_g = None
-            for i in range(self.sequence_length-1):
-                print('i: ',i)
-                if i < self.context_frames:
-                    x_1_g, hidden_g = network_template(x[:, i, :, :, :], self.ngf, hidden_g)
-                else:
-                    x_1_g, hidden_g = network_template(x_1_g, self.ngf, hidden_g)
-                x_hat.append(x_1_g)
-            # pack them all together
-            x_hat = tf.stack(x_hat)
-            self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4])
-            print('self.x_hat shape is: ',self.x_hat.shape)
-        return self.x_hat
-
-    def discriminator(self, x):
+            network_template = tf.make_template('network',
+                                                convLSTM.convLSTM_cell)  # make the template to share the variables
+
+            x_hat = convLSTM.convLSTM_network(self.inputs,
+                                              self.sequence_length,
+                                              self.context_frames,
+                                              network_template)
+        return x_hat
+
+
+    def discriminator(self, vid):
         """
-        Function that get discriminator architecture      
+        Function that get discriminator architecture
         """
         with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
-            conv1 = tf.layers.conv3d(x, 4, kernel_size=[4,4,4], strides=[1,2,2], padding="SAME", name="dis1")
-            conv1 = ConvLstmGANVideoPredictionModel.lrelu(conv1)
-            #conv2 = tf.layers.conv3d(conv1, 1, kernel_size=[4,4,4], strides=[1,2,2], padding="SAME", name="dis2")
-            conv2 = tf.reshape(conv1, [-1,1])
-            #fc1 = ConvLstmGANVideoPredictionModel.lrelu(self.bd1(ConvLstmGANVideoPredictionModel.linear(conv2, output_size=256, scope='d_fc1')))
-            fc2 = ConvLstmGANVideoPredictionModel.lrelu(self.bd2(ConvLstmGANVideoPredictionModel.linear(conv2, output_size=64, scope='d_fc2')))
-            out_logit = ConvLstmGANVideoPredictionModel.linear(fc2, 1, scope='d_fc3')
-            out = tf.nn.sigmoid(out_logit)
-            #out,out_logit = self.Conv3Dnet(x,self.ndf)
-            return out, out_logit
+            conv1 = tf.layers.conv3d(vid,64,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis1")
+            conv1 = self._lrelu(conv1)
+            conv2 = tf.layers.conv3d(conv1, 128, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis2")
+            conv2 = self._lrelu(self.bd1(conv2))
+            conv3 = tf.layers.conv3d(conv2, 256, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME" ,name="dis3")
+            conv3 = self._lrelu(self.bd2(conv3))
+            conv4 = tf.layers.conv3d(conv3, 512, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis4")
+            conv4 = self._lrelu(self.bd3(conv4))
+            conv5 = tf.layers.conv3d(conv4, 1, kernel_size=[2,4,4],strides=[1,1,1],padding="SAME", name="dis5")
+            conv5 = tf.reshape(conv5, [-1,1])
+            conv5sigmoid = tf.nn.sigmoid(conv5)
+            return conv5sigmoid, conv5
 
     def get_disc_loss(self):
         """
@@ -212,8 +189,10 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
         gen_labels = tf.zeros_like(self.D_fake)
         self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real_logits, labels=real_labels))
         self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=gen_labels))
-        self.D_loss = self.D_loss_real + self.D_loss_fake
-        return self.D_loss
+        D_loss = self.D_loss_real + self.D_loss_fake
+        return D_loss
+
+
 
     def get_gen_loss(self):
         """
@@ -223,129 +202,24 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
         Return the loss of generator given inputs
         """
         real_labels = tf.ones_like(self.D_fake)
-        self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=real_labels))
-        return self.G_loss         
+        G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits,
+                                                                             labels=real_labels))
+        return G_loss
    
-    def get_vars(self):
+    def _get_vars(self):
         """
         Get trainable variables from discriminator and generator
         """
         self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
         self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
-  
-    def build_model(self):
-        """
-        Define gan architectures
-        """
-        self.gen_images = self.generator(self.inputs)
-        self.D_real, self.D_real_logits = self.discriminator(self.inputs[:,self.context_frames:, :, :, 0:1]) # use the first varibale as targeted
-        #self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,:,:,:,0:1]) #0:1
-        self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,self.context_frames-1:, :, :, 0:1]) #0:1
-
-        self.get_gen_loss()
-        self.get_disc_loss()
-        self.get_vars()
-        if self.loss_fun == "rmse":
-            #self.recon_loss = tf.reduce_mean(tf.square(self.inputs[:, self.context_frames:,:,:,0] - self.gen_images[:,:,:,:,0]))
-            self.recon_loss = tf.reduce_mean(tf.square(self.inputs[:, self.context_frames:, :, :, 0] - self.gen_images[:, self.context_frames-1:, :, :, 0]))
-        elif self.loss_fun == "cross_entropy":
-            x_flatten = tf.reshape(self.inputs[:, self.context_frames:,:,:,0],[-1])
-            #x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,:,:,:,0],[-1])
-            x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,self.context_frames-1:, :, :, 0], [-1])
-            bce = tf.keras.losses.BinaryCrossentropy()
-            self.recon_loss = bce(x_flatten, x_hat_predict_frames_flatten)
-        else:
-            raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")
-
-    @staticmethod
-    def convLSTM_cell(inputs, hidden):
-        y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
-        # conv lstm cell
-        cell_shape = y_0.get_shape().as_list()
-        channels = cell_shape[-1]
-        with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
-            cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64)
-            if hidden is None:
-                hidden = cell.zero_state(y_0, tf.float32)
-            output, hidden = cell(y_0, hidden)
-        return output, hidden
-        #output_shape = output.get_shape().as_list()
-        #z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
-        ###we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction
-        #x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid")
-        #print('x_hat shape is: ',x_hat.shape)
-        #return x_hat, hidden
-
-    def get_noise(self, x, sigma=0.2):
-        """
-        Function for creating noise: Given the dimensions (n_batch,n_seq, n_height, n_width, channel)
-        """
-        x_shape = x.get_shape().as_list()
-        noise = sigma * tf.random.uniform(minval=-1., maxval=1., shape=x_shape)
-        x = x + noise
-        return x
-
-    def Conv3Dnet_v1(self, x, ndf):
-        conv1 = tf.layers.conv3d(x, ndf, kernel_size = [4, 4, 4], strides = [1, 2, 2], padding = "SAME", name = 'conv1')
-        conv1 = self.lrelu(conv1)
-        # conv2 = tf.layers.conv3d(conv1,ndf*2,kernel_size=[4,4,4],strides=[1,2,2],padding="SAME",name='conv2')
-        # conv2 = self.lrelu(conv2)
-        conv3 = tf.layers.conv3d(conv1, 1, kernel_size = [4, 4, 4], strides = [1, 1, 1], padding = "SAME", name = 'conv3')
-        fl = tf.reshape(conv3, [-1, 1])
-        print('fl shape: ', fl.shape)
-        fc1 = self.lrelu(self.bd1(self.linear(fl, 256, scope = 'fc1')))
-        print('fc1 shape: ', fc1.shape)
-        fc2 = self.lrelu(self.bd2(self.linear(fc1, 64, scope = 'fc2')))
-        print('fc2 shape: ', fc2.shape)
-        out_logit = self.linear(fc2, 1, scope = 'out')
-        out = tf.nn.sigmoid(out_logit)
-        return out, out_logit
-
-
-    def Conv3Dnet_v2(self, x, ndf):
-        """
-            args:
-            input images: a input tensor with dimension (n_batch,forecast_length,height,width,channel)
-            output images:
-        """
-        conv1 = Conv3D(ndf, 4, strides = (1, 2, 2), padding = 'same', kernel_initializer = 'he_normal')(x)
-        bn1 = BatchNormalization()(conv1)
-        bn1 = LeakyReLU(0.2)(bn1)
-        pool1 = MaxPooling3D(pool_size = (1, 2, 2), padding = 'same')(bn1)
-        noise1 = self.get_noise(pool1)
-
-        conv2 = Conv3D(ndf * 2, 4, strides = (1, 2, 2), padding = 'same', kernel_initializer = 'he_normal')(noise1)
-        bn2 = BatchNormalization()(conv2)
-        bn2 = LeakyReLU(0.2)(bn2)
-        pool2 = MaxPooling3D(pool_size = (1, 2, 2), padding = 'same')(bn2)
-        noise2 = self.get_noise(pool2)
-
-        conv3 = Conv3D(ndf * 4, 4, strides = (1, 2, 2), padding = 'same', kernel_initializer = 'he_normal')(noise2)
-        bn3 = BatchNormalization()(conv3)
-        bn3 = LeakyReLU(0.2)(bn3)
-        pool3 = MaxPooling3D(pool_size = (1, 2, 2), padding = 'same')(bn3)
-
-        conv4 = Conv3D(1, 4, 1, padding = 'same')(pool3)
-
-        fl = tf.reshape(conv4, [-1, 1])
-        drop1 = Dropout(0.3)(fl)
-        fc1 = Dense(1024, activation = 'relu')(drop1)
-        drop2 = Dropout(0.3)(fc1)
-        fc2 = Dense(512, activation = 'relu')(drop2)
-        out_logit = Dense(1, activation = 'linear')(fc2)
-        out = tf.nn.sigmoid(out_logit)
-        return out, out_logit
-
-
-    @staticmethod
-    def lrelu(x, leak=0.2, name='lrelu'):
-        return tf.maximum(x, leak * x)
 
 
-    @staticmethod
-    def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
-        shape = input_.get_shape().as_list()
 
+    def _lrelu(self, x, leak=0.2):
+        return tf.maximum(x, leak * x)
+
+    def _linear(self, input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
+        shape = input_.get_shape().as_list()
         with tf.variable_scope(scope or "Linear"):
             matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
                                      tf.random_normal_initializer(stddev = stddev))
@@ -356,19 +230,5 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
             else:
                 return tf.matmul(input_, matrix) + bias
 
-class batch_norm(object):
-    def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
-        with tf.variable_scope(name):
-            self.epsilon  = epsilon
-            self.momentum = momentum
-            self.name = name
-
-    def __call__(self, x, train=True):
-        return tf.contrib.layers.batch_norm(x,
-                      decay=self.momentum,
-                      updates_collections=None,
-                      epsilon=self.epsilon,
-                      scale=True,
-                      is_training=train,
-                      scope=self.name)
+
 
diff --git a/video_prediction_tools/model_modules/video_prediction/models/dna_model.py b/video_prediction_tools/model_modules/video_prediction/models/dna_model.py
deleted file mode 100644
index 8badf600f62c21d71cd81d8c2bfcde2f75e91d34..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/models/dna_model.py
+++ /dev/null
@@ -1,475 +0,0 @@
-# Copyright 2016 The TensorFlow Authors All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Model architecture for predictive model, including CDNA, DNA, and STP."""
-
-import itertools
-
-import numpy as np
-import tensorflow as tf
-import tensorflow.contrib.slim as slim
-from tensorflow.contrib.layers.python import layers as tf_layers
-from model_modules.video_prediction.models import VideoPredictionModel
-from .sna_model import basic_conv_lstm_cell
-
-
-# Amount to use when lower bounding tensors
-RELU_SHIFT = 1e-12
-
-
-def construct_model(images,
-                    actions=None,
-                    states=None,
-                    iter_num=-1.0,
-                    kernel_size=(5, 5),
-                    k=-1,
-                    use_state=True,
-                    num_masks=10,
-                    stp=False,
-                    cdna=True,
-                    dna=False,
-                    context_frames=2,
-                    pix_distributions=None):
-    """Build convolutional lstm video predictor using STP, CDNA, or DNA.
-
-    Args:
-        images: tensor of ground truth image sequences
-        actions: tensor of action sequences
-        states: tensor of ground truth state sequences
-        iter_num: tensor of the current training iteration (for sched. sampling)
-        k: constant used for scheduled sampling. -1 to feed in own prediction.
-        use_state: True to include state and action in prediction
-        num_masks: the number of different pixel motion predictions (and
-                   the number of masks for each of those predictions)
-        stp: True to use Spatial Transformer Predictor (STP)
-        cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
-        dna: True to use Dynamic Neural Advection (DNA)
-        context_frames: number of ground truth frames to pass in before
-                        feeding in own predictions
-    Returns:
-        gen_images: predicted future image frames
-        gen_states: predicted future states
-
-    Raises:
-        ValueError: if more than one network option specified or more than 1 mask
-        specified for DNA model.
-    """
-    DNA_KERN_SIZE = kernel_size[0]
-
-    if stp + cdna + dna != 1:
-        raise ValueError('More than one, or no network option specified.')
-    batch_size, img_height, img_width, color_channels = images[0].get_shape()[0:4]
-    lstm_func = basic_conv_lstm_cell
-
-    # Generated robot states and images.
-    gen_states, gen_images = [], []
-    gen_pix_distrib = []
-    gen_masks = []
-    current_state = states[0]
-
-    if k == -1:
-        feedself = True
-    else:
-        # Scheduled sampling:
-        # Calculate number of ground-truth frames to pass in.
-        num_ground_truth = tf.to_int32(
-            tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k)))))
-        feedself = False
-
-    # LSTM state sizes and states.
-    lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32]))
-    lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None
-    lstm_state5, lstm_state6, lstm_state7 = None, None, None
-
-    for t, action in enumerate(actions):
-        # Reuse variables after the first timestep.
-        reuse = bool(gen_images)
-
-        done_warm_start = len(gen_images) > context_frames - 1
-        with slim.arg_scope(
-                [lstm_func, slim.layers.conv2d, slim.layers.fully_connected,
-                 tf_layers.layer_norm, slim.layers.conv2d_transpose],
-                reuse=reuse):
-
-            if feedself and done_warm_start:
-                # Feed in generated image.
-                prev_image = gen_images[-1]
-                if pix_distributions is not None:
-                    prev_pix_distrib = gen_pix_distrib[-1]
-            elif done_warm_start:
-                # Scheduled sampling
-                prev_image = scheduled_sample(images[t], gen_images[-1], batch_size,
-                                              num_ground_truth)
-            else:
-                # Always feed in ground_truth
-                prev_image = images[t]
-                if pix_distributions is not None:
-                    prev_pix_distrib = pix_distributions[t]
-                    # prev_pix_distrib = tf.expand_dims(prev_pix_distrib, -1)
-
-            # Predicted state is always fed back in
-            state_action = tf.concat(axis=1, values=[action, current_state])
-
-            enc0 = slim.layers.conv2d(
-                prev_image,
-                32, [5, 5],
-                stride=2,
-                scope='scale1_conv1',
-                normalizer_fn=tf_layers.layer_norm,
-                normalizer_params={'scope': 'layer_norm1'})
-
-            hidden1, lstm_state1 = lstm_func(
-                enc0, lstm_state1, lstm_size[0], scope='state1')
-            hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')
-            hidden2, lstm_state2 = lstm_func(
-                hidden1, lstm_state2, lstm_size[1], scope='state2')
-            hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3')
-            enc1 = slim.layers.conv2d(
-                hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2')
-
-            hidden3, lstm_state3 = lstm_func(
-                enc1, lstm_state3, lstm_size[2], scope='state3')
-            hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4')
-            hidden4, lstm_state4 = lstm_func(
-                hidden3, lstm_state4, lstm_size[3], scope='state4')
-            hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5')
-            enc2 = slim.layers.conv2d(
-                hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3')
-
-            # Pass in state and action.
-            smear = tf.reshape(
-                state_action,
-                [int(batch_size), 1, 1, int(state_action.get_shape()[1])])
-            smear = tf.tile(
-                smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1])
-            if use_state:
-                enc2 = tf.concat(axis=3, values=[enc2, smear])
-            enc3 = slim.layers.conv2d(
-                enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4')
-
-            hidden5, lstm_state5 = lstm_func(
-                enc3, lstm_state5, lstm_size[4], scope='state5')  # last 8x8
-            hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6')
-            enc4 = slim.layers.conv2d_transpose(
-                hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1')
-
-            hidden6, lstm_state6 = lstm_func(
-                enc4, lstm_state6, lstm_size[5], scope='state6')  # 16x16
-            hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7')
-            # Skip connection.
-            hidden6 = tf.concat(axis=3, values=[hidden6, enc1])  # both 16x16
-
-            enc5 = slim.layers.conv2d_transpose(
-                hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2')
-            hidden7, lstm_state7 = lstm_func(
-                enc5, lstm_state7, lstm_size[6], scope='state7')  # 32x32
-            hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8')
-
-            # Skip connection.
-            hidden7 = tf.concat(axis=3, values=[hidden7, enc0])  # both 32x32
-
-            enc6 = slim.layers.conv2d_transpose(
-                hidden7,
-                hidden7.get_shape()[3], 3, stride=2, scope='convt3',
-                normalizer_fn=tf_layers.layer_norm,
-                normalizer_params={'scope': 'layer_norm9'})
-
-            if dna:
-                # Using largest hidden state for predicting untied conv kernels.
-                enc7 = slim.layers.conv2d_transpose(
-                    enc6, DNA_KERN_SIZE ** 2, 1, stride=1, scope='convt4')
-            else:
-                # Using largest hidden state for predicting a new image layer.
-                enc7 = slim.layers.conv2d_transpose(
-                    enc6, color_channels, 1, stride=1, scope='convt4')
-                # This allows the network to also generate one image from scratch,
-                # which is useful when regions of the image become unoccluded.
-                transformed = [tf.nn.sigmoid(enc7)]
-
-            if stp:
-                stp_input0 = tf.reshape(hidden5, [int(batch_size), -1])
-                stp_input1 = slim.layers.fully_connected(
-                    stp_input0, 100, scope='fc_stp')
-
-                # disabling capability to generete pixels
-                reuse_stp = None
-                if reuse:
-                    reuse_stp = reuse
-                transformed = stp_transformation(prev_image, stp_input1, num_masks, reuse_stp)
-                # transformed += stp_transformation(prev_image, stp_input1, num_masks)
-
-                if pix_distributions is not None:
-                    transf_distrib = stp_transformation(prev_pix_distrib, stp_input1, num_masks, reuse=True)
-
-            elif cdna:
-                cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
-
-                new_transformed, cdna_kerns = cdna_transformation(prev_image,
-                                                                  cdna_input,
-                                                                  num_masks,
-                                                                  int(color_channels),
-                                                                  kernel_size,
-                                                                  reuse_sc=reuse)
-                transformed += new_transformed
-
-                if pix_distributions is not None:
-                    if not dna:
-                        transf_distrib = [prev_pix_distrib]
-                    new_transf_distrib, _ = cdna_transformation(prev_pix_distrib,
-                                                                cdna_input,
-                                                                num_masks,
-                                                                prev_pix_distrib.shape[-1].value,
-                                                                kernel_size,
-                                                                reuse_sc=True)
-                    transf_distrib += new_transf_distrib
-
-            elif dna:
-                # Only one mask is supported (more should be unnecessary).
-                if num_masks != 1:
-                    raise ValueError('Only one mask is supported for DNA model.')
-                transformed = [dna_transformation(prev_image, enc7, DNA_KERN_SIZE)]
-
-            masks = slim.layers.conv2d_transpose(
-                enc6, num_masks + 1, 1, stride=1, scope='convt7')
-            masks = tf.reshape(
-                tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])),
-                [int(batch_size), int(img_height), int(img_width), num_masks + 1])
-            mask_list = tf.split(masks, num_masks + 1, axis=3)
-            output = mask_list[0] * prev_image
-            for layer, mask in zip(transformed, mask_list[1:]):
-                output += layer * mask
-            gen_images.append(output)
-            gen_masks.append(mask_list)
-
-            if dna and pix_distributions is not None:
-                transf_distrib = [dna_transformation(prev_pix_distrib, enc7, DNA_KERN_SIZE)]
-
-            if pix_distributions is not None:
-                pix_distrib_output = mask_list[0] * prev_pix_distrib
-                for layer, mask in zip(transf_distrib, mask_list[1:]):
-                    pix_distrib_output += layer * mask
-                pix_distrib_output /= tf.reduce_sum(pix_distrib_output, axis=(1, 2), keepdims=True)
-                gen_pix_distrib.append(pix_distrib_output)
-
-            if int(current_state.get_shape()[1]) == 0:
-                current_state = tf.zeros_like(state_action)
-            else:
-                current_state = slim.layers.fully_connected(
-                    state_action,
-                    int(current_state.get_shape()[1]),
-                    scope='state_pred',
-                    activation_fn=None)
-            gen_states.append(current_state)
-
-    return gen_images, gen_states, gen_masks, gen_pix_distrib
-
-
-## Utility functions
-def stp_transformation(prev_image, stp_input, num_masks):
-    """Apply spatial transformer predictor (STP) to previous image.
-
-    Args:
-        prev_image: previous image to be transformed.
-        stp_input: hidden layer to be used for computing STN parameters.
-        num_masks: number of masks and hence the number of STP transformations.
-    Returns:
-        List of images transformed by the predicted STP parameters.
-     """
-    # Only import spatial transformer if needed.
-    from spatial_transformer import transformer
-
-    identity_params = tf.convert_to_tensor(
-        np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32))
-    transformed = []
-    for i in range(num_masks - 1):
-        params = slim.layers.fully_connected(
-            stp_input, 6, scope='stp_params' + str(i),
-            activation_fn=None) + identity_params
-        transformed.append(transformer(prev_image, params))
-
-    return transformed
-
-
-def cdna_transformation(prev_image, cdna_input, num_masks, color_channels, kernel_size, reuse_sc=None):
-    """Apply convolutional dynamic neural advection to previous image.
-
-    Args:
-        prev_image: previous image to be transformed.
-        cdna_input: hidden lyaer to be used for computing CDNA kernels.
-        num_masks: the number of masks and hence the number of CDNA transformations.
-        color_channels: the number of color channels in the images.
-    Returns:
-        List of images transformed by the predicted CDNA kernels.
-    """
-    batch_size = int(cdna_input.get_shape()[0])
-    height = int(prev_image.get_shape()[1])
-    width = int(prev_image.get_shape()[2])
-
-    # Predict kernels using linear function of last hidden layer.
-    cdna_kerns = slim.layers.fully_connected(
-        cdna_input,
-        kernel_size[0] * kernel_size[1] * num_masks,
-        scope='cdna_params',
-        activation_fn=None,
-        reuse=reuse_sc)
-
-    # Reshape and normalize.
-    cdna_kerns = tf.reshape(
-        cdna_kerns, [batch_size, kernel_size[0], kernel_size[1], 1, num_masks])
-    cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT
-    norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True)
-    cdna_kerns /= norm_factor
-
-    # Treat the color channel dimension as the batch dimension since the same
-    # transformation is applied to each color channel.
-    # Treat the batch dimension as the channel dimension so that
-    # depthwise_conv2d can apply a different transformation to each sample.
-    cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3])
-    cdna_kerns = tf.reshape(cdna_kerns, [kernel_size[0], kernel_size[1], batch_size, num_masks])
-    # Swap the batch and channel dimensions.
-    prev_image = tf.transpose(prev_image, [3, 1, 2, 0])
-
-    # Transform image.
-    transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME')
-
-    # Transpose the dimensions to where they belong.
-    transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks])
-    transformed = tf.transpose(transformed, [3, 1, 2, 0, 4])
-    transformed = tf.unstack(transformed, axis=-1)
-    return transformed, cdna_kerns
-
-
-def dna_transformation(prev_image, dna_input, kernel_size):
-    """Apply dynamic neural advection to previous image.
-
-    Args:
-        prev_image: previous image to be transformed.
-        dna_input: hidden lyaer to be used for computing DNA transformation.
-    Returns:
-        List of images transformed by the predicted CDNA kernels.
-    """
-    # Construct translated images.
-    pad_along_height = (kernel_size[0] - 1)
-    pad_along_width = (kernel_size[1] - 1)
-    pad_top = pad_along_height // 2
-    pad_bottom = pad_along_height - pad_top
-    pad_left = pad_along_width // 2
-    pad_right = pad_along_width - pad_left
-    prev_image_pad = tf.pad(prev_image, [[0, 0],
-                                         [pad_top, pad_bottom],
-                                         [pad_left, pad_right],
-                                         [0, 0]])
-    image_height = int(prev_image.get_shape()[1])
-    image_width = int(prev_image.get_shape()[2])
-
-    inputs = []
-    for xkern in range(kernel_size[0]):
-        for ykern in range(kernel_size[1]):
-            inputs.append(
-                tf.expand_dims(
-                    tf.slice(prev_image_pad, [0, xkern, ykern, 0],
-                             [-1, image_height, image_width, -1]), [3]))
-    inputs = tf.concat(axis=3, values=inputs)
-
-    # Normalize channels to 1.
-    kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT
-    kernel = tf.expand_dims(
-        kernel / tf.reduce_sum(
-            kernel, [3], keepdims=True), [4])
-    return tf.reduce_sum(kernel * inputs, [3], keepdims=False)
-
-
-def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
-    """Sample batch with specified mix of ground truth and generated data points.
-
-    Args:
-        ground_truth_x: tensor of ground-truth data points.
-        generated_x: tensor of generated data points.
-        batch_size: batch size
-        num_ground_truth: number of ground-truth examples to include in batch.
-    Returns:
-        New batch with num_ground_truth sampled from ground_truth_x and the rest
-        from generated_x.
-    """
-    idx = tf.random_shuffle(tf.range(int(batch_size)))
-    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
-    generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))
-
-    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
-    generated_examps = tf.gather(generated_x, generated_idx)
-    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
-                             [ground_truth_examps, generated_examps])
-
-
-def generator_fn(inputs, hparams=None):
-    images = tf.unstack(inputs['images'], axis=0)
-    actions = tf.unstack(inputs['actions'], axis=0)
-    states = tf.unstack(inputs['states'], axis=0)
-    pix_distributions = tf.unstack(inputs['pix_distribs'], axis=0) if 'pix_distribs' in inputs else None
-    iter_num = tf.to_float(tf.train.get_or_create_global_step())
-
-    gen_images, gen_states, gen_masks, gen_pix_distrib = \
-        construct_model(images,
-                        actions,
-                        states,
-                        iter_num=iter_num,
-                        kernel_size=hparams.kernel_size,
-                        k=hparams.schedule_sampling_k,
-                        num_masks=hparams.num_masks,
-                        cdna=hparams.transformation == 'cdna',
-                        dna=hparams.transformation == 'dna',
-                        stp=hparams.transformation == 'stp',
-                        context_frames=hparams.context_frames,
-                        pix_distributions=pix_distributions)
-    outputs = {
-        'gen_images': tf.stack(gen_images, axis=0),
-        'gen_states': tf.stack(gen_states, axis=0),
-        'masks': tf.stack([tf.stack(gen_mask_list, axis=-1) for gen_mask_list in gen_masks], axis=0),
-    }
-    if 'pix_distribs' in inputs:
-        outputs['gen_pix_distribs'] = tf.stack(gen_pix_distrib, axis=0)
-    gen_images = outputs['gen_images'][hparams.context_frames - 1:]
-    return gen_images, outputs
-
-
-class DNAVideoPredictionModel(VideoPredictionModel):
-    def __init__(self, *args, **kwargs):
-        super(DNAVideoPredictionModel, self).__init__(
-            generator_fn, *args, **kwargs)
-
-    def get_default_hparams_dict(self):
-        default_hparams = super(DNAVideoPredictionModel, self).get_default_hparams_dict()
-        hparams = dict(
-            batch_size=32,
-            l1_weight=0.0,
-            l2_weight=1.0,
-            transformation='cdna',
-            kernel_size=(9, 9),
-            num_masks=10,
-            schedule_sampling_k=900.0,
-        )
-        return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-    def parse_hparams(self, hparams_dict, hparams):
-        hparams = super(DNAVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)
-        if self.mode == 'test':
-            def override_hparams_maybe(name, value):
-                orig_value = hparams.values()[name]
-                if orig_value != value:
-                    print('Overriding hparams from %s=%r to %r for mode=%s.' %
-                          (name, orig_value, value, self.mode))
-                    hparams.set_hparam(name, value)
-            override_hparams_maybe('schedule_sampling_k', -1)
-        return hparams
diff --git a/video_prediction_tools/model_modules/video_prediction/models/linear_regression_model.py b/video_prediction_tools/model_modules/video_prediction/models/linear_regression_model.py
deleted file mode 100644
index 296ae16a1f70fcd5047481ab46cdd6e14abe1e83..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/models/linear_regression_model.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# SPDX-FileCopyrightText: 2018, alexlee-gk
-#
-# SPDX-License-Identifier: MIT
-
-__email__ = "b.gong@fz-juelich.de"
-__author__ = "Bing Gong"
-__date__ = "2022-04-13"
-
-from .our_base_model import BaseModels
-import tensorflow as tf
-
-
-class VanillaConvLstmVideoPredictionModel(BaseModels):
-
-    def __init__(self, hparams_dict=None, **kwargs):
-        """
-        This is class for building convLSTM architecture by using updated hparameters
-        args:
-             hparams_dict : dict, the dictionary contains the hparaemters names and values
-        """
-        super().__init__(hparams_dict)
-        self.get_hparams()
-
-    def get_hparams(self):
-        """
-        obtain the hparams from the dict to the class variables
-        """
-        method = BaseModels.get_hparams.__name__
-
-        try:
-            self.context_frames = self.hparams.context_frames
-            self.sequence_length = self.hparams.sequence_length
-            self.max_epochs = self.hparams.max_epochs
-            self.batch_size = self.hparams.batch_size
-            self.shuffle_on_val = self.hparams.shuffle_on_val
-            self.opt_var = self.hparams.opt_var
-            self.learning_rate = self.hparams.lr
-
-            print("The model hparams have been parsed successfully! ")
-        except Exception as error:
-           print("Method %{}: error: {}".format(method, error))
-           raise("Method %{}: the hparameter dictionary must include the params defined above!".format(method))
-
-    def build_graph(self, x: tf.Tensor):
-
-        self.is_build_graph = False
-        self.inputs = x
-        self.global_step = tf.train.get_or_create_global_step()
-        original_global_variables = tf.global_variables()
-
-        self.build_model()
-
-
-        # This is the loss function (MSE):
-        # Optimize all target variables/channels
-        if self.opt_var == "all":
-            x = self.inputs[:, self.context_frames:, :, :, :]
-            x_hat = self.x_hat_predict_frames[:, :, :, :, :]
-            print("The model is optimzied on all the variables in the loss function")
-        elif self.opt_var != "all" and isinstance(self.opt_var, str):
-            self.opt_var = int(self.opt_var)
-            print("The model is optimized on the {} variable in the loss function".format(self.opt_var))
-            x = self.inputs[:, self.context_frames:, :, :, self.opt_var]
-            x_hat = self.x_hat_predict_frames[:, :, :, :, self.opt_var]
-        else:
-            raise ValueError(
-                "The opt var in the hyper-parameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables")
-
-        #loss function is mean squre error
-        self.total_loss = tf.reduce_mean(tf.square(x - x_hat))
-
-        self.train_op = tf.train.AdamOptimizer(
-            learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
-
-        self.outputs["gen_images"] = self.x_hat
-
-        # Summary op
-        self.loss_summary = tf.summary.scalar("total_loss", self.total_loss)
-        self.summary_op = tf.summary.merge_all()
-        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-        self.saveable_variables = [self.global_step] + global_variables
-        self.is_build_graph = True
-        return self.is_build_graph
-
-
-
-    def build_model(self):
-        pass
diff --git a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py
deleted file mode 100644
index 80a49160159840ec6219d0d5a488d676c7257d67..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py
+++ /dev/null
@@ -1,468 +0,0 @@
-# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
-#
-# SPDX-License-Identifier: MIT
-
-__email__ = "b.gong@fz-juelich.de"
-__author__ = "Bing Gong"
-__date__ = "2020-08-22"
-
-
-import itertools
-import numpy as np
-import tensorflow as tf
-
-from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
-from model_modules.video_prediction.models import BaseVideoPredictionModel
-from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
-from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
-from model_modules.video_prediction.layers.mcnet_ops import *
-from model_modules.video_prediction.utils.mcnet_utils import *
-import os
-
-class McNetVideoPredictionModel(BaseVideoPredictionModel):
-    def __init__(self, mode='train', hparams_dict=None,
-                 hparams=None, **kwargs):
-        super(McNetVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
-        self.mode = mode
-        self.lr = self.hparams.lr
-        self.context_frames = self.hparams.context_frames
-        self.sequence_length = self.hparams.sequence_length
-        self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
-        self.df_dim = self.hparams.df_dim
-        self.gf_dim = self.hparams.gf_dim
-        self.alpha = self.hparams.alpha
-        self.beta = self.hparams.beta
-        self.gen_images_enc = None
-        self.recon_loss = None
-        self.latent_loss = None
-        self.total_loss = None
-
-    def get_default_hparams_dict(self):
-        """
-        The keys of this dict define valid hyperparameters for instances of
-        this class. A class inheriting from this one should override this
-        method if it has a different set of hyperparameters.
-
-        Returns:
-            A dict with the following hyperparameters.
-
-            batch_size: batch size for training.
-            lr: learning rate. if decay steps is non-zero, this is the
-                learning rate for steps <= decay_step.
-
-
-
-
-            max_steps: number of training steps.
-
-
-            context_frames: the number of ground-truth frames to pass in at
-                start. Must be specified during instantiation.
-            sequence_length: the number of frames in the video sequence,
-                including the context frames, so this model predicts
-                `sequence_length - context_frames` future frames. Must be
-                specified during instantiation.
-            df_dim: specific parameters for mcnet
-            gf_dim: specific parameters for menet
-            alpha:  specific parameters for mcnet
-            beta:   specific paramters for mcnet
-
-        """
-        default_hparams = super(McNetVideoPredictionModel, self).get_default_hparams_dict()
-        hparams = dict(
-            batch_size=16,
-            lr=0.001,
-            max_epochs=350000,
-            context_frames = 10,
-            sequence_length = 20,
-            nz = 16,
-            gf_dim = 64,
-            df_dim = 64,
-            alpha = 1,
-            beta = 0.0
-        )
-        return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-    def build_graph(self, x):
-
-        self.x = x["images"]
-        self.x_shape = self.x.get_shape().as_list()
-        self.batch_size = self.x_shape[0]
-        self.image_size = [self.x_shape[2],self.x_shape[3]]
-        self.c_dim = self.x_shape[4]
-        self.diff_shape = [self.batch_size, self.context_frames-1, self.image_size[0],
-                           self.image_size[1], self.c_dim]
-        self.xt_shape = [self.batch_size, self.image_size[0], self.image_size[1],self.c_dim]
-        self.is_train = True
-       
-
-        #self.global_step = tf.Variable(0, name='global_step', trainable=False)
-        self.global_step = tf.train.get_or_create_global_step()
-        original_global_variables = tf.global_variables()
-
-        # self.xt = tf.placeholder(tf.float32, self.xt_shape, name='xt')
-        self.xt = self.x[:, self.context_frames - 1, :, :, :]
-
-        self.diff_in = tf.placeholder(tf.float32, self.diff_shape, name='diff_in')
-        diff_in_all = []
-        for t in range(1, self.context_frames):
-            prev = self.x[:, t-1:t, :, :, :]
-            next = self.x[:, t:t+1, :, :, :]
-            #diff_in = tf.reshape(next - prev, [self.batch_size, 1, self.image_size[0], self.image_size[1], -1])
-            print("prev:",prev)
-            print("next:",next)
-            diff_in = tf.subtract(next,prev)
-            print("diff_in:",diff_in)
-            diff_in_all.append(diff_in)
-
-        self.diff_in = tf.concat(axis = 1, values = diff_in_all)
-
-        cell = BasicConvLSTMCell([self.image_size[0] / 8, self.image_size[1] / 8], [3, 3], 256)
-
-        pred = self.forward(self.diff_in, self.xt, cell)
-
-
-        self.G = tf.concat(axis=1, values=pred)#[batch_size,context_frames,image1,image2,channels]
-        print ("1:self.G:",self.G)
-        if self.is_train:
-
-            true_sim = self.x[:, self.context_frames:, :, :, :]
-
-            # Bing: the following make sure the channel is three dimension, if the channle is 3 then will be duplicated
-            if self.c_dim == 1: true_sim = tf.tile(true_sim, [1, 1, 1, 1, 3])
-
-            # Bing: the raw inputs shape is [batch_size, image_size[0],self.image_size[1], num_seq, channel]. tf.transpose will transpoe the shape into
-            # [batch size*num_seq, image_size0, image_size1, channels], for our era5 case, we do not need transpose
-            # true_sim = tf.reshape(tf.transpose(true_sim,[0,3,1,2,4]),
-            #                             [-1, self.image_size[0],
-            #                              self.image_size[1], 3])
-            true_sim = tf.reshape(true_sim, [-1, self.image_size[0], self.image_size[1], 3])
-
-
-
-
-        gen_sim = self.G
-        
-        #combine groud truth and predict frames
-        self.x_hat = tf.concat([self.x[:, :self.context_frames, :, :, :], self.G], 1)
-        print ("self.x_hat:",self.x_hat)
-        if self.c_dim == 1: gen_sim = tf.tile(gen_sim, [1, 1, 1, 1, 3])
-        # gen_sim = tf.reshape(tf.transpose(gen_sim,[0,3,1,2,4]),
-        #                                [-1, self.image_size[0],
-        #                                self.image_size[1], 3])
-
-        gen_sim = tf.reshape(gen_sim, [-1, self.image_size[0], self.image_size[1], 3])
-
-
-        binput = tf.reshape(tf.transpose(self.x[:, :self.context_frames, :, :, :], [0, 1, 2, 3, 4]),
-                            [self.batch_size, self.image_size[0],
-                             self.image_size[1], -1])
-
-        btarget = tf.reshape(tf.transpose(self.x[:, self.context_frames:, :, :, :], [0, 1, 2, 3, 4]),
-                             [self.batch_size, self.image_size[0],
-                              self.image_size[1], -1])
-        bgen = tf.reshape(self.G, [self.batch_size,
-                                   self.image_size[0],
-                                   self.image_size[1], -1])
-
-        print ("binput:",binput)
-        print("btarget:",btarget)
-        print("bgen:",bgen)
-
-        good_data = tf.concat(axis=3, values=[binput, btarget])
-        gen_data = tf.concat(axis=3, values=[binput, bgen])
-        self.gen_data = gen_data
-        print ("2:self.gen_data:", self.gen_data)
-        with tf.variable_scope("DIS", reuse=False):
-            self.D, self.D_logits = self.discriminator(good_data)
-
-        with tf.variable_scope("DIS", reuse=True):
-            self.D_, self.D_logits_ = self.discriminator(gen_data)
-
-        self.L_p = tf.reduce_mean(
-            tf.square(self.G - self.x[:, self.context_frames:, :, :, :]))
-
-        self.L_gdl = gdl(gen_sim, true_sim, 1.)
-        self.L_img = self.L_p + self.L_gdl
-
-        self.d_loss_real = tf.reduce_mean(
-            tf.nn.sigmoid_cross_entropy_with_logits(
-                logits = self.D_logits, labels = tf.ones_like(self.D)
-            ))
-        self.d_loss_fake = tf.reduce_mean(
-            tf.nn.sigmoid_cross_entropy_with_logits(
-                logits = self.D_logits_, labels = tf.zeros_like(self.D_)
-            ))
-        self.d_loss = self.d_loss_real + self.d_loss_fake
-        self.L_GAN = tf.reduce_mean(
-            tf.nn.sigmoid_cross_entropy_with_logits(
-                logits = self.D_logits_, labels = tf.ones_like(self.D_)
-            ))
-
-        self.loss_sum = tf.summary.scalar("L_img", self.L_img)
-        self.L_p_sum = tf.summary.scalar("L_p", self.L_p)
-        self.L_gdl_sum = tf.summary.scalar("L_gdl", self.L_gdl)
-        self.L_GAN_sum = tf.summary.scalar("L_GAN", self.L_GAN)
-        self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
-        self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real)
-        self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)
-
-        self.total_loss = self.alpha * self.L_img + self.beta * self.L_GAN
-        self._loss_sum = tf.summary.scalar("total_loss", self.total_loss)
-        self.g_sum = tf.summary.merge([self.L_p_sum,
-                                       self.L_gdl_sum, self.loss_sum,
-                                       self.L_GAN_sum])
-        self.d_sum = tf.summary.merge([self.d_loss_real_sum, self.d_loss_sum,
-                                       self.d_loss_fake_sum])
-
-
-        self.t_vars = tf.trainable_variables()
-        self.g_vars = [var for var in self.t_vars if 'DIS' not in var.name]
-        self.d_vars = [var for var in self.t_vars if 'DIS' in var.name]
-        num_param = 0.0
-        for var in self.g_vars:
-            num_param += int(np.prod(var.get_shape()));
-        print("Number of parameters: %d" % num_param)
-
-        # Training
-        self.d_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize(
-            self.d_loss, var_list = self.d_vars)
-        self.g_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize(
-            self.alpha * self.L_img + self.beta * self.L_GAN, var_list = self.g_vars, global_step=self.global_step)
-       
-        self.train_op = [self.d_optim,self.g_optim]
-        self.outputs = {}
-        self.outputs["gen_images"] = self.x_hat
-        
-
-        self.summary_op = tf.summary.merge_all()
-        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-        self.saveable_variables = [self.global_step] + global_variables
-        return 
-
-
-    def forward(self, diff_in, xt, cell):
-        # Initial state
-        state = tf.zeros([self.batch_size, self.image_size[0] / 8,
-                          self.image_size[1] / 8, 512])
-        reuse = False
-        # Encoder
-        for t in range(self.context_frames - 1):
-            enc_h, res_m = self.motion_enc(diff_in[:, t, :, :, :], reuse = reuse)
-            h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = reuse)
-            reuse = True
-        pred = []
-        # Decoder
-        for t in range(self.predict_frames):
-            if t == 0:
-                h_cont, res_c = self.content_enc(xt, reuse = False)
-                h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = False)
-                res_connect = self.residual(res_m, res_c, reuse = False)
-                x_hat = self.dec_cnn(h_tp1, res_connect, reuse = False)
-
-            else:
-
-                enc_h, res_m = self.motion_enc(diff_in, reuse = True)
-                h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = True)
-                h_cont, res_c = self.content_enc(xt, reuse = reuse)
-                h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = True)
-                res_connect = self.residual(res_m, res_c, reuse = True)
-                x_hat = self.dec_cnn(h_tp1, res_connect, reuse = True)
-                print ("x_hat :",x_hat)
-            if self.c_dim == 3:
-                # Network outputs are BGR so they need to be reversed to use
-                # rgb_to_grayscale
-                #x_hat_gray = tf.concat(axis=3,values=[x_hat[:,:,:,2:3], x_hat[:,:,:,1:2],x_hat[:,:,:,0:1]])
-                #xt_gray = tf.concat(axis=3,values=[xt[:,:,:,2:3], xt[:,:,:,1:2],xt[:,:,:,0:1]])
-
-                #                 x_hat_gray = 1./255.*tf.image.rgb_to_grayscale(
-                #                     inverse_transform(x_hat_rgb)*255.
-                #                 )
-                #                 xt_gray = 1./255.*tf.image.rgb_to_grayscale(
-                #                     inverse_transform(xt_rgb)*255.
-                #                 )
-
-                x_hat_gray = x_hat
-                xt_gray = xt
-            else:
-                x_hat_gray = inverse_transform(x_hat)
-                xt_gray = inverse_transform(xt)
-
-            diff_in = x_hat_gray - xt_gray
-            xt = x_hat
-
-
-            pred.append(tf.reshape(x_hat, [self.batch_size, 1, self.image_size[0],
-                                           self.image_size[1], self.c_dim]))
-
-        return pred
-
-    def motion_enc(self, diff_in, reuse):
-        res_in = []
-
-        conv1 = relu(conv2d(diff_in, output_dim = self.gf_dim, k_h = 5, k_w = 5,
-                            d_h = 1, d_w = 1, name = 'dyn1_conv1', reuse = reuse))
-        res_in.append(conv1)
-        pool1 = MaxPooling(conv1, [2, 2])
-
-        conv2 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 5, k_w = 5,
-                            d_h = 1, d_w = 1, name = 'dyn_conv2', reuse = reuse))
-        res_in.append(conv2)
-        pool2 = MaxPooling(conv2, [2, 2])
-
-        conv3 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 7, k_w = 7,
-                            d_h = 1, d_w = 1, name = 'dyn_conv3', reuse = reuse))
-        res_in.append(conv3)
-        pool3 = MaxPooling(conv3, [2, 2])
-        return pool3, res_in
-
-    def content_enc(self, xt, reuse):
-        res_in = []
-        conv1_1 = relu(conv2d(xt, output_dim = self.gf_dim, k_h = 3, k_w = 3,
-                              d_h = 1, d_w = 1, name = 'cont_conv1_1', reuse = reuse))
-        conv1_2 = relu(conv2d(conv1_1, output_dim = self.gf_dim, k_h = 3, k_w = 3,
-                              d_h = 1, d_w = 1, name = 'cont_conv1_2', reuse = reuse))
-        res_in.append(conv1_2)
-        pool1 = MaxPooling(conv1_2, [2, 2])
-
-        conv2_1 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3,
-                              d_h = 1, d_w = 1, name = 'cont_conv2_1', reuse = reuse))
-        conv2_2 = relu(conv2d(conv2_1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3,
-                              d_h = 1, d_w = 1, name = 'cont_conv2_2', reuse = reuse))
-        res_in.append(conv2_2)
-        pool2 = MaxPooling(conv2_2, [2, 2])
-
-        conv3_1 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
-                              d_h = 1, d_w = 1, name = 'cont_conv3_1', reuse = reuse))
-        conv3_2 = relu(conv2d(conv3_1, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
-                              d_h = 1, d_w = 1, name = 'cont_conv3_2', reuse = reuse))
-        conv3_3 = relu(conv2d(conv3_2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
-                              d_h = 1, d_w = 1, name = 'cont_conv3_3', reuse = reuse))
-        res_in.append(conv3_3)
-        pool3 = MaxPooling(conv3_3, [2, 2])
-        return pool3, res_in
-
-    def comb_layers(self, h_dyn, h_cont, reuse=False):
-        comb1 = relu(conv2d(tf.concat(axis = 3, values = [h_dyn, h_cont]),
-                            output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
-                            d_h = 1, d_w = 1, name = 'comb1', reuse = reuse))
-        comb2 = relu(conv2d(comb1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3,
-                            d_h = 1, d_w = 1, name = 'comb2', reuse = reuse))
-        h_comb = relu(conv2d(comb2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3,
-                             d_h = 1, d_w = 1, name = 'h_comb', reuse = reuse))
-        return h_comb
-
-    def residual(self, input_dyn, input_cont, reuse=False):
-        n_layers = len(input_dyn)
-        res_out = []
-        for l in range(n_layers):
-            input_ = tf.concat(axis = 3, values = [input_dyn[l], input_cont[l]])
-            out_dim = input_cont[l].get_shape()[3]
-            res1 = relu(conv2d(input_, output_dim = out_dim,
-                               k_h = 3, k_w = 3, d_h = 1, d_w = 1,
-                               name = 'res' + str(l) + '_1', reuse = reuse))
-            res2 = conv2d(res1, output_dim = out_dim, k_h = 3, k_w = 3,
-                          d_h = 1, d_w = 1, name = 'res' + str(l) + '_2', reuse = reuse)
-            res_out.append(res2)
-        return res_out
-
-    def dec_cnn(self, h_comb, res_connect, reuse=False):
-
-        shapel3 = [self.batch_size, int(self.image_size[0] / 4),
-                   int(self.image_size[1] / 4), self.gf_dim * 4]
-        shapeout3 = [self.batch_size, int(self.image_size[0] / 4),
-                     int(self.image_size[1] / 4), self.gf_dim * 2]
-        depool3 = FixedUnPooling(h_comb, [2, 2])
-        deconv3_3 = relu(deconv2d(relu(tf.add(depool3, res_connect[2])),
-                                  output_shape = shapel3, k_h = 3, k_w = 3,
-                                  d_h = 1, d_w = 1, name = 'dec_deconv3_3', reuse = reuse))
-        deconv3_2 = relu(deconv2d(deconv3_3, output_shape = shapel3, k_h = 3, k_w = 3,
-                                  d_h = 1, d_w = 1, name = 'dec_deconv3_2', reuse = reuse))
-        deconv3_1 = relu(deconv2d(deconv3_2, output_shape = shapeout3, k_h = 3, k_w = 3,
-                                  d_h = 1, d_w = 1, name = 'dec_deconv3_1', reuse = reuse))
-
-        shapel2 = [self.batch_size, int(self.image_size[0] / 2),
-                   int(self.image_size[1] / 2), self.gf_dim * 2]
-        shapeout3 = [self.batch_size, int(self.image_size[0] / 2),
-                     int(self.image_size[1] / 2), self.gf_dim]
-        depool2 = FixedUnPooling(deconv3_1, [2, 2])
-        deconv2_2 = relu(deconv2d(relu(tf.add(depool2, res_connect[1])),
-                                  output_shape = shapel2, k_h = 3, k_w = 3,
-                                  d_h = 1, d_w = 1, name = 'dec_deconv2_2', reuse = reuse))
-        deconv2_1 = relu(deconv2d(deconv2_2, output_shape = shapeout3, k_h = 3, k_w = 3,
-                                  d_h = 1, d_w = 1, name = 'dec_deconv2_1', reuse = reuse))
-
-        shapel1 = [self.batch_size, self.image_size[0],
-                   self.image_size[1], self.gf_dim]
-        shapeout1 = [self.batch_size, self.image_size[0],
-                     self.image_size[1], self.c_dim]
-        depool1 = FixedUnPooling(deconv2_1, [2, 2])
-        deconv1_2 = relu(deconv2d(relu(tf.add(depool1, res_connect[0])),
-                                  output_shape = shapel1, k_h = 3, k_w = 3, d_h = 1, d_w = 1,
-                                  name = 'dec_deconv1_2', reuse = reuse))
-        xtp1 = tanh(deconv2d(deconv1_2, output_shape = shapeout1, k_h = 3, k_w = 3,
-                             d_h = 1, d_w = 1, name = 'dec_deconv1_1', reuse = reuse))
-        return xtp1
-
-    def discriminator(self, image):
-        h0 = lrelu(conv2d(image, self.df_dim, name = 'dis_h0_conv'))
-        h1 = lrelu(batch_norm(conv2d(h0, self.df_dim * 2, name = 'dis_h1_conv'),
-                              "bn1"))
-        h2 = lrelu(batch_norm(conv2d(h1, self.df_dim * 4, name = 'dis_h2_conv'),
-                              "bn2"))
-        h3 = lrelu(batch_norm(conv2d(h2, self.df_dim * 8, name = 'dis_h3_conv'),
-                              "bn3"))
-        h = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'dis_h3_lin')
-
-        return tf.nn.sigmoid(h), h
-
-    def save(self, sess, checkpoint_dir, step):
-        model_name = "MCNET.model"
-
-        if not os.path.exists(checkpoint_dir):
-            os.makedirs(checkpoint_dir)
-
-        self.saver.save(sess,
-                        os.path.join(checkpoint_dir, model_name),
-                        global_step = step)
-
-    def load(self, sess, checkpoint_dir, model_name=None):
-        print(" [*] Reading checkpoints...")
-        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
-        if ckpt and ckpt.model_checkpoint_path:
-            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
-            if model_name is None: model_name = ckpt_name
-            self.saver.restore(sess, os.path.join(checkpoint_dir, model_name))
-            print(" Loaded model: " + str(model_name))
-            return True, model_name
-        else:
-            return False, None
-
-        # Execute the forward and the backward pass
-
-    def run_single_step(self, global_step):
-        print("global_step:", global_step)
-        try:
-            train_batch = self.sess.run(self.train_iterator.get_next())
-            # z=np.random.uniform(-1,1,size=(self.batch_size,self.nz))
-            x = self.sess.run([self.x], feed_dict = {self.x: train_batch["images"]})
-            _, g_sum = self.sess.run([self.g_optim, self.g_sum], feed_dict = {self.x: train_batch["images"]})
-            _, d_sum = self.sess.run([self.d_optim, self.d_sum], feed_dict = {self.x: train_batch["images"]})
-
-            gen_data, train_loss = self.sess.run([self.gen_data, self.total_loss],
-                                                       feed_dict = {self.x: train_batch["images"]})
-
-        except tf.errors.OutOfRangeError:
-            print("train out of range error")
-
-        try:
-            val_batch = self.sess.run(self.val_iterator.get_next())
-            val_loss = self.sess.run([self.total_loss], feed_dict = {self.x: val_batch["images"]})
-            # self.val_writer.add_summary(val_summary, global_step)
-        except tf.errors.OutOfRangeError:
-            print("train out of range error")
-
-        return train_loss, val_total_loss
-
-
-
diff --git a/video_prediction_tools/model_modules/video_prediction/models/non_trainable_model.py b/video_prediction_tools/model_modules/video_prediction/models/non_trainable_model.py
deleted file mode 100644
index aba90339317dafcf114442ca37cb62338e32d8cd..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/models/non_trainable_model.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# SPDX-FileCopyrightText: 2018, alexlee-gk
-#
-# SPDX-License-Identifier: MIT
-
-from tensorflow.python.util import nest
-from model_modules.video_prediction.utils.tf_utils import transpose_batch_time
-
-import tensorflow as tf
-
-from .base_model import BaseVideoPredictionModel
-
-
-class NonTrainableVideoPredictionModel(BaseVideoPredictionModel):
-    pass
-
-
-class GroundTruthVideoPredictionModel(NonTrainableVideoPredictionModel):
-    def build_graph(self, inputs):
-        super(GroundTruthVideoPredictionModel, self).build_graph(inputs)
-
-        self.outputs = OrderedDict()
-        self.outputs['gen_images'] = self.inputs['images'][:, 1:]
-        if 'pix_distribs' in self.inputs:
-            self.outputs['gen_pix_distribs'] = self.inputs['pix_distribs'][:, 1:]
-
-        inputs, outputs = nest.map_structure(transpose_batch_time, (self.inputs, self.outputs))
-        with tf.name_scope("metrics"):
-            metrics = self.metrics_fn(inputs, outputs)
-        with tf.name_scope("eval_outputs_and_metrics"):
-            eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs)
-        self.metrics, self.eval_outputs, self.eval_metrics = nest.map_structure(
-            transpose_batch_time, (metrics, eval_outputs, eval_metrics))
-
-
-class RepeatVideoPredictionModel(NonTrainableVideoPredictionModel):
-    def build_graph(self, inputs):
-        super(RepeatVideoPredictionModel, self).build_graph(inputs)
-
-        self.outputs = OrderedDict()
-        tile_pattern = [1, self.hparams.sequence_length - self.hparams.context_frames, 1, 1, 1]
-        last_context_images = self.inputs['images'][:, self.hparams.context_frames - 1]
-        self.outputs['gen_images'] = tf.concat([
-            self.inputs['images'][:, 1:self.hparams.context_frames - 1],
-            tf.tile(last_context_images[:, None], tile_pattern)], axis=-1)
-        if 'pix_distribs' in self.inputs:
-            last_context_pix_distrib = self.inputs['pix_distribs'][:, self.hparams.context_frames - 1]
-            self.outputs['gen_pix_distribs'] = tf.concat([
-                self.inputs['pix_distribs'][:, 1:self.hparams.context_frames - 1],
-                tf.tile(last_context_pix_distrib[:, None], tile_pattern)], axis=-1)
-
-        inputs, outputs = nest.map_structure(transpose_batch_time, (self.inputs, self.outputs))
-        with tf.name_scope("metrics"):
-            metrics = self.metrics_fn(inputs, outputs)
-        with tf.name_scope("eval_outputs_and_metrics"):
-            eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs)
-        self.metrics, self.eval_outputs, self.eval_metrics = nest.map_structure(
-            transpose_batch_time, (metrics, eval_outputs, eval_metrics))
diff --git a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
index 3219bda2de305cf36ad178ebcc192ac9a5a37b78..589a6b32b7fe9c80646e53310d54272f696cb88f 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
@@ -14,69 +14,146 @@ import tensorflow as tf
 
 class BaseModels(ABC):
 
-    def __init__(self, hparams_dict_config=None):
-        self.hparams_dict_config = hparams_dict_config
-        self.hparams_dict = self.get_model_hparams_dict()
-        self.hparams = self.parse_hparams()
-        # Attributes set during runtime
-        self.total_loss = None
-        self.loss_summary = None
+    def __init__(self, hparams_dict_config=None, mode="train"):
+        _modes = ["train","val"]
+
+        if mode not in _modes:
+            raise ValueError("The mode must be 'train' or 'val'")
+        else:
+            self.mode = mode
+
+        self.__model = None
+        self.hparams = self.hparams_options(hparams_dict_config)
+        self.parse_hparams(self.hparams)
+
+        # Compile options, must be customized in the sub-class
+        self.inputs = None
+        self.train_op = None
         self.total_loss = None
         self.outputs = {}
-        self.train_op = None
+        self.loss_summary = None
         self.summary_op = None
-        self.inputs = None
-        self.global_step = None
+        self.global_step = tf.train.get_or_create_global_step()
         self.saveable_variables = None
-        self.is_build_graph = None
-        self.x_hat = None
-        self.x_hat_predict_frames = None
+        self._is_build_graph_set = False
 
-
-    def get_model_hparams_dict(self):
-        """
-        Get model_hparams_dict from json file
-        """
-        if self.hparams_dict_config:
-            with open(self.hparams_dict_config, 'r') as f:
+    
+    def hparams_options(self, hparams_dict_config:str):
+        if hparams_dict_config:
+            with open(hparams_dict_config, 'r') as f:
                 hparams_dict = json.loads(f.read())
         else:
-            raise FileNotFoundError("hyper-parameter directory doesn't exist! please check {}!".format(self.hparams_dict_config))
+            raise FileNotFoundError("hyper-parameter directory doesn't exist! please check {}!".format(hparams_dict_config))
+        return dotdict(hparams_dict)
 
-        return hparams_dict
 
-    def parse_hparams(self):
+    @abstractmethod
+    def parse_hparams(self, hparams)->None:
         """
-        Obtain the parameters from directory
+        parse the hyper-parameter as class attribute
+        Examples:
+            ... code-block:: python
+            def parse_hparams(self):
+                try:
+                    self.context_frames = hparams.context_frames
+                    self.max_epochs = hparams.max_epochs
+                    self.batch_size = hparams.batch_size
+                    self.shuffle_on_val = hparams.shuffle_on_val
+                    self.loss_fun = hparams.loss_fun
+
+                except Exception as e:
+                    raise ValueError(f"missing hyperparameter: {e.args[0]}")
         """
+        pass
 
-        hparams = dotdict(self.hparams_dict)
-        return hparams
-
-    @abstractmethod
+    @property
     def get_hparams(self):
+        return self.hparams
+
+
+    def build_graph(self, x: tf.Tensor)->bool:
         """
-        obtain the hparams from the dict to the class variables
+        This function is used for build the graph, and allow a optimiser to the graph by using tensorflow function.
+
+        Example:
+            ... code-block:: python
+                def build_graph(self, inputs):
+                    original_global_variables = tf.global_variables()
+                    x_hat = self.build_model(x)
+                    self.train_loss = self.get_loss(x,x_hat)
+                    self.train_op = self.optimizer(self.train_loss)
+                    self.outputs["gen_images"] = x_hat
+                    self.summary_op = self.summary() #This is optional
+                    global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+                    self.saveable_variables = [self.global_step] + global_variables
+                    self._is_build_graph_set=True
+                    return self._is_build_graph_set
+
         """
-        method = BaseModels.get_hparams.__name__
+        self.inputs = x
+        original_global_variables = tf.global_variables()
+        x_hat = self.build_model(x)
+        self.total_loss = self.get_loss(x, x_hat)
+        self.train_op = self.optimizer(self.total_loss)
+        self.outputs["gen_images"] = x_hat
+        self.summary_op = self.summary(total_loss = self.total_loss)
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        self._is_build_graph_set = True
+        return self._is_build_graph_set
+
+
+    def optimizer(self, total_loss):
+        """
+        Define the optimizer
+        Example:
+            ... code-block:: python
+                def optimizer(self):
+                    train_op = tf.train.AdamOptimizer(
+                        learning_rate = self.lr).minimize(total_loss, global_step = self.global_step)
+                    return train_op
+        """
+        train_op = tf.train.AdamOptimizer(
+            learning_rate = self.lr).minimize(total_loss, global_step = self.global_step)
+        return train_op
+
 
-        try:
-            self.context_frames = self.hparams.context_frames
-            self.max_epochs = self.hparams.max_epochs
-            self.batch_size = self.hparams.batch_size
-            self.shuffle_on_val = self.hparams.shuffle_on_val
-            self.loss_fun = self.hparams.loss_fun
 
-        except Exception as error:
-           print("Method %{}: error: {}".format(method,error))
-           raise("Method %{}: the hparameter dictionary must include "
-                 "'context_frames','max_epochs','batch_size','shuffle_on_val' 'loss_fun'".format(method))
 
     @abstractmethod
-    def build_graph(self, x: tf.Tensor):
+    def get_loss(self, x:tf.Tensor, x_hat:tf.Tensor)->tf.Tensor:
+        """
+        :param x    : Input tensors
+        :param x_hat: Prediction/output tensors
+        :return     : the loss function
+        """
         pass
 
 
+    def summary(self, **kwargs):
+        """
+        return the summary operation can be used for TensorBoard
+        """
+        for key, value in kwargs.items():
+            tf.summary.scalar(key, value)
+        summary_op = tf.summary.merge_all()
+        return summary_op
+
+
+
     @abstractmethod
-    def build_model(self):
+    def build_model(self, x)->tf.Tensor:
+        """
+        This function is used to create the network
+        Example: see example in vanilla_convLSTM_model.py, it must return prediction fnsrames and save it to the self.output
+        which is used for calculating the loss
+        """
         pass
+
+
+
+
+
+
+
+
diff --git a/video_prediction_tools/model_modules/video_prediction/models/savp_model.py b/video_prediction_tools/model_modules/video_prediction/models/savp_model.py
deleted file mode 100644
index 17d72563e5dffac8bc1ff1b863d083d603d0a69b..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/models/savp_model.py
+++ /dev/null
@@ -1,996 +0,0 @@
-# SPDX-FileCopyrightText: 2018, alexlee-gk
-#
-# SPDX-License-Identifier: MIT
-
-import collections
-import functools
-import itertools
-from collections import OrderedDict
-import numpy as np
-import tensorflow as tf
-from tensorflow.python.util import nest
-from model_modules.video_prediction import ops, flow_ops
-from model_modules.video_prediction.models import VideoPredictionModel
-from model_modules.video_prediction.models import networks
-from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
-from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
-from model_modules.video_prediction.utils import tf_utils
-
-# Amount to use when lower bounding tensors
-RELU_SHIFT = 1e-12
-
-
-def posterior_fn(inputs, hparams):
-    images = inputs['images']
-    image_pairs = tf.concat([images[:-1], images[1:]], axis=-1)
-    if 'actions' in inputs:
-        image_pairs = tile_concat(
-            [image_pairs, inputs['actions'][..., None, None, :]], axis=-1)
-
-    h = tf_utils.with_flat_batch(networks.encoder)(
-        image_pairs, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer)
-
-    if hparams.use_e_rnn:
-        with tf.variable_scope('layer_%d' % (hparams.n_layers + 1)):
-            h = tf_utils.with_flat_batch(dense, 2)(h, hparams.nef * 4)
-
-        if hparams.rnn == 'lstm':
-            RNNCell = tf.contrib.rnn.BasicLSTMCell
-        elif hparams.rnn == 'gru':
-            RNNCell = tf.contrib.rnn.GRUCell
-        else:
-            raise NotImplementedError
-        with tf.variable_scope('%s' % hparams.rnn):
-            rnn_cell = RNNCell(hparams.nef * 4)
-            h, _ = tf_utils.unroll_rnn(rnn_cell, h)
-
-    with tf.variable_scope('z_mu'):
-        z_mu = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz)
-    with tf.variable_scope('z_log_sigma_sq'):
-        z_log_sigma_sq = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz)
-        z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10)
-    outputs = {'zs_mu': z_mu, 'zs_log_sigma_sq': z_log_sigma_sq}
-    return outputs
-
-
-def prior_fn(inputs, hparams):
-    images = inputs['images']
-    image_pairs = tf.concat([images[:hparams.context_frames - 1], images[1:hparams.context_frames]], axis=-1)
-    if 'actions' in inputs:
-        image_pairs = tile_concat(
-            [image_pairs, inputs['actions'][..., None, None, :]], axis=-1)
-
-    h = tf_utils.with_flat_batch(networks.encoder)(
-        image_pairs, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer)
-    h_zeros = tf.zeros(tf.concat([[hparams.sequence_length - hparams.context_frames], tf.shape(h)[1:]], axis=0))
-    h = tf.concat([h, h_zeros], axis=0)
-
-    with tf.variable_scope('layer_%d' % (hparams.n_layers + 1)):
-        h = tf_utils.with_flat_batch(dense, 2)(h, hparams.nef * 4)
-
-    if hparams.rnn == 'lstm':
-        RNNCell = tf.contrib.rnn.BasicLSTMCell
-    elif hparams.rnn == 'gru':
-        RNNCell = tf.contrib.rnn.GRUCell
-    else:
-        raise NotImplementedError
-    with tf.variable_scope('%s' % hparams.rnn):
-        rnn_cell = RNNCell(hparams.nef * 4)
-        h, _ = tf_utils.unroll_rnn(rnn_cell, h)
-
-    with tf.variable_scope('z_mu'):
-        z_mu = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz)
-    with tf.variable_scope('z_log_sigma_sq'):
-        z_log_sigma_sq = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz)
-        z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10)
-    outputs = {'zs_mu': z_mu, 'zs_log_sigma_sq': z_log_sigma_sq}
-    return outputs
-
-
-def discriminator_given_video_fn(targets, hparams):
-    sequence_length, batch_size = targets.shape.as_list()[:2]
-    clip_length = hparams.clip_length
-
-    # sample an image and apply the image distriminator on that frame
-    t_sample = tf.random_uniform([batch_size], minval=0, maxval=sequence_length, dtype=tf.int32)
-    image_sample = tf.gather_nd(targets, tf.stack([t_sample, tf.range(batch_size)], axis=1))
-
-    # sample a subsequence of length clip_length and apply the images/video discriminators on those frames
-    t_start = tf.random_uniform([batch_size], minval=0, maxval=sequence_length - clip_length + 1, dtype=tf.int32)
-    t_start_indices = tf.stack([t_start, tf.range(batch_size)], axis=1)
-    t_offset_indices = tf.stack([tf.range(clip_length), tf.zeros(clip_length, dtype=tf.int32)], axis=1)
-    indices = t_start_indices[None] + t_offset_indices[:, None]
-    clip_sample = tf.gather_nd(targets, flatten(indices, 0, 1))
-    clip_sample = tf.reshape(clip_sample, [clip_length] + targets.shape.as_list()[1:])
-
-    outputs = {}
-    if hparams.image_sn_gan_weight or hparams.image_sn_vae_gan_weight:
-        with tf.variable_scope('image'):
-            image_features = networks.image_sn_discriminator(image_sample, ndf=hparams.ndf)
-            image_features, image_logits = image_features[:-1], image_features[-1]
-            outputs['discrim_image_sn_logits'] = image_logits
-            for i, image_feature in enumerate(image_features):
-                outputs['discrim_image_sn_feature%d' % i] = image_feature
-    if hparams.video_sn_gan_weight or hparams.video_sn_vae_gan_weight:
-        with tf.variable_scope('video'):
-            video_features = networks.video_sn_discriminator(clip_sample, ndf=hparams.ndf)
-            video_features, video_logits = video_features[:-1], video_features[-1]
-            outputs['discrim_video_sn_logits'] = video_logits
-            for i, video_feature in enumerate(video_features):
-                outputs['discrim_video_sn_feature%d' % i] = video_feature
-    if hparams.images_sn_gan_weight or hparams.images_sn_vae_gan_weight:
-        with tf.variable_scope('images'):
-            images_features = tf_utils.with_flat_batch(networks.image_sn_discriminator)(clip_sample, ndf=hparams.ndf)
-            images_features, images_logits = images_features[:-1], images_features[-1]
-            outputs['discrim_images_sn_logits'] = images_logits
-            for i, images_feature in enumerate(images_features):
-                outputs['discrim_images_sn_feature%d' % i] = images_feature
-    return outputs
-
-
-def discriminator_fn(inputs, outputs, mode, hparams):
-    # do the encoder version first so that it isn't affected by the reuse_variables() call
-    if hparams.nz == 0:
-        discrim_outputs_enc_real = collections.OrderedDict()
-        discrim_outputs_enc_fake = collections.OrderedDict()
-    else:
-        images_enc_real = inputs['images'][1:]
-        images_enc_fake = outputs['gen_images_enc']
-        if hparams.use_same_discriminator:
-            with tf.name_scope("real"):
-                discrim_outputs_enc_real = discriminator_given_video_fn(images_enc_real, hparams)
-            tf.get_variable_scope().reuse_variables()
-            with tf.name_scope("fake"):
-                discrim_outputs_enc_fake = discriminator_given_video_fn(images_enc_fake, hparams)
-        else:
-            with tf.variable_scope('encoder'), tf.name_scope("real"):
-                discrim_outputs_enc_real = discriminator_given_video_fn(images_enc_real, hparams)
-            with tf.variable_scope('encoder', reuse=True), tf.name_scope("fake"):
-                discrim_outputs_enc_fake = discriminator_given_video_fn(images_enc_fake, hparams)
-
-    images_real = inputs['images'][1:]
-    images_fake = outputs['gen_images']
-    with tf.name_scope("real"):
-        discrim_outputs_real = discriminator_given_video_fn(images_real, hparams)
-    tf.get_variable_scope().reuse_variables()
-    with tf.name_scope("fake"):
-        discrim_outputs_fake = discriminator_given_video_fn(images_fake, hparams)
-
-    discrim_outputs_real = OrderedDict([(k + '_real', v) for k, v in discrim_outputs_real.items()])
-    discrim_outputs_fake = OrderedDict([(k + '_fake', v) for k, v in discrim_outputs_fake.items()])
-    discrim_outputs_enc_real = OrderedDict([(k + '_enc_real', v) for k, v in discrim_outputs_enc_real.items()])
-    discrim_outputs_enc_fake = OrderedDict([(k + '_enc_fake', v) for k, v in discrim_outputs_enc_fake.items()])
-    outputs = [discrim_outputs_real, discrim_outputs_fake,
-               discrim_outputs_enc_real, discrim_outputs_enc_fake]
-    total_num_outputs = sum([len(output) for output in outputs])
-    outputs = collections.OrderedDict(itertools.chain(*[output.items() for output in outputs]))
-    assert len(outputs) == total_num_outputs  # ensure no output is lost because of repeated keys
-    return outputs
-
-
-class SAVPCell(tf.nn.rnn_cell.RNNCell):
-    def __init__(self, inputs, mode, hparams, reuse=None):
-        super(SAVPCell, self).__init__(_reuse=reuse)
-        self.inputs = inputs
-        self.mode = mode
-        self.hparams = hparams
-
-        if self.hparams.where_add not in ('input', 'all', 'middle'):
-            raise ValueError('Invalid where_add %s' % self.hparams.where_add)
-
-        batch_size = inputs['images'].shape[1].value
-        image_shape = inputs['images'].shape.as_list()[2:]
-        height, width, _ = image_shape
-        scale_size = min(height, width)
-        if scale_size >= 256:
-            self.encoder_layer_specs = [
-                (self.hparams.ngf, False),
-                (self.hparams.ngf * 2, False),
-                (self.hparams.ngf * 4, True),
-                (self.hparams.ngf * 8, True),
-                (self.hparams.ngf * 8, True),
-            ]
-            self.decoder_layer_specs = [
-                (self.hparams.ngf * 8, True),
-                (self.hparams.ngf * 4, True),
-                (self.hparams.ngf * 2, False),
-                (self.hparams.ngf, False),
-                (self.hparams.ngf, False),
-            ]
-        elif scale_size >= 128:
-            self.encoder_layer_specs = [
-                (self.hparams.ngf, False),
-                (self.hparams.ngf * 2, True),
-                (self.hparams.ngf * 4, True),
-                (self.hparams.ngf * 8, True),
-            ]
-            self.decoder_layer_specs = [
-                (self.hparams.ngf * 8, True),
-                (self.hparams.ngf * 4, True),
-                (self.hparams.ngf * 2, False),
-                (self.hparams.ngf, False),
-            ]
-        elif scale_size >= 64:
-            self.encoder_layer_specs = [
-                (self.hparams.ngf, True),
-                (self.hparams.ngf * 2, True),
-                (self.hparams.ngf * 4, True),
-            ]
-            self.decoder_layer_specs = [
-                (self.hparams.ngf * 2, True),
-                (self.hparams.ngf, True),
-                (self.hparams.ngf, False),
-            ]
-        elif scale_size >= 32:
-            self.encoder_layer_specs = [
-                (self.hparams.ngf, True),
-                (self.hparams.ngf * 2, True),
-            ]
-            self.decoder_layer_specs = [
-                (self.hparams.ngf, True),
-                (self.hparams.ngf, False),
-            ]
-        else:
-            print("The minimum of image size is 32")
-            raise NotImplementedError
-        assert len(self.encoder_layer_specs) == len(self.decoder_layer_specs)
-        total_stride = 2 ** len(self.encoder_layer_specs)
-        if (height % total_stride) or (width % total_stride):
-            raise ValueError("The image has dimension (%d, %d), but it should be divisible "
-                             "by the total stride, which is %d." % (height, width, total_stride))
-
-        # output_size
-        num_masks = self.hparams.last_frames * self.hparams.num_transformed_images + \
-            int(bool(self.hparams.prev_image_background)) + \
-            int(bool(self.hparams.first_image_background and not self.hparams.context_images_background)) + \
-            int(bool(self.hparams.last_image_background and not self.hparams.context_images_background)) + \
-            int(bool(self.hparams.last_context_image_background and not self.hparams.context_images_background)) + \
-            (self.hparams.context_frames if self.hparams.context_images_background else 0) + \
-            int(bool(self.hparams.generate_scratch_image))
-        output_size = {
-            'gen_images': tf.TensorShape(image_shape),
-            'transformed_images': tf.TensorShape(image_shape + [num_masks]),
-            'masks': tf.TensorShape([height, width, 1, num_masks]),
-        }
-        if 'pix_distribs' in inputs:
-            num_motions = inputs['pix_distribs'].shape[-1].value
-            output_size['gen_pix_distribs'] = tf.TensorShape([height, width, num_motions])
-            output_size['transformed_pix_distribs'] = tf.TensorShape([height, width, num_motions, num_masks])
-        if 'states' in inputs:
-            output_size['gen_states'] = inputs['states'].shape[2:]
-        if self.hparams.transformation == 'flow':
-            output_size['gen_flows'] = tf.TensorShape([height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images])
-            output_size['gen_flows_rgb'] = tf.TensorShape([height, width, 3, self.hparams.last_frames * self.hparams.num_transformed_images])
-        self._output_size = output_size
-
-        # state_size
-        conv_rnn_state_sizes = []
-        conv_rnn_height, conv_rnn_width = height, width
-        for out_channels, use_conv_rnn in self.encoder_layer_specs:
-            conv_rnn_height //= 2
-            conv_rnn_width //= 2
-            if use_conv_rnn and not self.hparams.ablation_rnn:
-                conv_rnn_state_sizes.append(tf.TensorShape([conv_rnn_height, conv_rnn_width, out_channels]))
-        for out_channels, use_conv_rnn in self.decoder_layer_specs:
-            conv_rnn_height *= 2
-            conv_rnn_width *= 2
-            if use_conv_rnn and not self.hparams.ablation_rnn:
-                conv_rnn_state_sizes.append(tf.TensorShape([conv_rnn_height, conv_rnn_width, out_channels]))
-        if self.hparams.conv_rnn == 'lstm':
-            conv_rnn_state_sizes = [tf.nn.rnn_cell.LSTMStateTuple(conv_rnn_state_size, conv_rnn_state_size)
-                                    for conv_rnn_state_size in conv_rnn_state_sizes]
-        state_size = {'time': tf.TensorShape([]),
-                      'gen_image': tf.TensorShape(image_shape),
-                      'last_images': [tf.TensorShape(image_shape)] * self.hparams.last_frames,
-                      'conv_rnn_states': conv_rnn_state_sizes}
-        if 'zs' in inputs and self.hparams.use_rnn_z and not self.hparams.ablation_rnn:
-            rnn_z_state_size = tf.TensorShape([self.hparams.nz])
-            if self.hparams.rnn == 'lstm':
-                rnn_z_state_size = tf.nn.rnn_cell.LSTMStateTuple(rnn_z_state_size, rnn_z_state_size)
-            state_size['rnn_z_state'] = rnn_z_state_size
-        if 'pix_distribs' in inputs:
-            state_size['gen_pix_distrib'] = tf.TensorShape([height, width, num_motions])
-            state_size['last_pix_distribs'] = [tf.TensorShape([height, width, num_motions])] * self.hparams.last_frames
-        if 'states' in inputs:
-            state_size['gen_state'] = inputs['states'].shape[2:]
-        self._state_size = state_size
-
-        if self.hparams.learn_initial_state:
-            learnable_initial_state_size = {k: v for k, v in state_size.items()
-                                            if k in ('conv_rnn_states', 'rnn_z_state')}
-        else:
-            learnable_initial_state_size = {}
-        learnable_initial_state_flat = []
-        for i, size in enumerate(nest.flatten(learnable_initial_state_size)):
-            with tf.variable_scope('initial_state_%d' % i):
-                state = tf.get_variable('initial_state', size,
-                                        dtype=tf.float32, initializer=tf.zeros_initializer())
-                learnable_initial_state_flat.append(state)
-        self._learnable_initial_state = nest.pack_sequence_as(
-            learnable_initial_state_size, learnable_initial_state_flat)
-
-        ground_truth_sampling_shape = [self.hparams.sequence_length - 1 - self.hparams.context_frames, batch_size]
-        if self.hparams.schedule_sampling == 'none' or self.mode != 'train':
-            ground_truth_sampling = tf.constant(False, dtype=tf.bool, shape=ground_truth_sampling_shape)
-        elif self.hparams.schedule_sampling in ('inverse_sigmoid', 'linear'):
-            if self.hparams.schedule_sampling == 'inverse_sigmoid':
-                k = self.hparams.schedule_sampling_k
-                start_step = self.hparams.schedule_sampling_steps[0]
-                iter_num = tf.to_float(tf.train.get_or_create_global_step())
-                prob = (k / (k + tf.exp((iter_num - start_step) / k)))
-                prob = tf.cond(tf.less(iter_num, start_step), lambda: 1.0, lambda: prob)
-            elif self.hparams.schedule_sampling == 'linear':
-                start_step, end_step = self.hparams.schedule_sampling_steps
-                step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step)
-                prob = 1.0 - tf.to_float(step - start_step) / tf.to_float(end_step - start_step)
-            log_probs = tf.log([1 - prob, prob])
-            ground_truth_sampling = tf.multinomial([log_probs] * batch_size, ground_truth_sampling_shape[0])
-            ground_truth_sampling = tf.cast(tf.transpose(ground_truth_sampling, [1, 0]), dtype=tf.bool)
-            # Ensure that eventually, the model is deterministically
-            # autoregressive (as opposed to autoregressive with very high probability).
-            ground_truth_sampling = tf.cond(tf.less(prob, 0.001),
-                                            lambda: tf.constant(False, dtype=tf.bool, shape=ground_truth_sampling_shape),
-                                            lambda: ground_truth_sampling)
-        else:
-            raise NotImplementedError
-        ground_truth_context = tf.constant(True, dtype=tf.bool, shape=[self.hparams.context_frames, batch_size])
-        self.ground_truth = tf.concat([ground_truth_context, ground_truth_sampling], axis=0)
-
-    @property
-    def output_size(self):
-        return self._output_size
-
-    @property
-    def state_size(self):
-        return self._state_size
-
-    def zero_state(self, batch_size, dtype):
-        init_state = super(SAVPCell, self).zero_state(batch_size, dtype)
-        learnable_init_state = nest.map_structure(
-            lambda x: tf.tile(x[None], [batch_size] + [1] * x.shape.ndims), self._learnable_initial_state)
-        init_state.update(learnable_init_state)
-        init_state['last_images'] = [self.inputs['images'][0]] * self.hparams.last_frames
-        if 'pix_distribs' in self.inputs:
-            init_state['last_pix_distribs'] = [self.inputs['pix_distribs'][0]] * self.hparams.last_frames
-        return init_state
-
-    def _rnn_func(self, inputs, state, num_units):
-        if self.hparams.rnn == 'lstm':
-            RNNCell = functools.partial(tf.nn.rnn_cell.LSTMCell, name='basic_lstm_cell')
-        elif self.hparams.rnn == 'gru':
-            RNNCell = tf.contrib.rnn.GRUCell
-        else:
-            raise NotImplementedError
-        rnn_cell = RNNCell(num_units, reuse=tf.get_variable_scope().reuse)
-        return rnn_cell(inputs, state)
-
-    def _conv_rnn_func(self, inputs, state, filters):
-        if isinstance(inputs, (list, tuple)):
-            inputs_shape = inputs[0].shape.as_list()
-        else:
-            inputs_shape = inputs.shape.as_list()
-        input_shape = inputs_shape[1:]
-        if self.hparams.conv_rnn_norm_layer == 'none':
-            normalizer_fn = None
-        else:
-            normalizer_fn = ops.get_norm_layer(self.hparams.conv_rnn_norm_layer)
-        if self.hparams.conv_rnn == 'lstm':
-            Conv2DRNNCell = BasicConv2DLSTMCell
-        elif self.hparams.conv_rnn == 'gru':
-            Conv2DRNNCell = Conv2DGRUCell
-        else:
-            raise NotImplementedError
-        if self.hparams.ablation_conv_rnn_norm:
-            conv_rnn_cell = Conv2DRNNCell(input_shape, filters, kernel_size=(5, 5),
-                                          reuse=tf.get_variable_scope().reuse)
-            h, state = conv_rnn_cell(inputs, state)
-            outputs = (normalizer_fn(h), state)
-        else:
-            conv_rnn_cell = Conv2DRNNCell(input_shape, filters, kernel_size=(5, 5),
-                                          normalizer_fn=normalizer_fn,
-                                          separate_norms=self.hparams.conv_rnn_norm_layer == 'layer',
-                                          reuse=tf.get_variable_scope().reuse)
-            outputs = conv_rnn_cell(inputs, state)
-        return outputs
-
-    def call(self, inputs, states):
-        norm_layer = ops.get_norm_layer(self.hparams.norm_layer)
-        downsample_layer = ops.get_downsample_layer(self.hparams.downsample_layer)
-        upsample_layer = ops.get_upsample_layer(self.hparams.upsample_layer)
-        activation_layer = ops.get_activation_layer(self.hparams.activation_layer)
-        image_shape = inputs['images'].get_shape().as_list()
-        batch_size, height, width, color_channels = image_shape
-        conv_rnn_states = states['conv_rnn_states']
-
-        time = states['time']
-        with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]):
-            t = tf.to_int32(tf.identity(time[0]))
-
-        image = tf.where(self.ground_truth[t], inputs['images'], states['gen_image'])  # schedule sampling (if any)
-        last_images = states['last_images'][1:] + [image]
-        if 'pix_distribs' in inputs:
-            pix_distrib = tf.where(self.ground_truth[t], inputs['pix_distribs'], states['gen_pix_distrib'])
-            last_pix_distribs = states['last_pix_distribs'][1:] + [pix_distrib]
-        if 'states' in inputs:
-            state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state'])
-
-        state_action = []
-        state_action_z = []
-        if 'actions' in inputs:
-            state_action.append(inputs['actions'])
-            state_action_z.append(inputs['actions'])
-        if 'states' in inputs:
-            state_action.append(state)
-            # don't backpropagate the convnet through the state dynamics
-            state_action_z.append(tf.stop_gradient(state))
-
-        if 'zs' in inputs:
-            if self.hparams.use_rnn_z:
-                with tf.variable_scope('%s_z' % ('fc' if self.hparams.ablation_rnn else self.hparams.rnn)):
-                    if self.hparams.ablation_rnn:
-                        rnn_z = dense(inputs['zs'], self.hparams.nz)
-                        rnn_z = tf.nn.tanh(rnn_z)
-                    else:
-                        rnn_z, rnn_z_state = self._rnn_func(inputs['zs'], states['rnn_z_state'], self.hparams.nz)
-                state_action_z.append(rnn_z)
-            else:
-                state_action_z.append(inputs['zs'])
-
-        def concat(tensors, axis):
-            if len(tensors) == 0:
-                return tf.zeros([batch_size, 0])
-            elif len(tensors) == 1:
-                return tensors[0]
-            else:
-                return tf.concat(tensors, axis=axis)
-        state_action = concat(state_action, axis=-1)
-        state_action_z = concat(state_action_z, axis=-1)
-
-        layers = []
-        new_conv_rnn_states = []
-        for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs):
-            with tf.variable_scope('h%d' % i):
-                if i == 0:
-                    h = tf.concat([image, self.inputs['images'][0]], axis=-1)
-                    kernel_size = (5, 5)
-                else:
-                    h = layers[-1][-1]
-                    kernel_size = (3, 3)
-                if self.hparams.where_add == 'all' or (self.hparams.where_add == 'input' and i == 0):
-                    if self.hparams.use_tile_concat:
-                        h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
-                    else:
-                        h = [h, state_action_z]
-                h = _maybe_tile_concat_layer(downsample_layer)(
-                    h, out_channels, kernel_size=kernel_size, strides=(2, 2))
-                h = norm_layer(h)
-                h = activation_layer(h)
-            if use_conv_rnn:
-                with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, i)):
-                    if self.hparams.where_add == 'all':
-                        if self.hparams.use_tile_concat:
-                            conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
-                        else:
-                            conv_rnn_h = [h, state_action_z]
-                    else:
-                        conv_rnn_h = h
-                    if self.hparams.ablation_rnn:
-                        conv_rnn_h = _maybe_tile_concat_layer(conv2d)(
-                            conv_rnn_h, out_channels, kernel_size=(5, 5))
-                        conv_rnn_h = norm_layer(conv_rnn_h)
-                        conv_rnn_h = activation_layer(conv_rnn_h)
-                    else:
-                        conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
-                        conv_rnn_h, conv_rnn_state = self._conv_rnn_func(conv_rnn_h, conv_rnn_state, out_channels)
-                        new_conv_rnn_states.append(conv_rnn_state)
-            layers.append((h, conv_rnn_h) if use_conv_rnn else (h,))
-
-        num_encoder_layers = len(layers)
-        for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs):
-            with tf.variable_scope('h%d' % len(layers)):
-                if i == 0:
-                    h = layers[-1][-1]
-                else:
-                    h = tf.concat([layers[-1][-1], layers[num_encoder_layers - i - 1][-1]], axis=-1)
-                if self.hparams.where_add == 'all' or (self.hparams.where_add == 'middle' and i == 0):
-                    if self.hparams.use_tile_concat:
-                        h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
-                    else:
-                        h = [h, state_action_z]
-                h = _maybe_tile_concat_layer(upsample_layer)(
-                    h, out_channels, kernel_size=(3, 3), strides=(2, 2))
-                h = norm_layer(h)
-                h = activation_layer(h)
-            if use_conv_rnn:
-                with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, len(layers))):
-                    if self.hparams.where_add == 'all':
-                        if self.hparams.use_tile_concat:
-                            conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
-                        else:
-                            conv_rnn_h = [h, state_action_z]
-                    else:
-                        conv_rnn_h = h
-                    if self.hparams.ablation_rnn:
-                        conv_rnn_h = _maybe_tile_concat_layer(conv2d)(conv_rnn_h, out_channels, kernel_size=(5, 5))
-                        conv_rnn_h = norm_layer(conv_rnn_h)
-                        conv_rnn_h = activation_layer(conv_rnn_h)
-                    else:
-                        conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
-                        conv_rnn_h, conv_rnn_state = self._conv_rnn_func(conv_rnn_h, conv_rnn_state, out_channels)
-                        new_conv_rnn_states.append(conv_rnn_state)
-            layers.append((h, conv_rnn_h) if use_conv_rnn else (h,))
-        assert len(new_conv_rnn_states) == len(conv_rnn_states)
-
-        if self.hparams.last_frames and self.hparams.num_transformed_images:
-            if self.hparams.transformation == 'flow':
-                with tf.variable_scope('h%d_flow' % len(layers)):
-                    h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1))
-                    h_flow = norm_layer(h_flow)
-                    h_flow = activation_layer(h_flow)
-
-                with tf.variable_scope('flows'):
-                    flows = conv2d(h_flow, 2 * self.hparams.last_frames * self.hparams.num_transformed_images, kernel_size=(3, 3), strides=(1, 1))
-                    flows = tf.reshape(flows, [batch_size, height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images])
-            else:
-                assert len(self.hparams.kernel_size) == 2
-                kernel_shape = list(self.hparams.kernel_size) + [self.hparams.last_frames * self.hparams.num_transformed_images]
-                if self.hparams.transformation == 'dna':
-                    with tf.variable_scope('h%d_dna_kernel' % len(layers)):
-                        h_dna_kernel = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1))
-                        h_dna_kernel = norm_layer(h_dna_kernel)
-                        h_dna_kernel = activation_layer(h_dna_kernel)
-
-                    # Using largest hidden state for predicting untied conv kernels.
-                    with tf.variable_scope('dna_kernels'):
-                        kernels = conv2d(h_dna_kernel, np.prod(kernel_shape), kernel_size=(3, 3), strides=(1, 1))
-                        kernels = tf.reshape(kernels, [batch_size, height, width] + kernel_shape)
-                        kernels = kernels + identity_kernel(self.hparams.kernel_size)[None, None, None, :, :, None]
-                    kernel_spatial_axes = [3, 4]
-                elif self.hparams.transformation == 'cdna':
-                    with tf.variable_scope('cdna_kernels'):
-                        smallest_layer = layers[num_encoder_layers - 1][-1]
-                        kernels = dense(flatten(smallest_layer), np.prod(kernel_shape))
-                        kernels = tf.reshape(kernels, [batch_size] + kernel_shape)
-                        kernels = kernels + identity_kernel(self.hparams.kernel_size)[None, :, :, None]
-                    kernel_spatial_axes = [1, 2]
-                else:
-                    raise ValueError('Invalid transformation %s' % self.hparams.transformation)
-
-            if self.hparams.transformation != 'flow':
-                with tf.name_scope('kernel_normalization'):
-                    kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT
-                    kernels /= tf.reduce_sum(kernels, axis=kernel_spatial_axes, keepdims=True)
-
-        if self.hparams.generate_scratch_image:
-            with tf.variable_scope('h%d_scratch' % len(layers)):
-                h_scratch = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1))
-                h_scratch = norm_layer(h_scratch)
-                h_scratch = activation_layer(h_scratch)
-
-            # Using largest hidden state for predicting a new image layer.
-            # This allows the network to also generate one image from scratch,
-            # which is useful when regions of the image become unoccluded.
-            with tf.variable_scope('scratch_image'):
-                scratch_image = conv2d(h_scratch, color_channels, kernel_size=(3, 3), strides=(1, 1))
-                scratch_image = tf.nn.sigmoid(scratch_image)
-
-        with tf.name_scope('transformed_images'):
-            transformed_images = []
-            if self.hparams.last_frames and self.hparams.num_transformed_images:
-                if self.hparams.transformation == 'flow':
-                    transformed_images.extend(apply_flows(last_images, flows))
-                else:
-                    transformed_images.extend(apply_kernels(last_images, kernels, self.hparams.dilation_rate))
-            if self.hparams.prev_image_background:
-                transformed_images.append(image)
-            if self.hparams.first_image_background and not self.hparams.context_images_background:
-                transformed_images.append(self.inputs['images'][0])
-            if self.hparams.last_image_background and not self.hparams.context_images_background:
-                transformed_images.append(self.inputs['images'][self.hparams.context_frames - 1])
-            if self.hparams.last_context_image_background and not self.hparams.context_images_background:
-                last_context_image = tf.cond(
-                    tf.less(t, self.hparams.context_frames),
-                    lambda: self.inputs['images'][t],
-                    lambda: self.inputs['images'][self.hparams.context_frames - 1])
-                transformed_images.append(last_context_image)
-            if self.hparams.context_images_background:
-                transformed_images.extend(tf.unstack(self.inputs['images'][:self.hparams.context_frames]))
-            if self.hparams.generate_scratch_image:
-                transformed_images.append(scratch_image)
-
-        if 'pix_distribs' in inputs:
-            with tf.name_scope('transformed_pix_distribs'):
-                transformed_pix_distribs = []
-                if self.hparams.last_frames and self.hparams.num_transformed_images:
-                    if self.hparams.transformation == 'flow':
-                        transformed_pix_distribs.extend(apply_flows(last_pix_distribs, flows))
-                    else:
-                        transformed_pix_distribs.extend(apply_kernels(last_pix_distribs, kernels, self.hparams.dilation_rate))
-                if self.hparams.prev_image_background:
-                    transformed_pix_distribs.append(pix_distrib)
-                if self.hparams.first_image_background and not self.hparams.context_images_background:
-                    transformed_pix_distribs.append(self.inputs['pix_distribs'][0])
-                if self.hparams.last_image_background and not self.hparams.context_images_background:
-                    transformed_pix_distribs.append(self.inputs['pix_distribs'][self.hparams.context_frames - 1])
-                if self.hparams.last_context_image_background and not self.hparams.context_images_background:
-                    last_context_pix_distrib = tf.cond(
-                        tf.less(t, self.hparams.context_frames),
-                        lambda: self.inputs['pix_distribs'][t],
-                        lambda: self.inputs['pix_distribs'][self.hparams.context_frames - 1])
-                    transformed_pix_distribs.append(last_context_pix_distrib)
-                if self.hparams.context_images_background:
-                    transformed_pix_distribs.extend(tf.unstack(self.inputs['pix_distribs'][:self.hparams.context_frames]))
-                if self.hparams.generate_scratch_image:
-                    transformed_pix_distribs.append(pix_distrib)
-
-        with tf.name_scope('masks'):
-            if len(transformed_images) > 1:
-                with tf.variable_scope('h%d_masks' % len(layers)):
-                    h_masks = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1))
-                    h_masks = norm_layer(h_masks)
-                    h_masks = activation_layer(h_masks)
-
-                with tf.variable_scope('masks'):
-                    if self.hparams.dependent_mask:
-                        h_masks = tf.concat([h_masks] + transformed_images, axis=-1)
-                    masks = conv2d(h_masks, len(transformed_images), kernel_size=(3, 3), strides=(1, 1))
-                    masks = tf.nn.softmax(masks)
-                    masks = tf.split(masks, len(transformed_images), axis=-1)
-            elif len(transformed_images) == 1:
-                masks = [tf.ones([batch_size, height, width, 1])]
-            else:
-                raise ValueError("Either one of the following should be true: "
-                                 "last_frames and num_transformed_images, first_image_background, "
-                                 "prev_image_background, generate_scratch_image")
-
-        with tf.name_scope('gen_images'):
-            assert len(transformed_images) == len(masks)
-            gen_image = tf.add_n([transformed_image * mask
-                                  for transformed_image, mask in zip(transformed_images, masks)])
-
-        if 'pix_distribs' in inputs:
-            with tf.name_scope('gen_pix_distribs'):
-                assert len(transformed_pix_distribs) == len(masks)
-                gen_pix_distrib = tf.add_n([transformed_pix_distrib * mask
-                                            for transformed_pix_distrib, mask in zip(transformed_pix_distribs, masks)])
-                gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True)
-
-        if 'states' in inputs:
-            with tf.name_scope('gen_states'):
-                with tf.variable_scope('state_pred'):
-                    gen_state = dense(state_action, inputs['states'].shape[-1].value)
-
-        outputs = {'gen_images': gen_image,
-                   'transformed_images': tf.stack(transformed_images, axis=-1),
-                   'masks': tf.stack(masks, axis=-1)}
-        if 'pix_distribs' in inputs:
-            outputs['gen_pix_distribs'] = gen_pix_distrib
-            outputs['transformed_pix_distribs'] = tf.stack(transformed_pix_distribs, axis=-1)
-        if 'states' in inputs:
-            outputs['gen_states'] = gen_state
-        if self.hparams.transformation == 'flow':
-            outputs['gen_flows'] = flows
-            flows_transposed = tf.transpose(flows, [0, 1, 2, 4, 3])
-            flows_rgb_transposed = tf_utils.flow_to_rgb(flows_transposed)
-            flows_rgb = tf.transpose(flows_rgb_transposed, [0, 1, 2, 4, 3])
-            outputs['gen_flows_rgb'] = flows_rgb
-
-        new_states = {'time': time + 1,
-                      'gen_image': gen_image,
-                      'last_images': last_images,
-                      'conv_rnn_states': new_conv_rnn_states}
-        if 'zs' in inputs and self.hparams.use_rnn_z and not self.hparams.ablation_rnn:
-            new_states['rnn_z_state'] = rnn_z_state
-        if 'pix_distribs' in inputs:
-            new_states['gen_pix_distrib'] = gen_pix_distrib
-            new_states['last_pix_distribs'] = last_pix_distribs
-        if 'states' in inputs:
-            new_states['gen_state'] = gen_state
-        return outputs, new_states
-
-
-def generator_given_z_fn(inputs, mode, hparams):
-    # all the inputs needs to have the same length for unrolling the rnn
-    #20200822 bing
-    inputs ={"images":inputs["images"]}
-    #20200822
-    inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
-              for name, input in inputs.items()}
-    cell = SAVPCell(inputs, mode, hparams)
-    outputs, _ = tf_utils.unroll_rnn(cell, inputs)
-    outputs['ground_truth_sampling_mean'] = tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:]))
-    return outputs
-
-
-def generator_fn(inputs, mode, hparams):
-    batch_size = tf.shape(inputs['images'])[1]
-
-    if hparams.nz == 0:
-        # no zs is given in inputs
-        outputs = generator_given_z_fn(inputs, mode, hparams)
-    else:
-        zs_shape = [hparams.sequence_length - 1, batch_size, hparams.nz]
-
-        # posterior
-        with tf.variable_scope('encoder'):
-            outputs_posterior = posterior_fn(inputs, hparams)
-            eps = tf.random_normal(zs_shape, 0, 1)
-            zs_posterior = outputs_posterior['zs_mu'] + tf.sqrt(tf.exp(outputs_posterior['zs_log_sigma_sq'])) * eps
-        inputs_posterior = dict(inputs)
-        inputs_posterior['zs'] = zs_posterior
-
-        # prior
-        if hparams.learn_prior:
-            with tf.variable_scope('prior'):
-                outputs_prior = prior_fn(inputs, hparams)
-            eps = tf.random_normal(zs_shape, 0, 1)
-            zs_prior = outputs_prior['zs_mu'] + tf.sqrt(tf.exp(outputs_prior['zs_log_sigma_sq'])) * eps
-        else:
-            outputs_prior = {}
-            zs_prior = tf.random_normal([hparams.sequence_length - hparams.context_frames] + zs_shape[1:], 0, 1)
-            zs_prior = tf.concat([zs_posterior[:hparams.context_frames - 1], zs_prior], axis=0)
-        inputs_prior = dict(inputs)
-        inputs_prior['zs'] = zs_prior
-
-        # generate
-        gen_outputs_posterior = generator_given_z_fn(inputs_posterior, mode, hparams)
-        tf.get_variable_scope().reuse_variables()
-        gen_outputs = generator_given_z_fn(inputs_prior, mode, hparams)
-
-        # rename tensors to avoid name collisions
-        output_prior = collections.OrderedDict([(k + '_prior', v) for k, v in outputs_prior.items()])
-        outputs_posterior = collections.OrderedDict([(k + '_enc', v) for k, v in outputs_posterior.items()])
-        gen_outputs_posterior = collections.OrderedDict([(k + '_enc', v) for k, v in gen_outputs_posterior.items()])
-
-        outputs = [output_prior, gen_outputs, outputs_posterior, gen_outputs_posterior]
-        total_num_outputs = sum([len(output) for output in outputs])
-        outputs = collections.OrderedDict(itertools.chain(*[output.items() for output in outputs]))
-        assert len(outputs) == total_num_outputs  # ensure no output is lost because of repeated keys
-
-        # generate multiple samples from the prior for visualization purposes
-        inputs_samples = {
-            name: tf.tile(input[:, None], [1, hparams.num_samples] + [1] * (input.shape.ndims - 1))
-            for name, input in inputs.items()}
-        zs_samples_shape = [hparams.sequence_length - 1, hparams.num_samples, batch_size, hparams.nz]
-        if hparams.learn_prior:
-            eps = tf.random_normal(zs_samples_shape, 0, 1)
-            zs_prior_samples = (outputs_prior['zs_mu'][:, None] +
-                                tf.sqrt(tf.exp(outputs_prior['zs_log_sigma_sq']))[:, None] * eps)
-        else:
-            zs_prior_samples = tf.random_normal(
-                [hparams.sequence_length - hparams.context_frames] + zs_samples_shape[1:], 0, 1)
-            zs_prior_samples = tf.concat(
-                [tf.tile(zs_posterior[:hparams.context_frames - 1][:, None], [1, hparams.num_samples, 1, 1]),
-                 zs_prior_samples], axis=0)
-        inputs_prior_samples = dict(inputs_samples)
-        inputs_prior_samples['zs'] = zs_prior_samples
-        inputs_prior_samples = {name: flatten(input, 1, 2) for name, input in inputs_prior_samples.items()}
-        gen_outputs_samples = generator_given_z_fn(inputs_prior_samples, mode, hparams)
-        gen_images_samples = gen_outputs_samples['gen_images']
-        gen_images_samples = tf.stack(tf.split(gen_images_samples, hparams.num_samples, axis=1), axis=-1)
-        gen_images_samples_avg = tf.reduce_mean(gen_images_samples, axis=-1)
-        outputs['gen_images_samples'] = gen_images_samples
-        outputs['gen_images_samples_avg'] = gen_images_samples_avg
-    return outputs
-
-
-class SAVPVideoPredictionModel(VideoPredictionModel):
-    def __init__(self, *args, **kwargs):
-        super(SAVPVideoPredictionModel, self).__init__(
-            generator_fn, discriminator_fn, *args, **kwargs)
-        if self.mode != 'train':
-            self.discriminator_fn = None
-        self.deterministic = not self.hparams.nz
-
-    def get_default_hparams_dict(self):
-        default_hparams = super(SAVPVideoPredictionModel, self).get_default_hparams_dict()
-        hparams = dict(
-            l1_weight=1.0,
-            l2_weight=0.0,
-            n_layers=3,
-            ndf=32,
-            norm_layer='instance',
-            use_same_discriminator=False,
-            ngf=32,
-            downsample_layer='conv_pool2d',
-            upsample_layer='upsample_conv2d',
-            activation_layer='relu',  # for generator only
-            transformation='cdna',
-            kernel_size=(5, 5),
-            dilation_rate=(1, 1),
-            where_add='all',
-            use_tile_concat=True,
-            learn_initial_state=False,
-            rnn='lstm',
-            conv_rnn='lstm',
-            conv_rnn_norm_layer='instance',
-            num_transformed_images=4,
-            last_frames=1,
-            prev_image_background=True,
-            first_image_background=True,
-            last_image_background=False,
-            last_context_image_background=False,
-            context_images_background=False,
-            generate_scratch_image=True,
-            dependent_mask=True,
-            schedule_sampling='inverse_sigmoid',
-            schedule_sampling_k=900.0,
-            schedule_sampling_steps=(0, 100000),
-            use_e_rnn=False,
-            learn_prior=False,
-            nz=8,
-            num_samples=8,
-            nef=64,
-            use_rnn_z=True,
-            ablation_conv_rnn_norm=False,
-            ablation_rnn=False,
-        )
-        return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-    def parse_hparams(self, hparams_dict, hparams):
-        # backwards compatibility
-        deprecated_hparams_keys = [
-            'num_gpus',
-            'e_net',
-            'd_conditional',
-            'd_downsample_layer',
-            'd_net',
-            'd_use_gt_inputs',
-            'acvideo_gan_weight',
-            'acvideo_vae_gan_weight',
-            'image_gan_weight',
-            'image_vae_gan_weight',
-            'tuple_gan_weight',
-            'tuple_vae_gan_weight',
-            'gan_weight',
-            'vae_gan_weight',
-            'video_gan_weight',
-            'video_vae_gan_weight',
-        ]
-        for deprecated_hparams_key in deprecated_hparams_keys:
-            hparams_dict.pop(deprecated_hparams_key, None)
-        return super(SAVPVideoPredictionModel, self).parse_hparams(hparams_dict, hparams)
-
-    def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
-        def restore_to_checkpoint_mapping(restore_name, checkpoint_var_names):
-            restore_name = restore_name.split(':')[0]
-            if restore_name not in checkpoint_var_names:
-                restore_name = restore_name.replace('savp_cell', 'dna_cell')
-            return restore_name
-
-        super(SAVPVideoPredictionModel, self).restore(sess, checkpoints, restore_to_checkpoint_mapping)
-
-
-def apply_dna_kernels(image, kernels, dilation_rate=(1, 1)):
-    """
-    Args:
-        image: A 4-D tensor of shape
-            `[batch, in_height, in_width, in_channels]`.
-        kernels: A 6-D of shape
-            `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`.
-    Returns:
-        A list of `num_transformed_images` 4-D tensors, each of shape
-            `[batch, in_height, in_width, in_channels]`.
-    """
-    dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2
-    batch_size, height, width, color_channels = image.get_shape().as_list()
-    batch_size, height, width, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list()
-    kernel_size = [kernel_height, kernel_width]
-
-    # Flatten the spatial dimensions.
-    kernels_reshaped = tf.reshape(kernels, [batch_size, height, width,
-                                            kernel_size[0] * kernel_size[1], num_transformed_images])
-    image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC')
-    # Combine channel and batch dimensions into the first dimension.
-    image_transposed = tf.transpose(image_padded, [3, 0, 1, 2])
-    image_reshaped = flatten(image_transposed, 0, 1)[..., None]
-    patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1],
-                                                strides=[1] * 4, rates=[1] + dilation_rate + [1], padding='VALID')
-    # Separate channel and batch dimensions, and move channel dimension.
-    patches_transposed = tf.reshape(patches_reshaped, [color_channels, batch_size, height, width, kernel_size[0] * kernel_size[1]])
-    patches = tf.transpose(patches_transposed, [1, 2, 3, 0, 4])
-    # Reduce along the spatial dimensions of the kernel.
-    outputs = tf.matmul(patches, kernels_reshaped)
-    outputs = tf.unstack(outputs, axis=-1)
-    return outputs
-
-
-def apply_cdna_kernels(image, kernels, dilation_rate=(1, 1)):
-    """
-    Args:
-        image: A 4-D tensor of shape
-            `[batch, in_height, in_width, in_channels]`.
-        kernels: A 4-D of shape
-            `[batch, kernel_size[0], kernel_size[1], num_transformed_images]`.
-    Returns:
-        A list of `num_transformed_images` 4-D tensors, each of shape
-            `[batch, in_height, in_width, in_channels]`.
-    """
-    batch_size, height, width, color_channels = image.get_shape().as_list()
-    batch_size, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list()
-    kernel_size = [kernel_height, kernel_width]
-    image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC')
-    # Treat the color channel dimension as the batch dimension since the same
-    # transformation is applied to each color channel.
-    # Treat the batch dimension as the channel dimension so that
-    # depthwise_conv2d can apply a different transformation to each sample.
-    kernels = tf.transpose(kernels, [1, 2, 0, 3])
-    kernels = tf.reshape(kernels, [kernel_size[0], kernel_size[1], batch_size, num_transformed_images])
-    # Swap the batch and channel dimensions.
-    image_transposed = tf.transpose(image_padded, [3, 1, 2, 0])
-    # Transform image.
-    outputs = tf.nn.depthwise_conv2d(image_transposed, kernels, [1, 1, 1, 1], padding='VALID', rate=dilation_rate)
-    # Transpose the dimensions to where they belong.
-    outputs = tf.reshape(outputs, [color_channels, height, width, batch_size, num_transformed_images])
-    outputs = tf.transpose(outputs, [4, 3, 1, 2, 0])
-    outputs = tf.unstack(outputs, axis=0)
-    return outputs
-
-
-def apply_kernels(image, kernels, dilation_rate=(1, 1)):
-    """
-    Args:
-        image: A 4-D tensor of shape
-            `[batch, in_height, in_width, in_channels]`.
-        kernels: A 4-D or 6-D tensor of shape
-            `[batch, kernel_size[0], kernel_size[1], num_transformed_images]` or
-            `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`.
-    Returns:
-        A list of `num_transformed_images` 4-D tensors, each of shape
-            `[batch, in_height, in_width, in_channels]`.
-    """
-    if isinstance(image, list):
-        image_list = image
-        kernels_list = tf.split(kernels, len(image_list), axis=-1)
-        outputs = []
-        for image, kernels in zip(image_list, kernels_list):
-            outputs.extend(apply_kernels(image, kernels))
-    else:
-        if len(kernels.get_shape()) == 4:
-            outputs = apply_cdna_kernels(image, kernels, dilation_rate=dilation_rate)
-        elif len(kernels.get_shape()) == 6:
-            outputs = apply_dna_kernels(image, kernels, dilation_rate=dilation_rate)
-        else:
-            raise ValueError
-    return outputs
-
-
-def apply_flows(image, flows):
-    if isinstance(image, list):
-        image_list = image
-        flows_list = tf.split(flows, len(image_list), axis=-1)
-        outputs = []
-        for image, flows in zip(image_list, flows_list):
-            outputs.extend(apply_flows(image, flows))
-    else:
-        flows = tf.unstack(flows, axis=-1)
-        outputs = [flow_ops.image_warp(image, flow) for flow in flows]
-    return outputs
-
-
-def identity_kernel(kernel_size):
-    kh, kw = kernel_size
-    kernel = np.zeros(kernel_size)
-
-    def center_slice(k):
-        if k % 2 == 0:
-            return slice(k // 2 - 1, k // 2 + 1)
-        else:
-            return slice(k // 2, k // 2 + 1)
-
-    kernel[center_slice(kh), center_slice(kw)] = 1.0
-    kernel /= np.sum(kernel)
-    return kernel
-
-
-def _maybe_tile_concat_layer(conv2d_layer):
-    def layer(inputs, out_channels, *args, **kwargs):
-        if isinstance(inputs, (list, tuple)):
-            inputs_spatial, inputs_non_spatial = inputs
-            outputs = (conv2d_layer(inputs_spatial, out_channels, *args, **kwargs) +
-                       dense(inputs_non_spatial, out_channels, use_bias=False)[:, None, None, :])
-        else:
-            outputs = conv2d_layer(inputs, out_channels, *args, **kwargs)
-        return outputs
-
-    return layer
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
index dc58cba3638309cd6b7054b91e45363f2aa42fce..70c050f53efa81f1c482fa06f3cd1a5318b6e987 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
@@ -3,7 +3,7 @@
 # SPDX-License-Identifier: MIT
 
 __email__ = "b.gong@fz-juelich.de"
-__author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth"
+__author__ = "Bing Gong"
 __date__ = "2020-11-05"
 
 from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
@@ -11,119 +11,137 @@ import tensorflow as tf
 from model_modules.video_prediction.layers import layer_def as ld
 from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
 from .our_base_model import BaseModels
+from hparams_utils import *
 
 class VanillaConvLstmVideoPredictionModel(BaseModels):
-    def __init__(self, hparams_dict=None, **kwargs):
+
+    def __init__(self, hparams_dict_config=None, **kwargs):
         """
         This is class for building convLSTM architecture by using updated hparameters
         args:
-             hparams_dict : dict, the dictionary contains the hparaemters names and values
+            hparams_dict : dict, the dictionary contains the hparaemters names and values
         """
-        super().__init__(hparams_dict)
-        self.get_hparams()
+        super().__init__(hparams_dict_config)
 
 
-    def get_hparams(self):
+    def parse_hparams(self, hparams):
         """
-        obtain the hparams from the dict to the class variables
+        obtain the hyper-parameters from the dictionary
         """
-        method = BaseModels.get_hparams.__name__
 
         try:
-            self.context_frames = self.hparams.context_frames
-            self.sequence_length = self.hparams.sequence_length
-            self.max_epochs = self.hparams.max_epochs
-            self.batch_size = self.hparams.batch_size
-            self.shuffle_on_val = self.hparams.shuffle_on_val
-            self.loss_fun = self.hparams.loss_fun
-            self.opt_var = self.hparams.opt_var
-            self.learning_rate = self.hparams.lr
+            self.context_frames = hparams.context_frames
+            self.sequence_length = hparams.sequence_length
+            self.max_epochs = hparams.max_epochs
+            self.batch_size = hparams.batch_size
+            self.shuffle_on_val = hparams.shuffle_on_val
+            self.loss_fun = hparams.loss_fun
+            self.opt_var = hparams.opt_var
+            self.lr = hparams.lr
             self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
             print("The model hparams have been parsed successfully! ")
-        except Exception as error:
-           print("Method %{}: error: {}".format(method, error))
-           raise("Method %{}: the hparameter dictionary must include "
-                 "'context_frames','max_epochs','batch_size','shuffle_on_val' 'loss_fun',"
-                 "'opt_var', 'lr', 'opt_var'".format(method))
+        except Exception as e:
+            raise ValueError(f"missing hyperparameter: {e.args[0]}")
+
 
 
-    def build_graph(self, x):
-        self.is_build_graph = False
+    def build_graph(self, x:tf.Tensor):
         self.inputs = x
-        self.global_step = tf.train.get_or_create_global_step()
         original_global_variables = tf.global_variables()
+        x_hat = self.build_model(x)
+        self.total_loss = self.get_loss(x, x_hat)
+        self.train_op = self.optimizer(self.total_loss)
+        self.outputs["gen_images"] = x_hat
+        self.summary_op = self.summary(total_loss = self.total_loss)
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        self._is_build_graph_set = True
+        return self._is_build_graph_set
 
-        self.build_model()
-        #This is the loss function (MSE):
 
+    def get_loss(self, x:tf.Tensor, x_hat:tf.Tensor)->tf.Tensor:
+        """
+        :param x    : Input tensors
+        :param x_hat: Prediction/output tensors
+        :return     : the loss function
+        """
+        #This is the loss function (MSE):
         #Optimize all target variables/channels
         if self.opt_var == "all":
-            x = self.inputs[:, self.context_frames:, :, :, :]
-            x_hat = self.x_hat_predict_frames[:, :, :, :, :]
-            print ("The model is optimzied on all the variables in the loss function")
+            x = x[:, self.context_frames:, :, :, :]
+            print("The model is optimzied on all the variables in the loss function")
         elif self.opt_var != "all" and isinstance(self.opt_var, str):
             self.opt_var = int(self.opt_var)
-            print ("The model is optimized on the {} variable in the loss function".format(self.opt_var))
-            x = self.inputs[:, self.context_frames:, :, :, self.opt_var]
-            x_hat = self.x_hat_predict_frames[:, :, :, :, self.opt_var]
+            print("The model is optimized on the {} variable in the loss function".format(self.opt_var))
+            x = x[:, self.context_frames:, :, :, self.opt_var]
+            x_hat = x_hat[:, :, :, :, self.opt_var]
         else:
             raise ValueError("The opt var in the hyperparameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables")
 
         if self.loss_fun == "mse":
-            self.total_loss = tf.reduce_mean(tf.square(x - x_hat))
+            total_loss = tf.reduce_mean(tf.square(x - x_hat))
         elif self.loss_fun == "cross_entropy":
             x_flatten = tf.reshape(x, [-1])
             x_hat_predict_frames_flatten = tf.reshape(x_hat, [-1])
             bce = tf.keras.losses.BinaryCrossentropy()
-            self.total_loss = bce(x_flatten, x_hat_predict_frames_flatten)
+            total_loss = bce(x_flatten, x_hat_predict_frames_flatten)
         else:
             raise ValueError("Loss function is not selected properly, you should chose either 'mse' or 'cross_entropy'")
+        return total_loss
 
-        #This is the loss for only all the channels(temperature, geo500, pressure)
-        #self.total_loss = tf.reduce_mean(
-        #    tf.square(self.x[:, self.context_frames:,:,:,:] - self.x_hat_predict_frames[:,:,:,:,:]))            
- 
-        self.train_op = tf.train.AdamOptimizer(
-            learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
-
-        self.outputs["gen_images"] = self.x_hat
-        # Summary op
-        self.loss_summary = tf.summary.scalar("total_loss", self.total_loss)
-        self.summary_op = tf.summary.merge_all()
-        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-        self.saveable_variables = [self.global_step] + global_variables
-        self.is_build_graph = True
-        return self.is_build_graph
 
-    def build_model(self):
+
+    def summary(self, total_loss)->None:
+        """
+        return the summary operation can be used for TensorBoard
+        """
+        tf.summary.scalar("total_loss", total_loss)
+        summary_op = tf.summary.merge_all()
+        return summary_op
+
+    def build_model(self, x: tf.Tensor):
         network_template = tf.make_template('network',
                                             VanillaConvLstmVideoPredictionModel.convLSTM_cell)  # make the template to share the variables
+
+        x_hat = VanillaConvLstmVideoPredictionModel.convLSTM_network(x,
+                                                                     self.sequence_length,
+                                                                     self.context_frames,
+                                                                     network_template)
+        return x_hat
+
+
+    @staticmethod
+    def convLSTM_network(x:tf.Tensor, sequence_length:int, context_frames:int, network_template:tf.make_template)->tf.Tensor:
+
         # create network
         x_hat = []
-        
-        #This is for training (optimization of convLSTM layer)
+
+        # This is for training (optimization of convLSTM layer)
         hidden_g = None
-        for i in range(self.sequence_length-1):
-            if i < self.context_frames:
-                x_1_g, hidden_g = network_template(self.inputs[:, i, :, :, :], hidden_g)
+        for i in range(sequence_length - 1):
+            if i < context_frames:
+                x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g)
             else:
                 x_1_g, hidden_g = network_template(x_1_g, hidden_g)
             x_hat.append(x_1_g)
 
         # pack them all together
         x_hat = tf.stack(x_hat)
-        self.x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4])  # change first dim with sec dim
-        self.x_hat_predict_frames = self.x_hat[:, self.context_frames-1:, :, :, :]
+        x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4])  # change first dim with sec dim
+        x_hat = x_hat[:, context_frames - 1:, :, :, :]
+        return x_hat
+
 
     @staticmethod
-    def convLSTM_cell(inputs, hidden):
+
+    def convLSTM_cell(inputs:tf.Tensor, hidden:tf.Tensor):
         """
         SPDX-FileCopyrightText: loliverhennigh 
         SPDX-License-Identifier: Apache-2.0
         The following function was revised based on the github https://github.com/loliverhennigh/Convolutional-LSTM-in-Tensorflow 
         """
         y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
-        channels = inputs.get_shape()[-1]
+
         # conv lstm cell
         cell_shape = y_0.get_shape().as_list()
         channels = cell_shape[-1]
@@ -137,24 +155,3 @@ class VanillaConvLstmVideoPredictionModel(BaseModels):
         #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction
         x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid")
         return x_hat, hidden
-
-    @staticmethod
-    def set_and_check_pred_frames(seq_length, context_frames):
-        """
-        Checks if sequence length and context_frames are set properly and returns number of frames to be predicted.
-        :param seq_length: number of frames/images per sequences
-        :param context_frames: number of context frames/images
-        :return: number of predicted frames
-        """
-
-        method = VanillaConvLstmVideoPredictionModel.set_and_check_pred_frames.__name__
-
-        # sanity checks
-        assert isinstance(seq_length, int), "%{0}: Sequence length (seq_length) must be an integer".format(method)
-        assert isinstance(context_frames, int), "%{0}: Number of context frames must be an integer".format(method)
-
-        if seq_length > context_frames:
-            return seq_length-context_frames
-        else:
-            raise ValueError("%{0}: Sequence length ({1}) must be larger than context frames ({2})."
-                             .format(method, seq_length, context_frames))
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py
deleted file mode 100644
index 5365bcb8c3c6e739bb2ba6ae04813e361e297921..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py
+++ /dev/null
@@ -1,179 +0,0 @@
-# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
-#
-# SPDX-License-Identifier: MIT
-
-__email__ = "b.gong@fz-juelich.de"
-__author__ = "Bing Gong"
-__date__ = "2020-09-01"
-
-from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
-import tensorflow as tf
-from model_modules.video_prediction.layers import layer_def as ld
-from tensorflow.contrib.training import HParams
-
-
-class VanillaVAEVideoPredictionModel(object):
-    def __init__(self, hparams_dict=None, **kwargs):
-        """
-        This is class for building convLSTM architecture by using updated hparameters
-        args:
-             hparams_dict: dict, the dictionary contains the hparaemters names and values
-        """
-
-        self.hparams_dict = hparams_dict
-        self.hparams = self.parse_hparams()
-        self.learning_rate = self.hparams.lr
-        self.total_loss = None
-        self.context_frames = self.hparams.context_frames
-        self.sequence_length = self.hparams.sequence_length
-        self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
-        self.max_epochs = self.hparams.max_epochs
-        self.nz = self.hparams.nz
-        self.loss_fun = self.hparams.loss_fun
-        self.batch_size = self.hparams.batch_size 
-        self.shuffle_on_val = self.hparams.shuffle_on_val
-        self.weight_recon = self.hparams.weight_recon 
-        
-    def get_default_hparams(self):
-        return HParams(**self.get_default_hparams_dict())
-
-    def parse_hparams(self):
-        """
-        Parse the hparams setting to ovoerride the default ones
-        """
-        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
-        return parsed_hparams
-
-
-    def get_default_hparams_dict(self):
-        """
-        The function that contains default hparams
-        Returns:
-            A dict with the following hyperparameters.
-            context_frames  : the number of ground-truth frames to pass in at start.
-            sequence_length : the number of frames in the video sequence 
-            max_epochs      : the number of epochs to train model
-            lr              : learning rate
-            loss_fun        : the loss function
-        """
-        hparams = dict(
-            context_frames=10,
-            sequence_length=24,
-            max_epochs = 20,
-            batch_size = 4,
-            lr = 0.001,
-            nz = 16,
-            loss_fun = "cross_entropy",
-            weight_recon = 1,
-            shuffle_on_val= True,
-        )
-        return hparams
-
-
-    def build_graph(self,x):  
-        self.x = x["images"]
-        self.global_step = tf.train.get_or_create_global_step()
-        original_global_variables = tf.global_variables()
-        self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all()
-        #This is the loss function (RMSE):
-        #This is loss function only for 1 channel (temperature RMSE)
-        if self.loss_fun == "rmse":
-            self.recon_loss = tf.reduce_mean(tf.square(self.x[:,self.context_frames:,:,:,0] - self.x_hat[:,self.context_frames-1:,:,:,0]))
-        elif self.loss_fun == "cross_entropy":
-            x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1])
-            x_hat_predict_frames_flatten = tf.reshape(self.x_hat[:,self.context_frames-1:,:,:,0],[-1])
-            bce = tf.keras.losses.BinaryCrossentropy()
-            self.recon_loss = bce(x_flatten,x_hat_predict_frames_flatten)
-        else:
-            raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")        
-        
-        latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq - tf.square(self.z_mu) -tf.exp(self.z_log_sigma_sq), axis=1)
-        self.latent_loss = tf.reduce_mean(latent_loss)
-        self.total_loss = self.weight_recon * self.recon_loss + self.latent_loss
-        self.train_op = tf.train.AdamOptimizer(
-            learning_rate = self.learning_rate).minimize(self.total_loss, global_step=self.global_step)
-        # Build a saver
-        self.losses = {
-            'recon_loss': self.recon_loss,
-            'latent_loss': self.latent_loss,
-            'total_loss': self.total_loss,
-        }
-
-        # Summary op
-        self.loss_summary = tf.summary.scalar("recon_loss", self.recon_loss)
-        self.loss_summary = tf.summary.scalar("latent_loss", self.latent_loss)
-        self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss)
-        self.summary_op = tf.summary.merge_all()
-        self.outputs = {}
-        self.outputs["gen_images"] = self.x_hat
-        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-        self.saveable_variables = [self.global_step] + global_variables
-        return None
-
-
-    @staticmethod
-    def vae_arc3(x,l_name=0,nz=16):
-        """
-        VAE for one timestamp of sequence
-        args:
-             x      : input tensor, shape is [batch_size,height, width, channel]
-             l_name :  int, default is 0, the sequence index
-             nz     :  int, default is 16, the latent space 
-        return 
-             x_hat  :  tensor, is the predicted value 
-             z_mu   :  tensor, means values of latent space 
-             z_log_sigma_sq: sensor, the variances of latent space
-             z      :  tensor, the normal distribution with z_mu, z-log_sigma_sq
-
-        """
-        input_shape = x.get_shape().as_list()
-        input_width = input_shape[2]
-        input_height = input_shape[1]
-        print("input_heights:",input_height)
-        seq_name = "sq_" + str(l_name) + "_"
-        conv1 = ld.conv_layer(inputs=x, kernel_size=3, stride=2, num_features=8, idx=seq_name + "encode_1")
-        conv1_shape = conv1.get_shape().as_list()
-        print("conv1_shape:",conv1_shape)
-        assert conv1_shape[3] == 8 #Features check
-        assert conv1_shape[1] == int((input_height - 3 + 1)/2) + 1 #[(Input_volumn - kernel_size + padding)/stride] + 1
-        conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2")
-        conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3")
-        conv4 = tf.layers.Flatten()(conv3)
-        conv3_shape = conv3.get_shape().as_list()
-        z_mu = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m")
-        z_log_sigma_sq = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma')
-        eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32)
-        z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps        
-        z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") 
-        z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]])
-        conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, seq_name + "decode_5")  
-        conv6  = ld.transpose_conv_layer(conv5, 3, 1, 8,seq_name + "decode_6")
-        x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8")
-        x_hat_shape = x_hat.get_shape().as_list()
-        pred_height = x_hat_shape[1]
-        pred_width = x_hat_shape[2]
-        assert pred_height == input_height
-        assert pred_width == input_width
-        return x_hat, z_mu, z_log_sigma_sq, z
-
-    def vae_arc_all(self):
-        """
-        Build architecture for all the sequences
-        """
-        X = []
-        z_log_sigma_sq_all = []
-        z_mu_all = []
-        for i in range(self.sequence_length-1):
-            q, z_mu, z_log_sigma_sq, z = VanillaVAEVideoPredictionModel.vae_arc3(self.x[:, i, :, :, :], l_name=i, nz=self.nz)
-            X.append(q)
-            z_log_sigma_sq_all.append(z_log_sigma_sq)
-            z_mu_all.append(z_mu)
-        x_hat = tf.stack(X, axis = 1)
-        x_hat_shape = x_hat.get_shape().as_list()
-        print("x-ha-shape:",x_hat_shape)
-        z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1)
-        z_mu_all = tf.stack(z_mu_all, axis = 1)
-        return x_hat, z_log_sigma_sq_all, z_mu_all
-
-
-
diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..033bd499ef5717391d9fab8cced03b268135b4c9
--- /dev/null
+++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py
@@ -0,0 +1,91 @@
+# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
+#
+# SPDX-License-Identifier: MIT
+# Weather Bench models
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong"
+__date__ = "2021-04-13"
+
+import tensorflow as tf
+from model_modules.video_prediction.layers import layer_def as ld
+from model_modules.video_prediction.losses import l1_loss
+from .our_base_model import BaseModels
+
+class WeatherBenchModel(BaseModels):
+
+    def __init__(self, hparams_dict_config: dict=None, mode:str="train", **kwargs):
+        """
+        This is class for building weatherBench architecture by using updated hparameters
+        args:
+             mode         :"train" or "val", side note: mode may not be used in the convLSTM,
+                          this will be a useful argument for the GAN-based model
+             hparams_dict :The dictionary contains the hyper-parameters names and values
+        """
+        super().__init__(hparams_dict_config, mode)
+
+
+    def parse_hparams(self, hparams):
+        """
+        Obtain the hyper-parameters from the dict to the class variables
+        """
+        try:
+            self.context_frames = self.hparams.context_frames
+            self.max_epochs = self.hparams.max_epochs
+            self.batch_size = self.hparams.batch_size
+            self.shuffle_on_val = self.hparams.shuffle_on_val
+            self.loss_fun = self.hparams.loss_fun
+            self.learning_rate = self.hparams.lr
+            self.sequence_length = self.hparams.sequence_length
+        except Exception as error:
+           print("error: {}".format(error))
+           raise ValueError("Method %{}: the hyper-parameter dictionary "
+                            "must include parameters above")
+
+    def get_loss(self, x: tf.Tensor, x_hat: tf.Tensor):
+        # Loss
+        total_loss = l1_loss(x[:, 1, :, :, :], x_hat[:, :, :, :])
+        return total_loss
+
+    def optimizer(self, total_loss):
+        return tf.train.AdamOptimizer(
+            learning_rate = self.learning_rate).minimize(total_loss,
+                                                         global_step =
+                                                         self.global_step)
+
+    def build_model(self, x):
+        """Fully convolutional network"""
+        x = x[:, 0, :, :, :]
+        _idx = 0
+        filters = [64, 64, 64, 64, 2]
+        kernels = [5, 5, 5, 5, 5]
+
+        for f, k in zip(filters[:-1], kernels[:-1]):
+            with tf.variable_scope("conv_layer_"+str(_idx), reuse=tf.AUTO_REUSE):
+                x = ld.conv_layer(x, kernel_size=k, stride=1,
+                                  num_features=f,
+                                  idx="conv_layer_"+str(_idx),
+                                  activate="leaky_relu")
+            _idx += 1
+        with tf.variable_scope("conv_last_layer", reuse=tf.AUTO_REUSE):
+            output = ld.conv_layer(x, kernel_size=kernels[-1],
+                                   stride=1, num_features=filters[-1],
+                                   idx="conv_last_layer", activate="linear")
+
+        return output
+
+
+    def forecast(self, x, forecast_time):
+        x_hat = []
+
+        for i in range(forecast_time):
+            if i == 0:
+                x_pred = self.build_model(x[:, i, :, :, :], filters, kernels)
+            else:
+                x_pred = self.build_model(x_pred, filters, kernels)
+            x_hat.append(x_pred)
+
+        x_hat = tf.stack(x_hat)
+        x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4])
+        return x_hat
+
+
diff --git a/video_prediction_tools/model_modules/video_prediction/ops.py b/video_prediction_tools/model_modules/video_prediction/ops.py
deleted file mode 100644
index e33f559ebfcf1e1d286136feb6b8ef95a8410564..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/ops.py
+++ /dev/null
@@ -1,1102 +0,0 @@
-# SPDX-FileCopyrightText: 2018, alexlee-gk
-#
-# SPDX-License-Identifier: MIT
-
-import numpy as np
-import tensorflow as tf
-
-
-def dense(inputs, units, use_spectral_norm=False, use_bias=True):
-    with tf.variable_scope('dense'):
-        input_shape = inputs.get_shape().as_list()
-        kernel_shape = [input_shape[1], units]
-        kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
-        if use_spectral_norm:
-            kernel = spectral_normed_weight(kernel)
-        outputs = tf.matmul(inputs, kernel)
-        if use_bias:
-            bias = tf.get_variable('bias', [units], dtype=tf.float32, initializer=tf.zeros_initializer())
-            outputs = tf.nn.bias_add(outputs, bias)
-        return outputs
-
-
-def pad1d(inputs, size, strides=(1,), padding='SAME', mode='CONSTANT'):
-    size = list(size) if isinstance(size, (tuple, list)) else [size]
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides]
-    input_shape = inputs.get_shape().as_list()
-    assert len(input_shape) == 3
-    in_width = input_shape[1]
-    if padding in ('SAME', 'FULL'):
-        if in_width % strides[0] == 0:
-            pad_along_width = max(size[0] - strides[0], 0)
-        else:
-            pad_along_width = max(size[0] - (in_width % strides[0]), 0)
-        if padding == 'SAME':
-            pad_left = pad_along_width // 2
-            pad_right = pad_along_width - pad_left
-        else:
-            pad_left = pad_along_width
-            pad_right = pad_along_width
-        padding_pattern = [[0, 0],
-                           [pad_left, pad_right],
-                           [0, 0]]
-        outputs = tf.pad(inputs, padding_pattern, mode=mode)
-    elif padding == 'VALID':
-        outputs = inputs
-    else:
-        raise ValueError("Invalid padding scheme %s" % padding)
-    return outputs
-
-
-def conv1d(inputs, filters, kernel_size, strides=(1,), padding='SAME', kernel=None, use_bias=True):
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size]
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides]
-    input_shape = inputs.get_shape().as_list()
-    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
-    if kernel is None:
-        with tf.variable_scope('conv1d'):
-            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
-    else:
-        if kernel_shape != kernel.get_shape().as_list():
-            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
-    if padding == 'FULL':
-        inputs = pad1d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
-        padding = 'VALID'
-    stride, = strides
-    outputs = tf.nn.conv1d(inputs, kernel, stride, padding=padding)
-    if use_bias:
-        with tf.variable_scope('conv1d'):
-            bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
-            outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def pad2d_paddings(inputs, size, strides=(1, 1), rate=(1, 1), padding='SAME'):
-    """
-    Computes the paddings for a 4-D tensor according to the convolution padding algorithm.
-
-    See pad2d.
-
-    Reference:
-        https://www.tensorflow.org/api_guides/python/nn#convolution
-        https://www.tensorflow.org/api_docs/python/tf/nn/with_space_to_batch
-    """
-    size = np.array(size) if isinstance(size, (tuple, list)) else np.array([size] * 2)
-    strides = np.array(strides) if isinstance(strides, (tuple, list)) else np.array([strides] * 2)
-    rate = np.array(rate) if isinstance(rate, (tuple, list)) else np.array([rate] * 2)
-    if np.any(strides > 1) and np.any(rate > 1):
-        raise ValueError("strides > 1 not supported in conjunction with rate > 1")
-    input_shape = inputs.get_shape().as_list()
-    assert len(input_shape) == 4
-    input_size = np.array(input_shape[1:3])
-    if padding in ('SAME', 'FULL'):
-        if np.any(rate > 1):
-            # We have two padding contributions. The first is used for converting "SAME"
-            # to "VALID". The second is required so that the height and width of the
-            # zero-padded value tensor are multiples of rate.
-
-            # Spatial dimensions of the filters and the upsampled filters in which we
-            # introduce (rate - 1) zeros between consecutive filter values.
-            dilated_size = size + (size - 1) * (rate - 1)
-            pad = dilated_size - 1
-        else:
-            pad = np.where(input_size % strides == 0,
-                           np.maximum(size - strides, 0),
-                           np.maximum(size - (input_size % strides), 0))
-        if padding == 'SAME':
-            # When full_padding_shape is odd, we pad more at end, following the same
-            # convention as conv2d.
-            pad_start = pad // 2
-            pad_end = pad - pad_start
-        else:
-            pad_start = pad
-            pad_end = pad
-        if np.any(rate > 1):
-            # More padding so that rate divides the height and width of the input.
-            # TODO: not sure if this is correct when padding == 'FULL'
-            orig_pad_end = pad_end
-            full_input_size = input_size + pad_start + orig_pad_end
-            pad_end_extra = (rate - full_input_size % rate) % rate
-            pad_end = orig_pad_end + pad_end_extra
-        paddings = [[0, 0],
-                    [pad_start[0], pad_end[0]],
-                    [pad_start[1], pad_end[1]],
-                    [0, 0]]
-    elif padding == 'VALID':
-        paddings = [[0, 0]] * 4
-    else:
-        raise ValueError("Invalid padding scheme %s" % padding)
-    return paddings
-
-
-def pad2d(inputs, size, strides=(1, 1), rate=(1, 1), padding='SAME', mode='CONSTANT'):
-    """
-    Pads a 4-D tensor according to the convolution padding algorithm.
-
-    Convolution with a padding scheme
-        conv2d(..., padding=padding)
-    is equivalent to zero-padding of the input with such scheme, followed by
-    convolution with 'VALID' padding
-        padded = pad2d(..., padding=padding, mode='CONSTANT')
-        conv2d(padded, ..., padding='VALID')
-
-    Args:
-        inputs: A 4-D tensor of shape
-            `[batch, in_height, in_width, in_channels]`.
-        padding: A string, either 'VALID', 'SAME', or 'FULL'. The padding algorithm.
-        mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive).
-
-    Returns:
-        A 4-D tensor.
-
-    Reference:
-        https://www.tensorflow.org/api_guides/python/nn#convolution
-    """
-    paddings = pad2d_paddings(inputs, size, strides=strides, rate=rate, padding=padding)
-    if paddings == [[0, 0]] * 4:
-        outputs = inputs
-    else:
-        outputs = tf.pad(inputs, paddings, mode=mode)
-    return outputs
-
-
-def local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
-            kernel=None, flip_filters=False,
-            use_bias=True, channelwise=False):
-    """
-    2-D locally connected operation.
-
-    Works similarly to 2-D convolution except that the weights are unshared, that is, a different set of filters is
-    applied at each different patch of the input.
-
-    Args:
-        inputs: A 4-D tensor of shape
-            `[batch, in_height, in_width, in_channels]`.
-        kernel: A 6-D or 7-D tensor of shape
-            `[in_height, in_width, kernel_size[0], kernel_size[1], in_channels, filters]` or
-            `[batch, in_height, in_width, kernel_size[0], kernel_size[1], in_channels, filters]`.
-
-    Returns:
-        A 4-D tensor.
-    """
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    if strides != [1, 1]:
-        raise NotImplementedError
-    if padding == 'FULL':
-        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
-        padding = 'VALID'
-    input_shape = inputs.get_shape().as_list()
-    if padding == 'SAME':
-        output_shape = input_shape[:3] + [filters]
-    elif padding == 'VALID':
-        output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters]
-    else:
-        raise ValueError("Invalid padding scheme %s" % padding)
-
-    if channelwise:
-        if filters not in (input_shape[-1], 1):
-            raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise "
-                             "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1]))
-        kernel_shape = output_shape[1:3] + kernel_size + [filters]
-    else:
-        kernel_shape = output_shape[1:3] + kernel_size + [input_shape[-1], filters]
-    if kernel is None:
-        with tf.variable_scope('local2d'):
-            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
-    else:
-        if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape):
-            raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s"
-                             % (tuple(kernel_shape), tuple([input_shape[0]] + kernel_shape), tuple(kernel.get_shape().as_list())))
-
-    outputs = []
-    for i in range(kernel_size[0]):
-        filter_h_ind = -i-1 if flip_filters else i
-        if padding == 'VALID':
-            ii = i
-        else:
-            ii = i - (kernel_size[0] // 2)
-        input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1]))
-        output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii)
-        assert 0 <= output_h_slice.start < output_shape[1]
-        assert 0 < output_h_slice.stop <= output_shape[1]
-
-        for j in range(kernel_size[1]):
-            filter_w_ind = -j-1 if flip_filters else j
-            if padding == 'VALID':
-                jj = j
-            else:
-                jj = j - (kernel_size[1] // 2)
-            input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2]))
-            output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj)
-            assert 0 <= output_w_slice.start < output_shape[2]
-            assert 0 < output_w_slice.stop <= output_shape[2]
-            if channelwise:
-                inc = inputs[:, input_h_slice, input_w_slice, :] * \
-                      kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :]
-            else:
-                inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] *
-                                    kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :, :], axis=-2)
-            # equivalent to this
-            # outputs[:, output_h_slice, output_w_slice, :] += inc
-            paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop],
-                        [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]]
-            outputs.append(tf.pad(inc, paddings))
-    outputs = tf.add_n(outputs)
-    if use_bias:
-        with tf.variable_scope('local2d'):
-            bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer())
-            outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def separable_local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
-                      vertical_kernel=None, horizontal_kernel=None, flip_filters=False,
-                      use_bias=True, channelwise=False):
-    """
-    2-D locally connected operation with separable filters.
-
-    Note that, unlike tf.nn.separable_conv2d, this is spatial separability between dimensions 1 and 2.
-
-    Args:
-        inputs: A 4-D tensor of shape
-            `[batch, in_height, in_width, in_channels]`.
-        vertical_kernel: A 5-D or 6-D tensor of shape
-            `[in_height, in_width, kernel_size[0], in_channels, filters]` or
-            `[batch, in_height, in_width, kernel_size[0], in_channels, filters]`.
-        horizontal_kernel: A 5-D or 6-D tensor of shape
-            `[in_height, in_width, kernel_size[1], in_channels, filters]` or
-            `[batch, in_height, in_width, kernel_size[1], in_channels, filters]`.
-
-    Returns:
-        A 4-D tensor.
-    """
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    if strides != [1, 1]:
-        raise NotImplementedError
-    if padding == 'FULL':
-        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
-        padding = 'VALID'
-    input_shape = inputs.get_shape().as_list()
-    if padding == 'SAME':
-        output_shape = input_shape[:3] + [filters]
-    elif padding == 'VALID':
-        output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters]
-    else:
-        raise ValueError("Invalid padding scheme %s" % padding)
-
-    kernels = [vertical_kernel, horizontal_kernel]
-    for i, (kernel_type, kernel_length, kernel) in enumerate(zip(['vertical', 'horizontal'], kernel_size, kernels)):
-        if channelwise:
-            if filters not in (input_shape[-1], 1):
-                raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise "
-                                 "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1]))
-            kernel_shape = output_shape[1:3] + [kernel_length, filters]
-        else:
-            kernel_shape = output_shape[1:3] + [kernel_length, input_shape[-1], filters]
-        if kernel is None:
-            with tf.variable_scope('separable_local2d'):
-                kernel = tf.get_variable('%s_kernel' % kernel_type, kernel_shape, dtype=tf.float32,
-                                         initializer=tf.truncated_normal_initializer(stddev=0.02))
-                kernels[i] = kernel
-        else:
-            if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape):
-                raise ValueError("Expecting %s kernel with shape %s or %s but instead got kernel with shape %s"
-                                 % (kernel_type,
-                                    tuple(kernel_shape), tuple([input_shape[0]] +kernel_shape),
-                                    tuple(kernel.get_shape().as_list())))
-
-    outputs = []
-    for i in range(kernel_size[0]):
-        filter_h_ind = -i-1 if flip_filters else i
-        if padding == 'VALID':
-            ii = i
-        else:
-            ii = i - (kernel_size[0] // 2)
-        input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1]))
-        output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii)
-        assert 0 <= output_h_slice.start < output_shape[1]
-        assert 0 < output_h_slice.stop <= output_shape[1]
-
-        for j in range(kernel_size[1]):
-            filter_w_ind = -j-1 if flip_filters else j
-            if padding == 'VALID':
-                jj = j
-            else:
-                jj = j - (kernel_size[1] // 2)
-            input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2]))
-            output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj)
-            assert 0 <= output_w_slice.start < output_shape[2]
-            assert 0 < output_w_slice.stop <= output_shape[2]
-            if channelwise:
-                inc = inputs[:, input_h_slice, input_w_slice, :] * \
-                      kernels[0][..., output_h_slice, output_w_slice, filter_h_ind, :] * \
-                      kernels[1][..., output_h_slice, output_w_slice, filter_w_ind, :]
-            else:
-                inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] *
-                                    kernels[0][..., output_h_slice, output_w_slice, filter_h_ind, :, :] *
-                                    kernels[1][..., output_h_slice, output_w_slice, filter_w_ind, :, :],
-                                    axis=-2)
-            # equivalent to this
-            # outputs[:, output_h_slice, output_w_slice, :] += inc
-            paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop],
-                        [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]]
-            outputs.append(tf.pad(inc, paddings))
-    outputs = tf.add_n(outputs)
-    if use_bias:
-        with tf.variable_scope('separable_local2d'):
-            bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer())
-            outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def kronecker_local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
-                      kernels=None, flip_filters=False, use_bias=True, channelwise=False):
-    """
-    2-D locally connected operation with filters represented as a kronecker product of smaller filters
-
-    Args:
-        inputs: A 4-D tensor of shape
-            `[batch, in_height, in_width, in_channels]`.
-        kernel: A list of 6-D or 7-D tensors of shape
-            `[in_height, in_width, kernel_size[i][0], kernel_size[i][1], in_channels, filters]` or
-            `[batch, in_height, in_width, kernel_size[i][0], kernel_size[i][1], in_channels, filters]`.
-
-    Returns:
-        A 4-D tensor.
-    """
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    if strides != [1, 1]:
-        raise NotImplementedError
-    if padding == 'FULL':
-        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
-        padding = 'VALID'
-    input_shape = inputs.get_shape().as_list()
-    if padding == 'SAME':
-        output_shape = input_shape[:3] + [filters]
-    elif padding == 'VALID':
-        output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters]
-    else:
-        raise ValueError("Invalid padding scheme %s" % padding)
-
-    if channelwise:
-        if filters not in (input_shape[-1], 1):
-            raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise "
-                             "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1]))
-        kernel_shape = output_shape[1:3] + kernel_size + [filters]
-        factor_kernel_shape = output_shape[1:3] + [None, None, filters]
-    else:
-        kernel_shape = output_shape[1:3] + kernel_size + [input_shape[-1], filters]
-        factor_kernel_shape = output_shape[1:3] + [None, None, input_shape[-1], filters]
-    if kernels is None:
-        with tf.variable_scope('kronecker_local2d'):
-            kernels = [tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                       initializer=tf.truncated_normal_initializer(stddev=0.02))]
-        filter_h_lengths = [kernel_size[0]]
-        filter_w_lengths = [kernel_size[1]]
-    else:
-        for kernel in kernels:
-            if not ((len(kernel.shape) == len(factor_kernel_shape) and
-                    all(((k == f) or f is None) for k, f in zip(kernel.get_shape().as_list(), factor_kernel_shape))) or
-                    (len(kernel.shape) == (len(factor_kernel_shape) + 1) and
-                    all(((k == f) or f is None) for k, f in zip(kernel.get_shape().as_list(), [input_shape[0]] +factor_kernel_shape)))):
-                raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s"
-                                 % (tuple(factor_kernel_shape), tuple([input_shape[0]] + factor_kernel_shape),
-                                    tuple(kernel.get_shape().as_list())))
-        if channelwise:
-            filter_h_lengths, filter_w_lengths = zip(*[kernel.get_shape().as_list()[-3:-1] for kernel in kernels])
-        else:
-            filter_h_lengths, filter_w_lengths = zip(*[kernel.get_shape().as_list()[-4:-2] for kernel in kernels])
-        if [np.prod(filter_h_lengths), np.prod(filter_w_lengths)] != kernel_size:
-            raise ValueError("Expecting kernel size %s but instead got kernel size %s"
-                             % (tuple(kernel_size), tuple([np.prod(filter_h_lengths), np.prod(filter_w_lengths)])))
-
-    def get_inds(ind, lengths):
-        inds = []
-        for i in range(len(lengths)):
-            curr_ind = int(ind)
-            for j in range(len(lengths) - 1, i, -1):
-                curr_ind //= lengths[j]
-            curr_ind %= lengths[i]
-            inds.append(curr_ind)
-        return inds
-
-    outputs = []
-    for i in range(kernel_size[0]):
-        if padding == 'VALID':
-            ii = i
-        else:
-            ii = i - (kernel_size[0] // 2)
-        input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1]))
-        output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii)
-        assert 0 <= output_h_slice.start < output_shape[1]
-        assert 0 < output_h_slice.stop <= output_shape[1]
-
-        for j in range(kernel_size[1]):
-            if padding == 'VALID':
-                jj = j
-            else:
-                jj = j - (kernel_size[1] // 2)
-            input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2]))
-            output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj)
-            assert 0 <= output_w_slice.start < output_shape[2]
-            assert 0 < output_w_slice.stop <= output_shape[2]
-            kernel_slice = 1.0
-            for filter_h_ind, filter_w_ind, kernel in zip(get_inds(i, filter_h_lengths), get_inds(j, filter_w_lengths), kernels):
-                if flip_filters:
-                    filter_h_ind = -filter_h_ind-1
-                    filter_w_ind = -filter_w_ind-1
-                if channelwise:
-                    kernel_slice *= kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :]
-                else:
-                    kernel_slice *= kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :, :]
-            if channelwise:
-                inc = inputs[:, input_h_slice, input_w_slice, :] * kernel_slice
-            else:
-                inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] * kernel_slice, axis=-2)
-            # equivalent to this
-            # outputs[:, output_h_slice, output_w_slice, :] += inc
-            paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop],
-                        [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]]
-            outputs.append(tf.pad(inc, paddings))
-    outputs = tf.add_n(outputs)
-    if use_bias:
-        with tf.variable_scope('kronecker_local2d'):
-            bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer())
-            outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def depthwise_conv2d(inputs, channel_multiplier, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True):
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    input_shape = inputs.get_shape().as_list()
-    kernel_shape = kernel_size + [input_shape[-1], channel_multiplier]
-    if kernel is None:
-        with tf.variable_scope('depthwise_conv2d'):
-            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
-    else:
-        if kernel_shape != kernel.get_shape().as_list():
-            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s"
-                             % (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
-    if padding == 'FULL':
-        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
-        padding = 'VALID'
-    outputs = tf.nn.depthwise_conv2d(inputs, kernel, [1] + strides + [1], padding=padding)
-    if use_bias:
-        with tf.variable_scope('depthwise_conv2d'):
-            bias = tf.get_variable('bias', [input_shape[-1] * channel_multiplier], dtype=tf.float32, initializer=tf.zeros_initializer())
-            outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, use_spectral_norm=False):
-    """
-    2-D convolution.
-
-    Args:
-        inputs: A 4-D tensor of shape
-            `[batch, in_height, in_width, in_channels]`.
-        kernel: A 4-D or 5-D tensor of shape
-            `[kernel_size[0], kernel_size[1], in_channels, filters]` or
-            `[batch, kernel_size[0], kernel_size[1], in_channels, filters]`.
-        bias: A 1-D or 2-D tensor of shape
-            `[filters]` or `[batch, filters]`.
-
-    Returns:
-        A 4-D tensor.
-    """
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    input_shape = inputs.get_shape().as_list()
-    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
-    if kernel is None:
-        with tf.variable_scope('conv2d'):
-            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
-            if use_spectral_norm:
-                kernel = spectral_normed_weight(kernel)
-    else:
-        if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape):
-            raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s"
-                             % (tuple(kernel_shape), tuple([input_shape[0]] + kernel_shape), tuple(kernel.get_shape().as_list())))
-    if padding == 'FULL':
-        inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT')
-        padding = 'VALID'
-    if kernel.get_shape().ndims == 4:
-        outputs = tf.nn.conv2d(inputs, kernel, [1] + strides + [1], padding=padding)
-    else:
-        def conv2d_single_fn(args):
-            input_, kernel_ = args
-            input_ = tf.expand_dims(input_, axis=0)
-            output = tf.nn.conv2d(input_, kernel_, [1] + strides + [1], padding=padding)
-            output = tf.squeeze(output, axis=0)
-            return output
-        outputs = tf.map_fn(conv2d_single_fn, [inputs, kernel], dtype=tf.float32)
-    if use_bias:
-        bias_shape = [filters]
-        if bias is None:
-            with tf.variable_scope('conv2d'):
-                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
-        else:
-            if bias.get_shape().as_list() not in (bias_shape, [input_shape[0]] + bias_shape):
-                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s"
-                                 % (tuple(bias_shape), tuple(bias.get_shape().as_list())))
-        if bias.get_shape().ndims == 1:
-            outputs = tf.nn.bias_add(outputs, bias)
-        else:
-            outputs = tf.add(outputs, bias[:, None, None, :])
-    return outputs
-
-
-def deconv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True):
-    """
-    2-D transposed convolution.
-
-    Notes on padding:
-       The equivalent of transposed convolution with full padding is a convolution with valid padding, and
-       the equivalent of transposed convolution with valid padding is a convolution with full padding.
-
-    Reference:
-        http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
-    """
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    input_shape = inputs.get_shape().as_list()
-    kernel_shape = list(kernel_size) + [filters, input_shape[-1]]
-    if kernel is None:
-        with tf.variable_scope('deconv2d'):
-            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
-    else:
-        if kernel_shape != kernel.get_shape().as_list():
-            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
-    if padding == 'FULL':
-        output_h, output_w = [s * (i + 1) - k for (i, k, s) in zip(input_shape[1:3], kernel_size, strides)]
-    elif padding == 'SAME':
-        output_h, output_w = [s * i for (i, s) in zip(input_shape[1:3], strides)]
-    elif padding == 'VALID':
-        output_h, output_w = [s * (i - 1) + k for (i, k, s) in zip(input_shape[1:3], kernel_size, strides)]
-    else:
-        raise ValueError("Invalid padding scheme %s" % padding)
-    output_shape = [input_shape[0], output_h, output_w, filters]
-    outputs = tf.nn.conv2d_transpose(inputs, kernel, output_shape, [1] + strides + [1], padding=padding)
-    if use_bias:
-        with tf.variable_scope('deconv2d'):
-            bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
-            outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def get_bilinear_kernel(strides):
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    strides = np.array(strides)
-    kernel_size = 2 * strides - strides % 2
-    center = strides - (kernel_size % 2 == 1) - 0.5 * (kernel_size % 2 != 1)
-    vertical_kernel = 1 - abs(np.arange(kernel_size[0]) - center[0]) / strides[0]
-    horizontal_kernel = 1 - abs(np.arange(kernel_size[1]) - center[1]) / strides[1]
-    kernel = vertical_kernel[:, None] * horizontal_kernel[None, :]
-    return kernel
-
-
-def upsample2d(inputs, strides, padding='SAME', upsample_mode='bilinear'):
-    if upsample_mode == 'bilinear':
-        single_bilinear_kernel = get_bilinear_kernel(strides).astype(np.float32)
-        input_shape = inputs.get_shape().as_list()
-        bilinear_kernel = tf.matrix_diag(tf.tile(tf.constant(single_bilinear_kernel)[..., None], (1, 1, input_shape[-1])))
-        outputs = deconv2d(inputs, input_shape[-1], kernel_size=single_bilinear_kernel.shape,
-                           strides=strides, kernel=bilinear_kernel, padding=padding, use_bias=False)
-    elif upsample_mode == 'nearest':
-        strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-        input_shape = inputs.get_shape().as_list()
-        inputs_tiled = tf.tile(inputs[:, :, None, :, None, :], [1, 1, strides[0], 1, strides[1], 1])
-        outputs = tf.reshape(inputs_tiled, [input_shape[0], input_shape[1] * strides[0],
-                                            input_shape[2] * strides[1], input_shape[3]])
-    else:
-        raise ValueError("Unknown upsample mode %s" % upsample_mode)
-    return outputs
-
-
-def upsample2d_v2(inputs, strides, padding='SAME', upsample_mode='bilinear'):
-    """
-    Possibly less computationally efficient but more memory efficent than upsampled2d.
-    """
-    if upsample_mode == 'bilinear':
-        single_kernel = get_bilinear_kernel(strides).astype(np.float32)
-    elif upsample_mode == 'nearest':
-        strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-        single_kernel = np.ones(strides, dtype=np.float32)
-    else:
-        raise ValueError("Unknown upsample mode %s" % upsample_mode)
-    input_shape = inputs.get_shape().as_list()
-    kernel = tf.constant(single_kernel)[:, :, None, None]
-    inputs = tf.transpose(inputs, [3, 0, 1, 2])[..., None]
-    outputs = tf.map_fn(lambda input: deconv2d(input, 1, kernel_size=single_kernel.shape,
-                                               strides=strides, kernel=kernel,
-                                               padding=padding, use_bias=False),
-                        inputs, parallel_iterations=input_shape[-1])
-    outputs = tf.transpose(tf.squeeze(outputs, axis=-1), [1, 2, 3, 0])
-    return outputs
-
-
-def upsample_conv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
-                    kernel=None, use_bias=True, bias=None, upsample_mode='bilinear'):
-    """
-    Upsamples the inputs by a factor using bilinear interpolation and the performs conv2d on the upsampled input. This
-    function is more computationally and memory efficient than a naive implementation. Unlike a naive implementation
-    that would upsample the input first, this implementation first convolves the bilinear kernel with the given kernel,
-    and then performs the convolution (actually a deconv2d) with the combined kernel. As opposed to just using deconv2d
-    directly, this function is less prone to checkerboard artifacts thanks to the implicit bilinear upsampling.
-
-    Example:
-        >>> import numpy as np
-        >>> import tensorflow as tf
-        >>> from video_prediction.ops import upsample_conv2d, upsample2d, conv2d, pad2d_paddings
-        >>> inputs_shape = [4, 8, 8, 64]
-        >>> kernel_size = [3, 3]  # for convolution
-        >>> filters = 32  # for convolution
-        >>> strides = [2, 2]  # for upsampling
-        >>> inputs = tf.get_variable("inputs", inputs_shape)
-        >>> kernel = tf.get_variable("kernel", (kernel_size[0], kernel_size[1], inputs_shape[-1], filters))
-        >>> bias = tf.get_variable("bias", (filters,))
-        >>> outputs = upsample_conv2d(inputs, filters, kernel_size=kernel_size, strides=strides, \
-                                      kernel=kernel, bias=bias)
-        >>> # upsample with bilinear interpolation
-        >>> inputs_up = upsample2d(inputs, strides=strides, padding='VALID')
-        >>> # convolve upsampled input with kernel
-        >>> outputs_up = conv2d(inputs_up, filters, kernel_size=kernel_size, strides=(1, 1), \
-                                kernel=kernel, bias=bias, padding='FULL')
-        >>> # crop appropriately
-        >>> same_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='SAME')
-        >>> full_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='FULL')
-        >>> crop_top = (strides[0] - strides[0] % 2) // 2 + full_paddings[1][1] - same_paddings[1][1]
-        >>> crop_left = (strides[1] - strides[1] % 2) // 2 + full_paddings[2][1] - same_paddings[2][1]
-        >>> outputs_up = outputs_up[:, crop_top:crop_top + strides[0] * inputs_shape[1], \
-                                    crop_left:crop_left + strides[1] * inputs_shape[2], :]
-        >>> sess = tf.Session()
-        >>> sess.run(tf.global_variables_initializer())
-        >>> assert np.allclose(*sess.run([outputs, outputs_up]), atol=1e-5)
-
-    """
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    if padding != 'SAME' or upsample_mode != 'bilinear':
-        raise NotImplementedError
-    input_shape = inputs.get_shape().as_list()
-    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
-    if kernel is None:
-        with tf.variable_scope('upsample_conv2d'):
-            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
-    else:
-        if kernel_shape != kernel.get_shape().as_list():
-            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" %
-                             (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
-
-    # convolve bilinear kernel with kernel
-    single_bilinear_kernel = get_bilinear_kernel(strides).astype(np.float32)
-    kernel_transposed = tf.transpose(kernel, (0, 1, 3, 2))
-    kernel_reshaped = tf.reshape(kernel_transposed, kernel_size + [1, input_shape[-1] * filters])
-    kernel_up_reshaped = conv2d(tf.constant(single_bilinear_kernel)[None, :, :, None], input_shape[-1] * filters,
-                                kernel_size=kernel_size, kernel=kernel_reshaped, padding='FULL', use_bias=False)
-    kernel_up = tf.reshape(kernel_up_reshaped,
-                           kernel_up_reshaped.get_shape().as_list()[1:3] + [filters, input_shape[-1]])
-
-    # deconvolve with the bilinearly convolved kernel
-    outputs = deconv2d(inputs, filters, kernel_size=kernel_up.get_shape().as_list()[:2], strides=strides,
-                       kernel=kernel_up, padding='SAME', use_bias=False)
-    if use_bias:
-        if bias is None:
-            with tf.variable_scope('upsample_conv2d'):
-                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
-        else:
-            bias_shape = [filters]
-            if bias_shape != bias.get_shape().as_list():
-                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" %
-                                 (tuple(bias_shape), tuple(bias.get_shape().as_list())))
-        outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def upsample_conv2d_v2(inputs, filters, kernel_size, strides=(1, 1), padding='SAME',
-                       kernel=None, use_bias=True, bias=None, upsample_mode='bilinear'):
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    if padding != 'SAME':
-        raise NotImplementedError
-    input_shape = inputs.get_shape().as_list()
-    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
-    if kernel is None:
-        with tf.variable_scope('upsample_conv2d'):
-            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
-    else:
-        if kernel_shape != kernel.get_shape().as_list():
-            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" %
-                             (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
-
-    inputs_up = upsample2d_v2(inputs, strides=strides, padding='VALID', upsample_mode=upsample_mode)
-    # convolve upsampled input with kernel
-    outputs = conv2d(inputs_up, filters, kernel_size=kernel_size, strides=(1, 1),
-                     kernel=kernel, bias=None, padding='FULL', use_bias=False)
-    # crop appropriately
-    same_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='SAME')
-    full_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='FULL')
-    crop_top = (strides[0] - strides[0] % 2) // 2 + full_paddings[1][1] - same_paddings[1][1]
-    crop_left = (strides[1] - strides[1] % 2) // 2 + full_paddings[2][1] - same_paddings[2][1]
-    outputs = outputs[:, crop_top:crop_top + strides[0] * input_shape[1],
-              crop_left:crop_left + strides[1] * input_shape[2], :]
-
-    if use_bias:
-        if bias is None:
-            with tf.variable_scope('upsample_conv2d'):
-                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
-        else:
-            bias_shape = [filters]
-            if bias_shape != bias.get_shape().as_list():
-                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" %
-                                 (tuple(bias_shape), tuple(bias.get_shape().as_list())))
-        outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def conv3d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', use_bias=True, use_spectral_norm=False):
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 3
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 3
-    input_shape = inputs.get_shape().as_list()
-    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
-    with tf.variable_scope('conv3d'):
-        kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
-        if use_spectral_norm:
-            kernel = spectral_normed_weight(kernel)
-    outputs = tf.nn.conv3d(inputs, kernel, [1] + strides + [1], padding=padding)
-    if use_bias:
-        bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
-        outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def pool2d(inputs, pool_size, strides=(1, 1), padding='SAME', pool_mode='avg'):
-    pool_size = list(pool_size) if isinstance(pool_size, (tuple, list)) else [pool_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    if padding == 'FULL':
-        inputs = pad2d(inputs, pool_size, strides=strides, padding=padding, mode='CONSTANT')
-        padding = 'VALID'
-    if pool_mode == 'max':
-        outputs = tf.nn.max_pool(inputs, [1] + pool_size + [1], [1] + strides + [1], padding=padding)
-    elif pool_mode == 'avg':
-        outputs = tf.nn.avg_pool(inputs, [1] + pool_size + [1], [1] + strides + [1], padding=padding)
-    else:
-        raise ValueError('Invalid pooling mode:', pool_mode)
-    return outputs
-
-
-def conv_pool2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, pool_mode='avg'):
-    """
-    Similar optimization as in upsample_conv2d
-
-    Example:
-        >>> import numpy as np
-        >>> import tensorflow as tf
-        >>> from video_prediction.ops import conv_pool2d, conv2d, pool2d
-        >>> inputs_shape = [4, 16, 16, 32]
-        >>> kernel_size = [3, 3]  # for convolution
-        >>> filters = 64  # for convolution
-        >>> strides = [2, 2]  # for pooling
-        >>> inputs = tf.get_variable("inputs", inputs_shape)
-        >>> kernel = tf.get_variable("kernel", (kernel_size[0], kernel_size[1], inputs_shape[-1], filters))
-        >>> bias = tf.get_variable("bias", (filters,))
-        >>> outputs = conv_pool2d(inputs, filters, kernel_size=kernel_size, strides=strides,
-                                  kernel=kernel, bias=bias, pool_mode='avg')
-        >>> inputs_conv = conv2d(inputs, filters, kernel_size=kernel_size, strides=(1, 1),
-                                 kernel=kernel, bias=bias)
-        >>> outputs_pool = pool2d(inputs_conv, pool_size=strides, strides=strides, pool_mode='avg')
-        >>> sess = tf.Session()
-        >>> sess.run(tf.global_variables_initializer())
-        >>> assert np.allclose(*sess.run([outputs, outputs_pool]), atol=1e-5)
-
-    """
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    if padding != 'SAME' or pool_mode != 'avg':
-        raise NotImplementedError
-    input_shape = inputs.get_shape().as_list()
-    if input_shape[1] % strides[0] or input_shape[2] % strides[1]:
-        raise NotImplementedError("The height and width of the input should be "
-                                  "an integer multiple of the respective stride.")
-    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
-    if kernel is None:
-        with tf.variable_scope('conv_pool2d'):
-            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
-    else:
-        if kernel_shape != kernel.get_shape().as_list():
-            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" %
-                             (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
-
-    # pool kernel
-    kernel_reshaped = tf.reshape(kernel, [1] + kernel_size + [input_shape[-1] * filters])
-    kernel_pool_reshaped = pool2d(kernel_reshaped, pool_size=strides, padding='FULL', pool_mode='avg')
-    kernel_pool = tf.reshape(kernel_pool_reshaped,
-                             kernel_pool_reshaped.get_shape().as_list()[1:3] + [input_shape[-1], filters])
-
-    outputs = conv2d(inputs, filters, kernel_size=kernel_pool.get_shape().as_list()[:2], strides=strides,
-                     kernel=kernel_pool, padding='SAME', use_bias=False)
-    if use_bias:
-        if bias is None:
-            with tf.variable_scope('conv_pool2d'):
-                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
-        else:
-            bias_shape = [filters]
-            if bias_shape != bias.get_shape().as_list():
-                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" %
-                                 (tuple(bias_shape), tuple(bias.get_shape().as_list())))
-        outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def conv_pool2d_v2(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, pool_mode='avg'):
-    kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-    strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2
-    if padding != 'SAME' or pool_mode != 'avg':
-        raise NotImplementedError
-    input_shape = inputs.get_shape().as_list()
-    if input_shape[1] % strides[0] or input_shape[2] % strides[1]:
-        raise NotImplementedError("The height and width of the input should be "
-                                  "an integer multiple of the respective stride.")
-    kernel_shape = list(kernel_size) + [input_shape[-1], filters]
-    if kernel is None:
-        with tf.variable_scope('conv_pool2d'):
-            kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32,
-                                     initializer=tf.truncated_normal_initializer(stddev=0.02))
-    else:
-        if kernel_shape != kernel.get_shape().as_list():
-            raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" %
-                             (tuple(kernel_shape), tuple(kernel.get_shape().as_list())))
-
-    inputs_conv = conv2d(inputs, filters, kernel_size=kernel_size, strides=(1, 1),
-                         kernel=kernel, bias=None, use_bias=False)
-    outputs = pool2d(inputs_conv, pool_size=strides, strides=strides, pool_mode='avg')
-
-    if use_bias:
-        if bias is None:
-            with tf.variable_scope('conv_pool2d'):
-                bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer())
-        else:
-            bias_shape = [filters]
-            if bias_shape != bias.get_shape().as_list():
-                raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" %
-                                 (tuple(bias_shape), tuple(bias.get_shape().as_list())))
-        outputs = tf.nn.bias_add(outputs, bias)
-    return outputs
-
-
-def lrelu(x, alpha):
-    """
-    Leaky ReLU activation function
-
-    Reference:
-        https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_ops.py
-    """
-    with tf.name_scope("lrelu"):
-        return tf.maximum(alpha * x, x)
-
-
-def batchnorm(input):
-    with tf.variable_scope("batchnorm"):
-        # this block looks like it has 3 inputs on the graph unless we do this
-        input = tf.identity(input)
-
-        channels = input.get_shape()[-1]
-        offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
-        scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.truncated_normal_initializer(1.0, 0.02))
-        mean, variance = tf.nn.moments(input, axes=list(range(len(input.get_shape()) - 1)), keepdims=False)
-        variance_epsilon = 1e-5
-        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
-        return normalized
-
-
-def instancenorm(input):
-    with tf.variable_scope("instancenorm"):
-        # this block looks like it has 3 inputs on the graph unless we do this
-        input = tf.identity(input)
-
-        channels = input.get_shape()[-1]
-        offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
-        scale = tf.get_variable("scale", [channels], dtype=tf.float32,
-                                initializer=tf.truncated_normal_initializer(1.0, 0.02))
-        mean, variance = tf.nn.moments(input, axes=list(range(1, len(input.get_shape()) - 1)), keepdims=True)
-        variance_epsilon = 1e-5
-        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale,
-                                               variance_epsilon=variance_epsilon)
-        return normalized
-
-
-def flatten(input, axis=1, end_axis=-1):
-    """
-    Caffe-style flatten.
-
-    Args:
-        inputs: An N-D tensor.
-        axis: The first axis to flatten: all preceding axes are retained in the output.
-            May be negative to index from the end (e.g., -1 for the last axis).
-        end_axis: The last axis to flatten: all following axes are retained in the output.
-            May be negative to index from the end (e.g., the default -1 for the last
-            axis)
-
-    Returns:
-        A M-D tensor where M = N - (end_axis - axis)
-    """
-    input_shape = tf.shape(input)
-    input_rank = tf.shape(input_shape)[0]
-    if axis < 0:
-        axis = input_rank + axis
-    if end_axis < 0:
-        end_axis = input_rank + end_axis
-    output_shape = []
-    if axis != 0:
-        output_shape.append(input_shape[:axis])
-    output_shape.append([tf.reduce_prod(input_shape[axis:end_axis + 1])])
-    if end_axis + 1 != input_rank:
-        output_shape.append(input_shape[end_axis + 1:])
-    output_shape = tf.concat(output_shape, axis=0)
-    output = tf.reshape(input, output_shape)
-    return output
-
-
-def tile_concat(values, axis):
-    """
-    Like concat except that first tiles the broadcastable dimensions if necessary
-    """
-    shapes = [value.get_shape() for value in values]
-    # convert axis to positive form
-    ndims = shapes[0].ndims
-    for shape in shapes[1:]:
-        assert ndims == shape.ndims
-    if -ndims < axis < 0:
-        axis += ndims
-    # remove axis dimension
-    shapes = [shape.as_list() for shape in shapes]
-    dims = [shape.pop(axis) for shape in shapes]
-    shapes = [tf.TensorShape(shape) for shape in shapes]
-    # compute broadcasted shape
-    b_shape = shapes[0]
-    for shape in shapes[1:]:
-        b_shape = tf.broadcast_static_shape(b_shape, shape)
-    # add back axis dimension
-    b_shapes = [b_shape.as_list() for _ in dims]
-    for b_shape, dim in zip(b_shapes, dims):
-        b_shape.insert(axis, dim)
-    # tile values to match broadcasted shape, if necessary
-    b_values = []
-    for value, b_shape in zip(values, b_shapes):
-        multiples = []
-        for dim, b_dim in zip(value.get_shape().as_list(), b_shape):
-            if dim == b_dim:
-                multiples.append(1)
-            else:
-                assert dim == 1
-                multiples.append(b_dim)
-        if any(multiple != 1 for multiple in multiples):
-            b_value = tf.tile(value, multiples)
-        else:
-            b_value = value
-        b_values.append(b_value)
-    return tf.concat(b_values, axis=axis)
-
-
-def sigmoid_kl_with_logits(logits, targets):
-    # broadcasts the same target value across the whole batch
-    # this is implemented so awkwardly because tensorflow lacks an x log x op
-    assert isinstance(targets, float)
-    if targets in [0., 1.]:
-        entropy = 0.
-    else:
-        entropy = - targets * np.log(targets) - (1. - targets) * np.log(1. - targets)
-    return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits) * targets) - entropy
-
-
-def spectral_normed_weight(W, u=None, num_iters=1):
-    SPECTRAL_NORMALIZATION_VARIABLES = 'spectral_normalization_variables'
-
-    # Usually num_iters = 1 will be enough
-    W_shape = W.shape.as_list()
-    W_reshaped = tf.reshape(W, [-1, W_shape[-1]])
-    if u is None:
-        u = tf.get_variable("u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
-
-    def l2normalize(v, eps=1e-12):
-        return v / (tf.norm(v) + eps)
-
-    def power_iteration(i, u_i, v_i):
-        v_ip1 = l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped)))
-        u_ip1 = l2normalize(tf.matmul(v_ip1, W_reshaped))
-        return i + 1, u_ip1, v_ip1
-    _, u_final, v_final = tf.while_loop(
-        cond=lambda i, _1, _2: i < num_iters,
-        body=power_iteration,
-        loop_vars=(tf.constant(0, dtype=tf.int32),
-                   u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]]))
-    )
-    sigma = tf.squeeze(tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final)))
-    W_bar_reshaped = W_reshaped / sigma
-    W_bar = tf.reshape(W_bar_reshaped, W_shape)
-
-    if u not in tf.get_collection(SPECTRAL_NORMALIZATION_VARIABLES):
-        tf.add_to_collection(SPECTRAL_NORMALIZATION_VARIABLES, u)
-        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u.assign(u_final))
-    return W_bar
-
-
-def get_activation_layer(layer_type):
-    if layer_type == 'relu':
-        layer = tf.nn.relu
-    elif layer_type == 'elu':
-        layer = tf.nn.elu
-    else:
-        raise ValueError('Invalid activation layer %s' % layer_type)
-    return layer
-
-
-def get_norm_layer(layer_type):
-    if layer_type == 'batch':
-        layer = tf.layers.batch_normalization
-    elif layer_type == 'layer':
-        layer = tf.contrib.layers.layer_norm
-    elif layer_type == 'instance':
-        from model_modules.video_prediction.layers import fused_instance_norm
-        layer = fused_instance_norm
-    elif layer_type == 'none':
-        layer = tf.identity
-    else:
-        raise ValueError('Invalid normalization layer %s' % layer_type)
-    return layer
-
-
-def get_upsample_layer(layer_type):
-    if layer_type == 'deconv2d':
-        layer = deconv2d
-    elif layer_type == 'upsample_conv2d':
-        layer = upsample_conv2d
-    elif layer_type == 'upsample_conv2d_v2':
-        layer = upsample_conv2d_v2
-    else:
-        raise ValueError('Invalid upsampling layer %s' % layer_type)
-    return layer
-
-
-def get_downsample_layer(layer_type):
-    if layer_type == 'conv2d':
-        layer = conv2d
-    elif layer_type == 'conv_pool2d':
-        layer = conv_pool2d
-    elif layer_type == 'conv_pool2d_v2':
-        layer = conv_pool2d_v2
-    else:
-        raise ValueError('Invalid downsampling layer %s' % layer_type)
-    return layer
diff --git a/video_prediction_tools/model_modules/video_prediction/rnn_ops.py b/video_prediction_tools/model_modules/video_prediction/rnn_ops.py
deleted file mode 100644
index 970fae3ec656dd3e677e906602f0d29d6650b2c7..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/rnn_ops.py
+++ /dev/null
@@ -1,267 +0,0 @@
-# Copyright 2016 The TensorFlow Authors All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Convolutional LSTM implementation."""
-
-import tensorflow as tf
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import rnn_cell_impl
-from tensorflow.python.ops import variable_scope as vs
-
-
-class BasicConv2DLSTMCell(rnn_cell_impl.RNNCell):
-    """2D Convolutional LSTM cell with (optional) normalization and recurrent dropout.
-
-    The implementation is based on: tf.contrib.rnn.LayerNormBasicLSTMCell.
-
-    It does not allow cell clipping, a projection layer, and does not
-    use peep-hole connections: it is the basic baseline.
-    """
-    def __init__(self, input_shape, filters, kernel_size,
-                 forget_bias=1.0, activation_fn=math_ops.tanh,
-                 normalizer_fn=None, separate_norms=True,
-                 norm_gain=1.0, norm_shift=0.0,
-                 dropout_keep_prob=1.0, dropout_prob_seed=None,
-                 skip_connection=False, reuse=None):
-        """Initializes the basic convolutional LSTM cell.
-
-        Args:
-            input_shape: int tuple, Shape of the input, excluding the batch size.
-            filters: int, The number of filters of the conv LSTM cell.
-            kernel_size: int tuple, The kernel size of the conv LSTM cell.
-            forget_bias: float, The bias added to forget gates (see above).
-            activation_fn: Activation function of the inner states.
-            normalizer_fn: If specified, this normalization will be applied before the
-                internal nonlinearities.
-            separate_norms: If set to `False`, the normalizer_fn is applied to the
-                concatenated tensor that follows the convolution, i.e. before splitting
-                the tensor. This case is slightly faster but it might be functionally
-                different, depending on the normalizer_fn (it's functionally the same
-                for instance norm but not for layer norm). Default: `True`.
-            norm_gain: float, The layer normalization gain initial value. If
-                `normalizer_fn` is `None`, this argument will be ignored.
-            norm_shift: float, The layer normalization shift initial value. If
-                `normalizer_fn` is `None`, this argument will be ignored.
-            dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
-                recurrent dropout probability value. If float and 1.0, no dropout will
-                be applied.
-            dropout_prob_seed: (optional) integer, the randomness seed.
-            skip_connection: If set to `True`, concatenate the input to the
-                output of the conv LSTM. Default: `False`.
-            reuse: (optional) Python boolean describing whether to reuse variables
-                in an existing scope.  If not `True`, and the existing scope already has
-                the given variables, an error is raised.
-        """
-        super(BasicConv2DLSTMCell, self).__init__(_reuse=reuse)
-
-        self._input_shape = input_shape
-        self._filters = filters
-        self._kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-        self._forget_bias = forget_bias
-        self._activation_fn = activation_fn
-        self._normalizer_fn = normalizer_fn
-        self._separate_norms = separate_norms
-        self._g = norm_gain
-        self._b = norm_shift
-        self._keep_prob = dropout_keep_prob
-        self._seed = dropout_prob_seed
-        self._skip_connection = skip_connection
-        self._reuse = reuse
-
-        if self._skip_connection:
-            output_channels = self._filters + self._input_shape[-1]
-        else:
-            output_channels = self._filters
-        cell_size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._filters])
-        self._output_size = tensor_shape.TensorShape(self._input_shape[:-1] + [output_channels])
-        self._state_size = rnn_cell_impl.LSTMStateTuple(cell_size, self._output_size)
-
-    @property
-    def output_size(self):
-        return self._output_size
-
-    @property
-    def state_size(self):
-        return self._state_size
-
-    def _norm(self, inputs, scope):
-        shape = inputs.get_shape()[-1:]
-        gamma_init = init_ops.constant_initializer(self._g)
-        beta_init = init_ops.constant_initializer(self._b)
-        with vs.variable_scope(scope):
-            # Initialize beta and gamma for use by normalizer.
-            vs.get_variable("gamma", shape=shape, initializer=gamma_init)
-            vs.get_variable("beta", shape=shape, initializer=beta_init)
-        normalized = self._normalizer_fn(inputs, reuse=True, scope=scope)
-        return normalized
-
-    def _conv2d(self, inputs):
-        output_filters = 4 * self._filters
-        input_shape = inputs.get_shape().as_list()
-        kernel_shape = list(self._kernel_size) + [input_shape[-1], output_filters]
-        kernel = vs.get_variable("kernel", kernel_shape, dtype=dtypes.float32,
-                                 initializer=init_ops.truncated_normal_initializer(stddev=0.02))
-        outputs = nn_ops.conv2d(inputs, kernel, [1] * 4, padding='SAME')
-        if not self._normalizer_fn:
-            bias = vs.get_variable('bias', [output_filters], dtype=dtypes.float32,
-                                   initializer=init_ops.zeros_initializer())
-            outputs = nn_ops.bias_add(outputs, bias)
-        return outputs
-
-    def _dense(self, inputs):
-        num_units = 4 * self._filters
-        input_shape = inputs.shape.as_list()
-        kernel_shape = [input_shape[-1], num_units]
-        kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32,
-                                 initializer=init_ops.truncated_normal_initializer(stddev=0.02))
-        outputs = tf.matmul(inputs, kernel)
-        return outputs
-
-    def call(self, inputs, state):
-        """2D Convolutional LSTM cell with (optional) normalization and recurrent dropout."""
-        c, h = state
-        tile_concat = isinstance(inputs, (list, tuple))
-        if tile_concat:
-            inputs, inputs_non_spatial = inputs
-        args = array_ops.concat([inputs, h], -1)
-        concat = self._conv2d(args)
-        if tile_concat:
-            concat = concat + self._dense(inputs_non_spatial)[:, None, None, :]
-
-        if self._normalizer_fn and not self._separate_norms:
-            concat = self._norm(concat, "input_transform_forget_output")
-        i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=-1)
-        if self._normalizer_fn and self._separate_norms:
-            i = self._norm(i, "input")
-            j = self._norm(j, "transform")
-            f = self._norm(f, "forget")
-            o = self._norm(o, "output")
-
-        g = self._activation_fn(j)
-        if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
-            g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
-
-        new_c = (c * math_ops.sigmoid(f + self._forget_bias)
-                 + math_ops.sigmoid(i) * g)
-        if self._normalizer_fn:
-            new_c = self._norm(new_c, "state")
-        new_h = self._activation_fn(new_c) * math_ops.sigmoid(o)
-
-        if self._skip_connection:
-            new_h = array_ops.concat([new_h, inputs], axis=-1)
-
-        new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
-        return new_h, new_state
-
-
-class Conv2DGRUCell(tf.nn.rnn_cell.RNNCell):
-    """2D Convolutional GRU cell with (optional) normalization.
-
-    Modified from these:
-    https://github.com/carlthome/tensorflow-convlstm-cell/blob/master/cell.py
-    https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn_cell_impl.py
-    """
-    def __init__(self, input_shape, filters, kernel_size,
-                 activation_fn=tf.tanh,
-                 normalizer_fn=None, separate_norms=True,
-                 bias_initializer=None, reuse=None):
-        super(Conv2DGRUCell, self).__init__(_reuse=reuse)
-        self._input_shape = input_shape
-        self._filters = filters
-        self._kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2
-        self._activation_fn = activation_fn
-        self._normalizer_fn = normalizer_fn
-        self._separate_norms = separate_norms
-        self._bias_initializer = bias_initializer
-        self._size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._filters])
-
-    @property
-    def state_size(self):
-        return self._size
-
-    @property
-    def output_size(self):
-        return self._size
-
-    def _norm(self, inputs, scope, bias_initializer):
-        shape = inputs.get_shape()[-1:]
-        gamma_init = init_ops.ones_initializer()
-        beta_init = bias_initializer
-        with vs.variable_scope(scope):
-            # Initialize beta and gamma for use by normalizer.
-            vs.get_variable("gamma", shape=shape, initializer=gamma_init)
-            vs.get_variable("beta", shape=shape, initializer=beta_init)
-        normalized = self._normalizer_fn(inputs, reuse=True, scope=scope)
-        return normalized
-
-    def _conv2d(self, inputs, output_filters, bias_initializer):
-        input_shape = inputs.get_shape().as_list()
-        kernel_shape = list(self._kernel_size) + [input_shape[-1], output_filters]
-        kernel = vs.get_variable("kernel", kernel_shape, dtype=dtypes.float32,
-                                 initializer=init_ops.truncated_normal_initializer(stddev=0.02))
-        outputs = nn_ops.conv2d(inputs, kernel, [1] * 4, padding='SAME')
-        if not self._normalizer_fn:
-            bias = vs.get_variable('bias', [output_filters], dtype=dtypes.float32,
-                                   initializer=bias_initializer)
-            outputs = nn_ops.bias_add(outputs, bias)
-        return outputs
-
-    def _dense(self, inputs, num_units):
-        input_shape = inputs.shape.as_list()
-        kernel_shape = [input_shape[-1], num_units]
-        kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32,
-                                 initializer=init_ops.truncated_normal_initializer(stddev=0.02))
-        outputs = tf.matmul(inputs, kernel)
-        return outputs
-
-    def call(self, inputs, state):
-        bias_ones = self._bias_initializer
-        if self._bias_initializer is None:
-            bias_ones = init_ops.ones_initializer()
-        tile_concat = isinstance(inputs, (list, tuple))
-        if tile_concat:
-            inputs, inputs_non_spatial = inputs
-        with vs.variable_scope('gates'):
-            inputs = array_ops.concat([inputs, state], axis=-1)
-            concat = self._conv2d(inputs, 2 * self._filters, bias_ones)
-            if tile_concat:
-                concat = concat + self._dense(inputs_non_spatial, concat.shape[-1].value)[:, None, None, :]
-            if self._normalizer_fn and not self._separate_norms:
-                concat = self._norm(concat, "reset_update", bias_ones)
-            r, u = array_ops.split(concat, 2, axis=-1)
-            if self._normalizer_fn and self._separate_norms:
-                r = self._norm(r, "reset", bias_ones)
-                u = self._norm(u, "update", bias_ones)
-            r, u = math_ops.sigmoid(r), math_ops.sigmoid(u)
-
-        bias_zeros = self._bias_initializer
-        if self._bias_initializer is None:
-            bias_zeros = init_ops.zeros_initializer()
-        with vs.variable_scope('candidate'):
-            inputs = array_ops.concat([inputs, r * state], axis=-1)
-            candidate = self._conv2d(inputs, self._filters, bias_zeros)
-            if tile_concat:
-                candidate = candidate + self._dense(inputs_non_spatial, candidate.shape[-1].value)[:, None, None, :]
-            if self._normalizer_fn:
-                candidate = self._norm(candidate, "state", bias_zeros)
-
-        c = self._activation_fn(candidate)
-        new_h = u * state + (1 - u) * c
-        return new_h, new_h
diff --git a/video_prediction_tools/model_modules/video_prediction/utils/__init__.py b/video_prediction_tools/model_modules/video_prediction/utils/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/video_prediction_tools/model_modules/video_prediction/utils/ffmpeg_gif.py b/video_prediction_tools/model_modules/video_prediction/utils/ffmpeg_gif.py
deleted file mode 100644
index 8aaeea0f09506d56b0c556168b6d9864498f2b69..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/utils/ffmpeg_gif.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# SPDX-FileCopyrightText: 2018, alexlee-gk
-#
-# SPDX-License-Identifier: MIT
-
-import os
-
-import numpy as np
-
-
-def save_gif(gif_fname, images, fps):
-    """
-    To generate a gif from image files, first generate palette from images
-    and then generate the gif from the images and the palette.
-    ffmpeg -i input_%02d.jpg -vf palettegen -y palette.png
-    ffmpeg -i input_%02d.jpg -i palette.png -lavfi paletteuse -y output.gif
-
-    Alternatively, use a filter to map the input images to both the palette
-    and gif commands, while also passing the palette to the gif command.
-    ffmpeg -i input_%02d.jpg -filter_complex "[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse" -y output.gif
-
-    To directly pass in numpy images, use rawvideo format and `-i -` option.
-    """
-    from subprocess import Popen, PIPE
-    head, tail = os.path.split(gif_fname)
-    if head and not os.path.exists(head):
-        os.makedirs(head)
-    h, w, c = images[0].shape
-    cmd = ['ffmpeg', '-y',
-           '-f', 'rawvideo',
-           '-vcodec', 'rawvideo',
-           '-r', '%.02f' % fps,
-           '-s', '%dx%d' % (w, h),
-           '-pix_fmt', {1: 'gray', 3: 'rgb24', 4: 'rgba'}[c],
-           '-i', '-',
-           '-filter_complex', '[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse',
-           '-r', '%.02f' % fps,
-           '%s' % gif_fname]
-    proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE)
-    for image in images:
-        proc.stdin.write(image.tostring())
-    out, err = proc.communicate()
-    if proc.returncode:
-        err = '\n'.join([' '.join(cmd), err.decode('utf8')])
-        raise IOError(err)
-    del proc
-
-
-def encode_gif(images, fps):
-    """Encodes numpy images into gif string.
-    Args:
-        images: A 5-D `uint8` `np.array` (or a list of 4-D images) of shape
-            `[batch_size, time, height, width, channels]` where `channels` is 1 or 3.
-        fps: frames per second of the animation
-    Returns:
-        The encoded gif string.
-    Raises:
-        IOError: If the ffmpeg command returns an error.
-    """
-    from subprocess import Popen, PIPE
-    h, w, c = images[0].shape
-    cmd = ['ffmpeg', '-y',
-           '-f', 'rawvideo',
-           '-vcodec', 'rawvideo',
-           '-r', '%.02f' % fps,
-           '-s', '%dx%d' % (w, h),
-           '-pix_fmt', {1: 'gray', 3: 'rgb24'}[c],
-           '-i', '-',
-           '-filter_complex', '[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse',
-           '-r', '%.02f' % fps,
-           '-f', 'gif',
-           '-']
-    proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE)
-    for image in images:
-        proc.stdin.write(image.tostring())
-    out, err = proc.communicate()
-    if proc.returncode:
-        err = '\n'.join([' '.join(cmd), err.decode('utf8')])
-        raise IOError(err)
-    del proc
-    return out
-
-
-def main():
-    images_shape = (12, 64, 64, 3)  # num_frames, height, width, channels
-    images = np.random.randint(256, size=images_shape).astype(np.uint8)
-
-    save_gif('output_save.gif', images, 4)
-    with open('output_save.gif', 'rb') as f:
-        string_save = f.read()
-
-    string_encode = encode_gif(images, 4)
-    with open('output_encode.gif', 'wb') as f:
-        f.write(string_encode)
-
-    print(np.all(string_save == string_encode))
-
-
-if __name__ == '__main__':
-    main()
diff --git a/video_prediction_tools/model_modules/video_prediction/utils/gif_summary.py b/video_prediction_tools/model_modules/video_prediction/utils/gif_summary.py
deleted file mode 100644
index 7f9ce616951a66dee17c1081a4961b4af1d6a57f..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/utils/gif_summary.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# coding=utf-8
-# Copyright 2018 The Tensor2Tensor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-from tensorflow.python.ops import summary_op_util
-#from tensorflow.python.distribute.summary_op_util import skip_summary TODO: IMPORT ERRORS IN juwels
-from model_modules.video_prediction.utils import ffmpeg_gif
-
-
-def py_gif_summary(tag, images, max_outputs, fps):
-  """Outputs a `Summary` protocol buffer with gif animations.
-  Args:
-    tag: Name of the summary.
-    images: A 5-D `uint8` `np.array` of shape `[batch_size, time, height, width,
-      channels]` where `channels` is 1 or 3.
-    max_outputs: Max number of batch elements to generate gifs for.
-    fps: frames per second of the animation
-  Returns:
-    The serialized `Summary` protocol buffer.
-  Raises:
-    ValueError: If `images` is not a 5-D `uint8` array with 1 or 3 channels.
-  """
-  is_bytes = isinstance(tag, bytes)
-  if is_bytes:
-    tag = tag.decode("utf-8")
-  images = np.asarray(images)
-  if images.dtype != np.uint8:
-    raise ValueError("Tensor must have dtype uint8 for gif summary.")
-  if images.ndim != 5:
-    raise ValueError("Tensor must be 5-D for gif summary.")
-  batch_size, _, height, width, channels = images.shape
-  if channels not in (1, 3):
-    raise ValueError("Tensors must have 1 or 3 channels for gif summary.")
-
-  summ = tf.Summary()
-  num_outputs = min(batch_size, max_outputs)
-  for i in range(num_outputs):
-    image_summ = tf.Summary.Image()
-    image_summ.height = height
-    image_summ.width = width
-    image_summ.colorspace = channels  # 1: grayscale, 3: RGB
-    try:
-      image_summ.encoded_image_string = ffmpeg_gif.encode_gif(images[i], fps)
-    except (IOError, OSError) as e:
-      tf.logging.warning(
-          "Unable to encode images to a gif string because either ffmpeg is "
-          "not installed or ffmpeg returned an error: %s. Falling back to an "
-          "image summary of the first frame in the sequence.", e)
-      try:
-        from PIL import Image  # pylint: disable=g-import-not-at-top
-        import io  # pylint: disable=g-import-not-at-top
-        with io.BytesIO() as output:
-          Image.fromarray(images[i][0]).save(output, "PNG")
-          image_summ.encoded_image_string = output.getvalue()
-      except:
-        tf.logging.warning(
-            "Gif summaries requires ffmpeg or PIL to be installed: %s", e)
-        image_summ.encoded_image_string = "".encode('utf-8') if is_bytes else ""
-    if num_outputs == 1:
-      summ_tag = "{}/gif".format(tag)
-    else:
-      summ_tag = "{}/gif/{}".format(tag, i)
-    summ.value.add(tag=summ_tag, image=image_summ)
-  summ_str = summ.SerializeToString()
-  return summ_str
-
-
-def gif_summary(name, tensor, max_outputs=3, fps=10, collections=None,
-                family=None):
-  """Outputs a `Summary` protocol buffer with gif animations.
-  Args:
-    name: Name of the summary.
-    tensor: A 5-D `uint8` `Tensor` of shape `[batch_size, time, height, width,
-      channels]` where `channels` is 1 or 3.
-    max_outputs: Max number of batch elements to generate gifs for.
-    fps: frames per second of the animation
-    collections: Optional list of tf.GraphKeys.  The collections to add the
-      summary to.  Defaults to [tf.GraphKeys.SUMMARIES]
-    family: Optional; if provided, used as the prefix of the summary tag name,
-      which controls the tab name used for display on Tensorboard.
-  Returns:
-    A scalar `Tensor` of type `string`. The serialized `Summary` protocol
-    buffer.
-  """
-  tensor = tf.convert_to_tensor(tensor)
-  # if skip_summary(): TODO: skipo summary errors happend in JUEWLS
-  #   return tf.constant("")
-  with summary_op_util.summary_scope(
-      name, family, values=[tensor]) as (tag, scope):
-    val = tf.py_func(
-        py_gif_summary,
-        [tag, tensor, max_outputs, fps],
-        tf.string,
-        stateful=False,
-        name=scope)
-    summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES])
-  return val
diff --git a/video_prediction_tools/model_modules/video_prediction/utils/html.py b/video_prediction_tools/model_modules/video_prediction/utils/html.py
deleted file mode 100755
index ee60f60021588c3d3140caf4c0ba4cb339a616bb..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/utils/html.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# SPDX-FileCopyrightText: 2018, alexlee-gk
-#
-# SPDX-License-Identifier: MIT
-
-import os
-
-import dominate
-from dominate.tags import *
-
-
-class HTML:
-    def __init__(self, web_dir, title, reflesh=0):
-        self.title = title
-        self.web_dir = web_dir
-        self.img_dir = os.path.join(self.web_dir, 'images')
-        if not os.path.exists(self.web_dir):
-            os.makedirs(self.web_dir)
-        if not os.path.exists(self.img_dir):
-            os.makedirs(self.img_dir)
-        # print(self.img_dir)
-
-        self.doc = dominate.document(title=title)
-        if reflesh > 0:
-            with self.doc.head:
-                meta(http_equiv="reflesh", content=str(reflesh))
-        self.t = None
-
-    def get_image_dir(self):
-        return self.img_dir
-
-    def add_header1(self, str):
-        with self.doc:
-            h1(str)
-
-    def add_header2(self, str):
-        with self.doc:
-            h2(str)
-
-    def add_header3(self, str):
-        with self.doc:
-            h3(str)
-
-    def add_table(self, border=1):
-        self.t = table(border=border, style="table-layout: fixed;")
-        self.doc.add(self.t)
-
-    def add_row(self, txts, colspans=None):
-        if self.t is None:
-            self.add_table()
-        with self.t:
-            with tr():
-                if colspans:
-                    assert len(txts) == len(colspans)
-                    colspans = [dict(colspan=str(colspan)) for colspan in colspans]
-                else:
-                    colspans = [dict()] * len(txts)
-                for txt, colspan in zip(txts, colspans):
-                    style = "word-break: break-all;" if len(str(txt)) > 80 else "word-wrap: break-word;"
-                    with td(style=style, halign="center", valign="top", **colspan):
-                        with p():
-                            if txt is not None:
-                                p(txt)
-
-    def add_images(self, ims, txts, links, colspans=None, height=None, width=400):
-        image_style = ''
-        if height is not None:
-            image_style += "height:%dpx;" % height
-        if width is not None:
-            image_style += "width:%dpx;" % width
-        if self.t is None:
-            self.add_table()
-        with self.t:
-            with tr():
-                if colspans:
-                    assert len(txts) == len(colspans)
-                    colspans = [dict(colspan=str(colspan)) for colspan in colspans]
-                else:
-                    colspans = [dict()] * len(txts)
-                for im, txt, link, colspan in zip(ims, txts, links, colspans):
-                    with td(style="word-wrap: break-word;", halign="center", valign="top", **colspan):
-                        with p():
-                            if im is not None and link is not None:
-                                with a(href=os.path.join('images', link)):
-                                    img(style=image_style, src=os.path.join('images', im))
-                            if im is not None and link is not None and txt is not None:
-                                br()
-                            if txt is not None:
-                                p(txt)
-
-    def save(self):
-        html_file = '%s/index.html' % self.web_dir
-        f = open(html_file, 'wt')
-        f.write(self.doc.render())
-        f.close()
-
-
-if __name__ == '__main__':
-    html = HTML('web/', 'test_html')
-    html.add_header('hello world')
-
-    ims = []
-    txts = []
-    links = []
-    for n in range(4):
-        ims.append('image_%d.jpg' % n)
-        txts.append('text_%d' % n)
-        links.append('image_%d.jpg' % n)
-    html.add_images(ims, txts, links)
-    html.save()
diff --git a/video_prediction_tools/model_modules/video_prediction/utils/mcnet_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/mcnet_utils.py
deleted file mode 100644
index e4f2129eadaf5973130c1aa2f1262ab9cf96a036..0000000000000000000000000000000000000000
--- a/video_prediction_tools/model_modules/video_prediction/utils/mcnet_utils.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
-# SPDX-FileCopyrightText: 2015 Alec Radford
-#
-# SPDX-License-Identifier: MIT
-
-import cv2
-import random
-import imageio
-import scipy.misc
-import numpy as np
-
-
-def transform(image):
-    return image/127.5 - 1.
-
-
-def inverse_transform(images):
-    return (images+1.)/2.
-
-
-def save_images(images, size, image_path):
-    return imsave(inverse_transform(images)*255., size, image_path)
-
-
-def merge(images, size):
-    h, w = images.shape[1], images.shape[2]
-    img = np.zeros((h * size[0], w * size[1], 3))
-
-    for idx, image in enumerate(images):
-        i = idx % size[1]
-        j = idx / size[1]
-        img[j*h:j*h+h, i*w:i*w+w, :] = image
-
-    return img
-
-
-def imsave(images, size, path):
-    return scipy.misc.imsave(path, merge(images, size))
-
-
-def get_minibatches_idx(n, minibatch_size, shuffle=False):
-    """ 
-    Used to shuffle the dataset at each iteration.
-    """
-    idx_list = np.arange(n, dtype="int32")
-
-    if shuffle:
-        random.shuffle(idx_list)
-
-    minibatches = []
-    minibatch_start = 0 
-    for i in range(n // minibatch_size):
-        minibatches.append(idx_list[minibatch_start:minibatch_start + minibatch_size])
-        minibatch_start += minibatch_size
-
-    if (minibatch_start != n): 
-    # Make a minibatch out of what is left
-        minibatches.append(idx_list[minibatch_start:])
-
-    return zip(range(len(minibatches)), minibatches)
-
-
-def draw_frame(img, is_input):
-    if img.shape[2] == 1:
-        img = np.repeat(img, [3], axis=2)
-    if is_input:
-        img[:2,:,0]  = img[:2,:,2] = 0 
-        img[:,:2,0]  = img[:,:2,2] = 0 
-        img[-2:,:,0] = img[-2:,:,2] = 0 
-        img[:,-2:,0] = img[:,-2:,2] = 0 
-        img[:2,:,1]  = 255 
-        img[:,:2,1]  = 255 
-        img[-2:,:,1] = 255 
-        img[:,-2:,1] = 255 
-    else:
-        img[:2,:,0]  = img[:2,:,1] = 0 
-        img[:,:2,0]  = img[:,:2,2] = 0 
-        img[-2:,:,0] = img[-2:,:,1] = 0 
-        img[:,-2:,0] = img[:,-2:,1] = 0 
-        img[:2,:,2]  = 255 
-        img[:,:2,2]  = 255 
-        img[-2:,:,2] = 255 
-        img[:,-2:,2] = 255 
-
-    return img 
-
-
-def load_kth_data(f_name, data_path, image_size, K, T): 
-    flip = np.random.binomial(1,.5,1)[0]
-    tokens = f_name.split()
-    vid_path = data_path + tokens[0] + "_uncomp.avi"
-    vid = imageio.get_reader(vid_path,"ffmpeg")
-    low = int(tokens[1])
-    high = np.min([int(tokens[2]),vid.get_length()])-K-T+1
-    if low == high:
-        stidx = 0 
-    else:
-        if low >= high: print(vid_path)
-        stidx = np.random.randint(low=low, high=high)
-    seq = np.zeros((image_size, image_size, K+T, 1), dtype="float32")
-    for t in xrange(K+T):
-        img = cv2.cvtColor(cv2.resize(vid.get_data(stidx+t),
-                           (image_size,image_size)),
-                           cv2.COLOR_RGB2GRAY)
-        seq[:,:,t] = transform(img[:,:,None])
-
-    if flip == 1:
-        seq = seq[:,::-1]
-
-    diff = np.zeros((image_size, image_size, K-1, 1), dtype="float32")
-    for t in xrange(1,K):
-        prev = inverse_transform(seq[:,:,t-1])
-        next = inverse_transform(seq[:,:,t])
-        diff[:,:,t-1] = next.astype("float32")-prev.astype("float32")
-
-    return seq, diff
-
-
-def load_s1m_data(f_name, data_path, trainlist, K, T):
-    flip = np.random.binomial(1,.5,1)[0]
-    vid_path = data_path + f_name  
-    img_size = [240,320]
-
-    while True:
-        try:
-            vid = imageio.get_reader(vid_path,"ffmpeg")
-            low = 1
-            high = vid.get_length()-K-T+1
-            if low == high:
-                stidx = 0
-            else:
-                stidx = np.random.randint(low=low, high=high)
-            seq = np.zeros((img_size[0], img_size[1], K+T, 3),
-                         dtype="float32")
-            for t in xrange(K+T):
-                img = cv2.resize(vid.get_data(stidx+t),
-                                 (img_size[1],img_size[0]))[:,:,::-1] 
-                seq[:,:,t] = transform(img)
-
-            if flip == 1:seq = seq[:,::-1]
-
-            diff = np.zeros((img_size[0], img_size[1], K-1, 1),
-                          dtype="float32")
-            for t in xrange(1,K):
-                prev = inverse_transform(seq[:,:,t-1])*255
-                prev = cv2.cvtColor(prev.astype("uint8"),cv2.COLOR_BGR2GRAY)
-                next = inverse_transform(seq[:,:,t])*255
-                next = cv2.cvtColor(next.astype("uint8"),cv2.COLOR_BGR2GRAY)
-                diff[:,:,t-1,0] = (next.astype("float32")-prev.astype("float32"))/255.
-            break
-        except Exception:
-            # In case the current video is bad load a random one 
-            rep_idx = np.random.randint(low=0, high=len(trainlist))
-            f_name = trainlist[rep_idx]
-            vid_path = data_path + f_name
-
-    return seq, diff
diff --git a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py
index 5b63bb77dea1fc60197d0116705d8237718fda05..b2a1534e43012d56677259eba57bb60049bad6aa 100644
--- a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py
+++ b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py
@@ -1,7 +1,3 @@
-# SPDX-FileCopyrightText: 2018, alexlee-gk
-#
-# SPDX-License-Identifier: MIT
-
 import itertools
 import os
 from collections import OrderedDict
@@ -14,518 +10,7 @@ from tensorflow.core.framework import node_def_pb2
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.training import device_setter
 from tensorflow.python.util import nest
-from model_modules.video_prediction.utils import ffmpeg_gif
-from model_modules.video_prediction.utils import gif_summary
-
-IMAGE_SUMMARIES = "image_summaries"
-EVAL_SUMMARIES = "eval_summaries"
-
-
-def local_device_setter(num_devices=1,
-                        ps_device_type='cpu',
-                        worker_device='/cpu:0',
-                        ps_ops=None,
-                        ps_strategy=None):
-    if ps_ops == None:
-        ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
-
-    if ps_strategy is None:
-        ps_strategy = device_setter._RoundRobinStrategy(num_devices)
-    if not six.callable(ps_strategy):
-        raise TypeError("ps_strategy must be callable")
-
-    def _local_device_chooser(op):
-        current_device = pydev.DeviceSpec.from_string(op.device or "")
-
-        node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
-        if node_def.op in ps_ops:
-            ps_device_spec = pydev.DeviceSpec.from_string(
-                '/{}:{}'.format(ps_device_type, ps_strategy(op)))
-
-            ps_device_spec.merge_from(current_device)
-            return ps_device_spec.to_string()
-        else:
-            worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
-            worker_device_spec.merge_from(current_device)
-            return worker_device_spec.to_string()
-
-    return _local_device_chooser
-
-
-def replace_read_ops(loss_or_losses, var_list):
-    """
-    Replaces read ops of each variable in `vars` with new read ops obtained
-    from `read_value()`, thus forcing to read the most up-to-date values of
-    the variables (which might incur copies across devices).
-    The graph is seeded from the tensor(s) `loss_or_losses`.
-    """
-    # ops between var ops and the loss
-    ops = set(ge.get_walks_intersection_ops([var.op for var in var_list], loss_or_losses))
-    if not ops:  # loss_or_losses doesn't depend on any var in var_list, so there is nothiing to replace
-        return
-
-    # filter out variables that are not involved in computing the loss
-    var_list = [var for var in var_list if var.op in ops]
-
-    for var in var_list:
-        output, = var.op.outputs
-        read_ops = set(output.consumers()) & ops
-        for read_op in read_ops:
-            with tf.name_scope('/'.join(read_op.name.split('/')[:-1])):
-                with tf.device(read_op.device):
-                    read_t, = read_op.outputs
-                    consumer_ops = set(read_t.consumers()) & ops
-                    # consumer_sgv might have multiple inputs, but we only care
-                    # about replacing the input that is read_t
-                    consumer_sgv = ge.sgv(consumer_ops)
-                    consumer_sgv = consumer_sgv.remap_inputs([list(consumer_sgv.inputs).index(read_t)])
-                    ge.connect(ge.sgv(var.read_value().op), consumer_sgv)
-
-
-def print_loss_info(losses, *tensors):
-    def get_descendants(tensor, tensors):
-        descendants = []
-        for child in tensor.op.inputs:
-            if child in tensors:
-                descendants.append(child)
-            else:
-                descendants.extend(get_descendants(child, tensors))
-        return descendants
-
-    name_to_tensors = itertools.chain(*[tensor.items() for tensor in tensors])
-    tensor_to_names = OrderedDict([(v, k) for k, v in name_to_tensors])
-
-    print(tf.get_default_graph().get_name_scope())
-    for name, (loss, weight) in losses.items():
-        print('  %s (%r)' % (name, weight))
-        descendant_names = []
-        for descendant in set(get_descendants(loss, tensor_to_names.keys())):
-            descendant_names.append(tensor_to_names[descendant])
-        for descendant_name in sorted(descendant_names):
-            print('    %s' % descendant_name)
-
-
-def with_flat_batch(flat_batch_fn, ndims=4):
-    def fn(x, *args, **kwargs):
-        shape = tf.shape(x)
-        flat_batch_shape = tf.concat([[-1], shape[-(ndims-1):]], axis=0)
-        flat_batch_shape.set_shape([ndims])
-        flat_batch_x = tf.reshape(x, flat_batch_shape)
-        flat_batch_r = flat_batch_fn(flat_batch_x, *args, **kwargs)
-        r = nest.map_structure(lambda x: tf.reshape(x, tf.concat([shape[:-(ndims-1)], tf.shape(x)[1:]], axis=0)),
-                               flat_batch_r)
-        return r
-    return fn
-
-
-def transpose_batch_time(x):
-    if isinstance(x, tf.Tensor) and x.shape.ndims >= 2:
-        return tf.transpose(x, [1, 0] + list(range(2, x.shape.ndims)))
-    else:
-        return x
-
-
-def dimension(inputs, axis=0):
-    shapes = [input_.shape for input_ in nest.flatten(inputs)]
-    s = tf.TensorShape([None])
-    for shape in shapes:
-        s = s.merge_with(shape[axis:axis + 1])
-    dim = s[0].value
-    return dim
-
-
-def unroll_rnn(cell, inputs, scope=None, use_dynamic_rnn=True):
-    """Chooses between dynamic_rnn and static_rnn if the leading time dimension is dynamic or not."""
-    dim = dimension(inputs, axis=0)
-    if use_dynamic_rnn or dim is None:
-        return tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32,
-                                 swap_memory=False, time_major=True, scope=scope)
-    else:
-        return static_rnn(cell, inputs, scope=scope)
-
-
-def static_rnn(cell, inputs, scope=None):
-    """Simple version of static_rnn."""
-    with tf.variable_scope(scope or "rnn") as varscope:
-        batch_size = dimension(inputs, axis=1)
-        state = cell.zero_state(batch_size, tf.float32)
-        flat_inputs = nest.flatten(inputs)
-        flat_inputs = list(zip(*[tf.unstack(flat_input, axis=0) for flat_input in flat_inputs]))
-        flat_outputs = []
-        for time, flat_input in enumerate(flat_inputs):
-            if time > 0:
-                varscope.reuse_variables()
-            input_ = nest.pack_sequence_as(inputs, flat_input)
-            output, state = cell(input_, state)
-            flat_output = nest.flatten(output)
-            flat_outputs.append(flat_output)
-        flat_outputs = [tf.stack(flat_output, axis=0) for flat_output in zip(*flat_outputs)]
-        outputs = nest.pack_sequence_as(output, flat_outputs)
-        return outputs, state
-
-
-def maybe_pad_or_slice(tensor, desired_length):
-    length = tensor.shape.as_list()[0]
-    if length < desired_length:
-        paddings = [[0, desired_length - length]] + [[0, 0]] * (tensor.shape.ndims - 1)
-        tensor = tf.pad(tensor, paddings)
-    elif length > desired_length:
-        tensor = tensor[:desired_length]
-    assert tensor.shape.as_list()[0] == desired_length
-    return tensor
-
-
-def tensor_to_clip(tensor):
-    if tensor.shape.ndims == 6:
-        # concatenate last dimension vertically
-        tensor = tf.concat(tf.unstack(tensor, axis=-1), axis=-3)
-    if tensor.shape.ndims == 5:
-        # concatenate batch dimension horizontally
-        tensor = tf.concat(tf.unstack(tensor, axis=0), axis=2)
-    if tensor.shape.ndims == 4:
-        # keep up to the first 3 channels
-        tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True)
-    else:
-        raise NotImplementedError
-    return tensor
-
-
-def tensor_to_image_batch(tensor):
-    if tensor.shape.ndims == 6:
-        # concatenate last dimension vertically
-        tensor= tf.concat(tf.unstack(tensor, axis=-1), axis=-3)
-    if tensor.shape.ndims == 5:
-        # concatenate time dimension horizontally
-        tensor = tf.concat(tf.unstack(tensor, axis=1), axis=2)
-    if tensor.shape.ndims == 4:
-        # keep up to the first 3 channels
-        tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True)
-    else:
-        raise NotImplementedError
-    return tensor
-
-
-def _as_name_scope_map(values):
-    name_scope_to_values = {}
-    for name, value in values.items():
-        name_scope = name.split('/')[0]
-        name_scope_to_values.setdefault(name_scope, {})
-        name_scope_to_values[name_scope][name] = value
-    return name_scope_to_values
-
-
-def add_image_summaries(outputs, max_outputs=8, collections=None):
-    if collections is None:
-        collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES]
-    for name_scope, outputs in _as_name_scope_map(outputs).items():
-        with tf.name_scope(name_scope):
-            for name, output in outputs.items():
-                if max_outputs:
-                    output = output[:max_outputs]
-                output = tensor_to_image_batch(output)
-                if output.shape[-1] not in (1, 3):
-                    # these are feature maps, so just skip them
-                    continue
-                tf.summary.image(name, output, collections=collections)
-
-
-def add_gif_summaries(outputs, max_outputs=8, collections=None):
-    if collections is None:
-        collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES]
-    for name_scope, outputs in _as_name_scope_map(outputs).items():
-        with tf.name_scope(name_scope):
-            for name, output in outputs.items():
-                if max_outputs:
-                    output = output[:max_outputs]
-                output = tensor_to_clip(output)
-                if output.shape[-1] not in (1, 3):
-                    # these are feature maps, so just skip them
-                    continue
-                gif_summary.gif_summary(name, output[None], fps=4, collections=collections)
-
-
-def add_scalar_summaries(losses_or_metrics, collections=None):
-    for name_scope, losses_or_metrics in _as_name_scope_map(losses_or_metrics).items():
-        with tf.name_scope(name_scope):
-            for name, loss_or_metric in losses_or_metrics.items():
-                if isinstance(loss_or_metric, tuple):
-                    loss_or_metric, _ = loss_or_metric
-                tf.summary.scalar(name, loss_or_metric, collections=collections)
-
-
-def add_summaries(outputs, collections=None):
-    scalar_outputs = OrderedDict()
-    image_outputs = OrderedDict()
-    gif_outputs = OrderedDict()
-    for name, output in outputs.items():
-        if not isinstance(output, tf.Tensor):
-            continue
-        if output.shape.ndims == 0:
-            scalar_outputs[name] = output
-        elif output.shape.ndims == 4:
-            image_outputs[name] = output
-        elif output.shape.ndims > 4 and output.shape[4].value in (1, 3):
-            gif_outputs[name] = output
-    add_scalar_summaries(scalar_outputs, collections=collections)
-    add_image_summaries(image_outputs, collections=collections)
-    add_gif_summaries(gif_outputs, collections=collections)
-
-
-def plot_buf(y):
-    def _plot_buf(y):
-        from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
-        from matplotlib.figure import Figure
-        import io
-        fig = Figure(figsize=(3, 3))
-        canvas = FigureCanvas(fig)
-        ax = fig.add_subplot(111)
-        ax.plot(y)
-        ax.grid(axis='y')
-        fig.tight_layout(pad=0)
 
-        buf = io.BytesIO()
-        fig.savefig(buf, format='png')
-        buf.seek(0)
-        return buf.getvalue()
-
-    s = tf.py_func(_plot_buf, [y], tf.string)
-    return s
-
-
-def add_plot_image_summaries(metrics, collections=None):
-    if collections is None:
-        collections = [IMAGE_SUMMARIES]
-    for name_scope, metrics in _as_name_scope_map(metrics).items():
-        with tf.name_scope(name_scope):
-            for name, metric in metrics.items():
-                try:
-                    buf = plot_buf(metric)
-                except:
-                    continue
-                image = tf.image.decode_png(buf, channels=4)
-                image = tf.expand_dims(image, axis=0)
-                tf.summary.image(name, image, max_outputs=1, collections=collections)
-
-
-def plot_summary(name, x, y, display_name=None, description=None, collections=None):
-    """
-    Hack that uses pr_curve summaries for 2D plots.
-
-    Args:
-        x: 1-D tensor with values in increasing order.
-        y: 1-D tensor with static shape.
-
-    Note: tensorboard needs to be modified and compiled from source to disable
-    default axis range [-0.05, 1.05].
-    """
-    from tensorboard import summary as summary_lib
-    x = tf.convert_to_tensor(x)
-    y = tf.convert_to_tensor(y)
-    with tf.control_dependencies([
-        tf.assert_equal(tf.shape(x), tf.shape(y)),
-        tf.assert_equal(y.shape.ndims, 1),
-    ]):
-        y = tf.identity(y)
-    num_thresholds = y.shape[0].value
-    if num_thresholds is None:
-        raise ValueError('Size of y needs to be statically defined for num_thresholds argument')
-    summary = summary_lib.pr_curve_raw_data_op(
-        name,
-        true_positive_counts=tf.ones(num_thresholds),
-        false_positive_counts=tf.ones(num_thresholds),
-        true_negative_counts=tf.ones(num_thresholds),
-        false_negative_counts=tf.ones(num_thresholds),
-        precision=y[::-1],
-        recall=x[::-1],
-        num_thresholds=num_thresholds,
-        display_name=display_name,
-        description=description,
-        collections=collections)
-    return summary
-
-
-def add_plot_summaries(metrics, x_offset=0, collections=None):
-    for name_scope, metrics in _as_name_scope_map(metrics).items():
-        with tf.name_scope(name_scope):
-            for name, metric in metrics.items():
-                plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections)
-
-
-def add_plot_and_scalar_summaries(metrics, x_offset=0, collections=None):
-    for name_scope, metrics in _as_name_scope_map(metrics).items():
-        with tf.name_scope(name_scope):
-            for name, metric in metrics.items():
-                tf.summary.scalar(name, tf.reduce_mean(metric), collections=collections)
-                plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections)
-
-
-def convert_tensor_to_gif_summary(summ):
-    if isinstance(summ, bytes):
-        summary_proto = tf.Summary()
-        summary_proto.ParseFromString(summ)
-        summ = summary_proto
-
-    summary = tf.Summary()
-    for value in summ.value:
-        tag = value.tag
-        try:
-            images_arr = tf.make_ndarray(value.tensor)
-        except TypeError:
-            summary.value.add(tag=tag, image=value.image)
-            continue
-
-        if len(images_arr.shape) == 5:
-            images_arr = np.concatenate(list(images_arr), axis=-2)
-        if len(images_arr.shape) != 4:
-            raise ValueError('Tensors must be 4-D or 5-D for gif summary.')
-        channels = images_arr.shape[-1]
-        if channels < 1 or channels > 4:
-            raise ValueError('Tensors must have 1, 2, 3, or 4 color channels for gif summary.')
-
-        encoded_image_string = ffmpeg_gif.encode_gif(images_arr, fps=4)
-
-        image = tf.Summary.Image()
-        image.height = images_arr.shape[-3]
-        image.width = images_arr.shape[-2]
-        image.colorspace = channels  # 1: grayscale, 2: grayscale + alpha, 3: RGB, 4: RGBA
-        image.encoded_image_string = encoded_image_string
-        summary.value.add(tag=tag, image=image)
-    return summary
-
-
-def compute_averaged_gradients(opt, tower_loss, **kwargs):
-    tower_gradvars = []
-    for loss in tower_loss:
-        with tf.device(loss.device):
-            gradvars = opt.compute_gradients(loss, **kwargs)
-            tower_gradvars.append(gradvars)
-
-    # Now compute global loss and gradients.
-    gradvars = []
-    with tf.name_scope('gradient_averaging'):
-        all_grads = {}
-        for grad, var in itertools.chain(*tower_gradvars):
-            if grad is not None:
-                all_grads.setdefault(var, []).append(grad)
-        for var, grads in all_grads.items():
-            # Average gradients on the same device as the variables
-            # to which they apply.
-            with tf.device(var.device):
-                if len(grads) == 1:
-                    avg_grad = grads[0]
-                else:
-                    avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
-            gradvars.append((avg_grad, var))
-    return gradvars
-
-
-# the next 3 function are from tensorpack:
-# https://github.com/tensorpack/tensorpack/blob/master/tensorpack/graph_builder/utils.py
-def split_grad_list(grad_list):
-    """
-    Args:
-        grad_list: K x N x 2
-
-    Returns:
-        K x N: gradients
-        K x N: variables
-    """
-    g = []
-    v = []
-    for tower in grad_list:
-        g.append([x[0] for x in tower])
-        v.append([x[1] for x in tower])
-    return g, v
-
-
-def merge_grad_list(all_grads, all_vars):
-    """
-    Args:
-        all_grads (K x N): gradients
-        all_vars(K x N): variables
-
-    Return:
-        K x N x 2: list of list of (grad, var) pairs
-    """
-    return [list(zip(gs, vs)) for gs, vs in zip(all_grads, all_vars)]
-
-
-def allreduce_grads(all_grads, average):
-    """
-    All-reduce average the gradients among K devices. Results are broadcasted to all devices.
-
-    Args:
-        all_grads (K x N): List of list of gradients. N is the number of variables.
-        average (bool): average gradients or not.
-
-    Returns:
-        K x N: same as input, but each grad is replaced by the average over K devices.
-    """
-    from tensorflow.contrib import nccl
-    nr_tower = len(all_grads)
-    if nr_tower == 1:
-        return all_grads
-    new_all_grads = []  # N x K
-    for grads in zip(*all_grads):
-        summed = nccl.all_sum(grads)
-
-        grads_for_devices = []  # K
-        for g in summed:
-            with tf.device(g.device):
-                # tensorflow/benchmarks didn't average gradients
-                if average:
-                    g = tf.multiply(g, 1.0 / nr_tower)
-            grads_for_devices.append(g)
-        new_all_grads.append(grads_for_devices)
-
-    # transpose to K x N
-    ret = list(zip(*new_all_grads))
-    return ret
-
-
-def _reduce_entries(*entries):
-    num_gpus = len(entries)
-    if entries[0] is None:
-        assert all(entry is None for entry in entries[1:])
-        reduced_entry = None
-    elif isinstance(entries[0], tf.Tensor):
-        if entries[0].shape.ndims == 0:
-            reduced_entry = tf.add_n(entries) / tf.to_float(num_gpus)
-        else:
-            reduced_entry = tf.concat(entries, axis=0)
-    elif np.isscalar(entries[0]) or isinstance(entries[0], np.ndarray):
-        if np.isscalar(entries[0]) or entries[0].ndim == 0:
-            reduced_entry = sum(entries) / float(num_gpus)
-        else:
-            reduced_entry = np.concatenate(entries, axis=0)
-    elif isinstance(entries[0], tuple) and len(entries[0]) == 2:
-        losses, weights = zip(*entries)
-        loss = tf.add_n(losses) / tf.to_float(num_gpus)
-        if isinstance(weights[0], tf.Tensor):
-            with tf.control_dependencies([tf.assert_equal(weight, weights[0]) for weight in weights[1:]]):
-                weight = tf.identity(weights[0])
-        else:
-            assert all(weight == weights[0] for weight in weights[1:])
-            weight = weights[0]
-        reduced_entry = (loss, weight)
-    else:
-        raise NotImplementedError
-    return reduced_entry
-
-
-def reduce_tensors(structures, shallow=False):
-    if len(structures) == 1:
-        reduced_structure = structures[0]
-    else:
-        if shallow:
-            if isinstance(structures[0], dict):
-                shallow_tree = type(structures[0])([(k, None) for k in structures[0]])
-            else:
-                shallow_tree = type(structures[0])([None for _ in structures[0]])
-            reduced_structure = nest.map_structure_up_to(shallow_tree, _reduce_entries, *structures)
-        else:
-            reduced_structure = nest.map_structure(_reduce_entries, *structures)
-    return reduced_structure
 
 
 def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None):
@@ -571,46 +56,3 @@ def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=Fal
 
     return restore_saver, checkpoint
 
-
-def pixel_distribution(pos, height, width):
-    batch_size = pos.get_shape().as_list()[0]
-    y, x = tf.unstack(pos, 2, axis=1)
-
-    x0 = tf.cast(tf.floor(x), 'int32')
-    x1 = x0 + 1
-    y0 = tf.cast(tf.floor(y), 'int32')
-    y1 = y0 + 1
-
-    Ia = tf.reshape(tf.one_hot(y0 * width + x0, height * width), [batch_size, height, width])
-    Ib = tf.reshape(tf.one_hot(y1 * width + x0, height * width), [batch_size, height, width])
-    Ic = tf.reshape(tf.one_hot(y0 * width + x1, height * width), [batch_size, height, width])
-    Id = tf.reshape(tf.one_hot(y1 * width + x1, height * width), [batch_size, height, width])
-
-    x0_f = tf.cast(x0, 'float32')
-    x1_f = tf.cast(x1, 'float32')
-    y0_f = tf.cast(y0, 'float32')
-    y1_f = tf.cast(y1, 'float32')
-    wa = ((x1_f - x) * (y1_f - y))[:, None, None]
-    wb = ((x1_f - x) * (y - y0_f))[:, None, None]
-    wc = ((x - x0_f) * (y1_f - y))[:, None, None]
-    wd = ((x - x0_f) * (y - y0_f))[:, None, None]
-
-    return tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
-
-
-def flow_to_rgb(flows):
-    """The last axis should have dimension 2, for x and y values."""
-
-    def cartesian_to_polar(x, y):
-        magnitude = tf.sqrt(tf.square(x) + tf.square(y))
-        angle = tf.atan2(y, x)
-        return magnitude, angle
-
-    mag, ang = cartesian_to_polar(*tf.unstack(flows, axis=-1))
-    ang_normalized = (ang + np.pi) / (2 * np.pi)
-    mag_min = tf.reduce_min(mag)
-    mag_max = tf.reduce_max(mag)
-    mag_normalized = (mag - mag_min) / (mag_max - mag_min)
-    hsv = tf.stack([ang_normalized, tf.ones_like(ang), mag_normalized], axis=-1)
-    rgb = tf.image.hsv_to_rgb(hsv)
-    return rgb
diff --git a/video_prediction_tools/utils/runscript_generator/config_postprocess.py b/video_prediction_tools/utils/runscript_generator/config_postprocess.py
index 3a258272787a4a7e3b0b32d11e5c3bf4f8a3973a..01817ac326ed92c43a1cfc43d8925ef1680cccc0 100755
--- a/video_prediction_tools/utils/runscript_generator/config_postprocess.py
+++ b/video_prediction_tools/utils/runscript_generator/config_postprocess.py
@@ -11,7 +11,7 @@ __date__ = "2021-02-01"
 # import modules
 import os, glob
 from model_modules.model_architectures import known_models
-from data_preprocess.dataset_options import known_datasets
+from data_extraction.dataset_options import known_datasets
 from runscript_generator.config_utils import Config_runscript_base    # import parent class
 
 class Config_Postprocess(Config_runscript_base):