diff --git a/video_prediction_tools/HPC_scripts/meta_postprocess_era5_template.sh b/video_prediction_tools/HPC_scripts/meta_postprocess_template.sh similarity index 100% rename from video_prediction_tools/HPC_scripts/meta_postprocess_era5_template.sh rename to video_prediction_tools/HPC_scripts/meta_postprocess_template.sh diff --git a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh b/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh deleted file mode 100644 index daa48d352ce6b1eca9c2f76692e68ca3e786273e..0000000000000000000000000000000000000000 --- a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/bin/bash -x -#SBATCH --account=<your_project> -#SBATCH --nodes=1 -#SBATCH --ntasks=13 -##SBATCH --ntasks-per-node=13 -#SBATCH --cpus-per-task=1 -#SBATCH --output=DataPreprocess_era5_step2-out.%j -#SBATCH --error=DataPreprocess_era5_step2-err.%j -#SBATCH --time=04:00:00 -#SBATCH --gres=gpu:0 -#SBATCH --partition=batch -#SBATCH --mail-type=ALL -#SBATCH --mail-user=me@somewhere.com - -######### Template identifier (don't remove) ######### -echo "Do not run the template scripts" -exit 99 -######### Template identifier (don't remove) ######### - -# auxiliary variables -WORK_DIR="$(pwd)" -BASE_DIR=$(dirname "$WORK_DIR") -# Name of virtual environment -VIRT_ENV_NAME="my_venv" -# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!! -# For container usage, comment in the follwoing lines -# Name of container image (must be available in working directory) -CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif" -WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh" - -# sanity checks -if [[ ! -f ${CONTAINER_IMG} ]]; then - echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'." - exit 1 -fi - -if [[ ! -f ${WRAPPER} ]]; then - echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image." - exit 1 -fi - -# clean-up modules to avoid conflicts between host and container settings -module purge - -# declare directory-variables which will be modified by config_runscript.py -source_dir=/my/path/to/pkl/files/ -destination_dir=/my/path/to/tfrecords/files - -sequence_length=20 -sequences_per_file=10 -# run Preprocessing (step 2 where Tf-records are generated) -export CUDA_VISIBLE_DEVICES=0 -## One node, single GPU -srun --mpi=pspmix --cpu-bind=none \ - singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \ - python3 ../main_scripts/main_preprocess_data_step2.py -source_dir ${source_dir} -dest_dir ${destination_dir} \ - -sequence_length ${sequence_length} -sequences_per_file ${sequences_per_file} - -# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above) -# Activate virtual environment if needed (and possible) -#if [ -z ${VIRTUAL_ENV} ]; then -# if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then -# echo "Activating virtual environment..." -# source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate -# else -# echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." -# exit 1 -# fi -#fi -# -# Loading modules -#module purge -#source ../env_setup/modules_train.sh -#export CUDA_VISIBLE_DEVICES=0 -# -# srun python3 ../main_scripts/main_preprocess_data_step2.py -source_dir ${source_dir} -dest_dir ${destination_dir} \ -# -sequence_length ${sequence_length} -sequences_per_file ${sequences_per_file} diff --git a/video_prediction_tools/HPC_scripts/preprocess_data_moving_mnist_template.sh b/video_prediction_tools/HPC_scripts/preprocess_data_moving_mnist_template.sh deleted file mode 100644 index f72950255efa181ca95b9b4f13c81efafe1e7733..0000000000000000000000000000000000000000 --- a/video_prediction_tools/HPC_scripts/preprocess_data_moving_mnist_template.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash -x -#SBATCH --account=<your_project> -#SBATCH --nodes=1 -#SBATCH --ntasks=1 -##SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=1 -#SBATCH --output=DataPreprocess_moving_mnist-out.%j -#SBATCH --error=DataPreprocess_moving_mnist-err.%j -#SBATCH --time=04:00:00 -#SBATCH --partition=batch -#SBATCH --mail-type=ALL -#SBATCH --mail-user=me@somewhere.com - -######### Template identifier (don't remove) ######### -echo "Do not run the template scripts" -exit 99 -######### Template identifier (don't remove) ######### - -# Name of virtual environment -VIRT_ENV_NAME="my_venv" - -# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!! -# For container usage, comment in the follwoing lines -# Name of container image (must be available in working directory) -CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif" -WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh" - -# sanity checks -if [[ ! -f ${CONTAINER_IMG} ]]; then - echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'." - exit 1 -fi - -if [[ ! -f ${WRAPPER} ]]; then - echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image." - exit 1 -fi - -# clean-up modules to avoid conflicts between host and container settings -module purge - -# declare directory-variables which will be modified generate_runscript.py -source_dir=/my/path/to/mnist/raw/data/ -destination_dir=/my/path/to/mnist/tfrecords/ - -# run Preprocessing (step 2 where Tf-records are generated) -# run postprocessing/generation of model results including evaluation metrics -export CUDA_VISIBLE_DEVICES=0 -## One node, single GPU -srun --mpi=pspmix --cpu-bind=none \ - singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \ - python3 ../video_prediction/datasets/moving_mnist.py ${source_dir} ${destination_dir} - -# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above) -# Activate virtual environment if needed (and possible) -#if [ -z ${VIRTUAL_ENV} ]; then -# if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then -# echo "Activating virtual environment..." -# source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate -# else -# echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." -# exit 1 -# fi -#fi -# -# Loading modules -#module purge -#source ../env_setup/modules_train.sh -#export CUDA_VISIBLE_DEVICES=0 -# -# srun python3 .../video_prediction/datasets/moving_mnist.py ${source_dir} ${destination_dir} \ No newline at end of file diff --git a/video_prediction_tools/HPC_scripts/train_model_moving_mnist_template.sh b/video_prediction_tools/HPC_scripts/train_model_moving_mnist_template.sh deleted file mode 100755 index 322d0fac362119032f558232e8161321434d2f2f..0000000000000000000000000000000000000000 --- a/video_prediction_tools/HPC_scripts/train_model_moving_mnist_template.sh +++ /dev/null @@ -1,82 +0,0 @@ -#!/bin/bash -x -#SBATCH --account=<your_project> -#SBATCH --nodes=1 -#SBATCH --ntasks=1 -##SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=1 -#SBATCH --output=train_moving_mnist-out.%j -#SBATCH --error=train_moving_mnist-err.%j -#SBATCH --time=00:20:00 -#SBATCH --gres=gpu:1 -#SBATCH --partition=gpus -#SBATCH --mail-type=ALL -#SBATCH --mail-user=me@somewhere.com - -######### Template identifier (don't remove) ######### -echo "Do not run the template scripts" -exit 99 -######### Template identifier (don't remove) ######### - -# auxiliary variables -WORK_DIR=`pwd` -BASE_DIR=$(dirname "$WORK_DIR") -# Name of virtual environment -VIRT_ENV_NAME="my_venv" -# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!! -# For container usage, comment in the follwoing lines -# Name of container image (must be available in working directory) -CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif" -WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh" - -# sanity checks -if [[ ! -f ${CONTAINER_IMG} ]]; then - echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'." - exit 1 -fi - -if [[ ! -f ${WRAPPER} ]]; then - echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image." - exit 1 -fi - -# clean-up modules to avoid conflicts between host and container settings -module purge - -# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) - -source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist -destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist - -# for choosing the model, convLSTM,savp, mcnet,vae -model=convLSTM -dataset=moving_mnist -model_hparams=../hparams/${dataset}/${model}/model_hparams.json -destination_dir=${destination_dir}/${model}/"$(date +"%Y%m%dT%H%M")_"$USER"" - -# run training in container -export CUDA_VISIBLE_DEVICES=0 -## One node, single GPU -srun --mpi=pspmix --cpu-bind=none \ - singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \ - python ../main_scripts/train.py --input_dir ${source_dir}/tfrecords/ --dataset ${dataset} --model ${model} \ - --model_hparams_dict ${model_hparams} --output_dir "${destination_dir}"/ - -# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above) -# Activate virtual environment if needed (and possible) -#if [ -z ${VIRTUAL_ENV} ]; then -# if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then -# echo "Activating virtual environment..." -# source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate -# else -# echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." -# exit 1 -# fi -#fi -# -# Loading modules -#module purge -#source ../env_setup/modules_train.sh -#export CUDA_VISIBLE_DEVICES=0 -# -# srun python3 ../main_scripts/train.py --input_dir ${source_dir}/tfrecords/ --dataset ${dataset} --model ${model} \ -# --model_hparams_dict ${model_hparams} --output_dir "${destination_dir}"/ \ No newline at end of file diff --git a/video_prediction_tools/HPC_scripts/train_model_era5_template.sh b/video_prediction_tools/HPC_scripts/train_model_template.sh similarity index 100% rename from video_prediction_tools/HPC_scripts/train_model_era5_template.sh rename to video_prediction_tools/HPC_scripts/train_model_template.sh diff --git a/video_prediction_tools/HPC_scripts/train_model_weatherbench_template.sh b/video_prediction_tools/HPC_scripts/train_model_weatherbench_template.sh deleted file mode 100644 index 44ccf018d2896553ad360d5c5dbd0c398b7b54d8..0000000000000000000000000000000000000000 --- a/video_prediction_tools/HPC_scripts/train_model_weatherbench_template.sh +++ /dev/null @@ -1,78 +0,0 @@ -#!/bin/bash -x -#SBATCH --account=<your_project> -#SBATCH --nodes=1 -#SBATCH --ntasks=1 -#SBATCH --output=train_model_era5-out.%j -#SBATCH --error=train_model_era5-err.%j -#SBATCH --time=24:00:00 -#SBATCH --gres=gpu:1 -#SBATCH --partition=some_partition -#SBATCH --mail-type=ALL -#SBATCH --mail-user=me@somewhere.com - -######### Template identifier (don't remove) ######### -echo "Do not run the template scripts" -exit 99 -######### Template identifier (don't remove) ######### - -# auxiliary variables -WORK_DIR="$(pwd)" -BASE_DIR=$(dirname "$WORK_DIR") -# Name of virtual environment -VIRT_ENV_NAME="my_venv" -# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!! -# For container usage, comment in the follwoing lines -# Name of container image (must be available in working directory) -CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif" -WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh" - -# sanity checks -if [[ ! -f ${CONTAINER_IMG} ]]; then - echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'." - exit 1 -fi - -if [[ ! -f ${WRAPPER} ]]; then - echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image." - exit 1 -fi - -# clean-up modules to avoid conflicts between host and container settings -module purge - -# declare directory-variables which will be modified by generate_runscript.py -source_dir=/my/path/to/tfrecords/files -destination_dir=/my/model/output/path - -# valid identifiers for model-argument are: convLSTM, savp, mcnet and vae -model=convLSTM -datasplit_dict=${destination_dir}/data_split.json -model_hparams=${destination_dir}/model_hparams.json - -# run training in container -export CUDA_VISIBLE_DEVICES=0 -## One node, single GPU -srun --mpi=pspmix --cpu-bind=none \ - singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \ - python3 "${BASE_DIR}"/main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} \ - --dataset weatherbench --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/ - -# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above) -# Activate virtual environment if needed (and possible) -#if [ -z ${VIRTUAL_ENV} ]; then -# if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then -# echo "Activating virtual environment..." -# source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate -# else -# echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." -# exit 1 -# fi -#fi -# -# Loading modules -#module purge -#source ../env_setup/modules_train.sh -#export CUDA_VISIBLE_DEVICES=0 -# -# srun python3 "${BASE_DIR}"/main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} \ -# --dataset era5 --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/ \ No newline at end of file diff --git a/video_prediction_tools/HPC_scripts/visualize_postprocess_moving_mnist_template.sh b/video_prediction_tools/HPC_scripts/visualize_postprocess_moving_mnist_template.sh deleted file mode 100755 index 142193121fb12ea792d0350eac859652512438a1..0000000000000000000000000000000000000000 --- a/video_prediction_tools/HPC_scripts/visualize_postprocess_moving_mnist_template.sh +++ /dev/null @@ -1,80 +0,0 @@ -#!/bin/bash -x -#SBATCH --account=<your_project> -#SBATCH --nodes=1 -#SBATCH --ntasks=1 -##SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=1 -#SBATCH --output=generate_era5-out.%j -#SBATCH --error=generate_era5-err.%j -#SBATCH --time=00:20:00 -#SBATCH --gres=gpu:1 -#SBATCH --partition=develgpus -#SBATCH --mail-type=ALL -#SBATCH --mail-user=me@somewhere.com - -######### Template identifier (don't remove) ######### -echo "Do not run the template scripts" -exit 99 -######### Template identifier (don't remove) ######### - -# auxiliary variables -WORK_DIR="$(pwd)" -BASE_DIR=$(dirname "$WORK_DIR") -# Name of virtual environment -VIRT_ENV_NAME="my_venv" -# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!! -# For container usage, comment in the follwoing lines -# Name of container image (must be available in working directory) -CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif" -WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh" - -# sanity checks -if [[ ! -f ${CONTAINER_IMG} ]]; then - echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'." - exit 1 -fi - -if [[ ! -f ${WRAPPER} ]]; then - echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image." - exit 1 -fi - -# clean-up modules to avoid conflicts between host and container settings -module purge - -# declare directory-variables which will be modified by config_runscript.py -source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist -checkpoint_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist -results_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist -# name of model -model=convLSTM - -# run postprocessing/generation of model results including evaluation metrics -export CUDA_VISIBLE_DEVICES=0 -## One node, single GPU -srun --mpi=pspmix --cpu-bind=none \ - singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \ - python3 ../scripts/generate_movingmnist.py --input_dir ${source_dir}/ --dataset_hparams sequence_length=20 \ - --checkpoint ${checkpoint_dir}/${model} --mode test --model ${model} --results_dir ${results_dir}/${model} \ - --batch_size 2 --dataset era5 > generate_era5-out."${SLURM_JOB_ID}" - -# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above) -# Activate virtual environment if needed (and possible) -#if [ -z ${VIRTUAL_ENV} ]; then -# if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then -# echo "Activating virtual environment..." -# source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate -# else -# echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." -# exit 1 -# fi -#fi -# -# Loading modules -#module purge -#source ../env_setup/modules_train.sh -#export CUDA_VISIBLE_DEVICES=0 -# -# srun python3 ../scripts/generate_movingmnist.py --input_dir ${source_dir}/ --dataset_hparams sequence_length=20 \ -# --checkpoint ${checkpoint_dir}/${model} --mode test --model ${model} --results_dir ${results_dir}/${model} \ -# --batch_size 2 --dataset era5 > generate_era5-out."${SLURM_JOB_ID}" \ No newline at end of file diff --git a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh b/video_prediction_tools/HPC_scripts/visualize_postprocess_template.sh similarity index 100% rename from video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh rename to video_prediction_tools/HPC_scripts/visualize_postprocess_template.sh diff --git a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json b/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json deleted file mode 100644 index bc5f8983a5aa6b0b2ba3d560bc4c2391995794a4..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json +++ /dev/null @@ -1,10 +0,0 @@ - -{ - "batch_size": 10, - "lr": 0.001, - "max_epochs": 2, - "context_frames": 12 -} - - - diff --git a/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json b/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json deleted file mode 100644 index 0ccf44e6370f765857204317f172c866865b4b35..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "batch_size": 16, - "lr": 0.0002, - "beta1": 0.5, - "beta2": 0.999, - "l1_weight": 100.0, - "l2_weight": 0.0, - "kl_weight": 0.0, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.1, - "vae_gan_feature_cdist_weight": 0.0, - "gan_feature_cdist_weight": 10.0, - "state_weight": 0.0, - "nz": 32, - "max_epochs":2, - "context_frames":12 -} diff --git a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json b/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json deleted file mode 100644 index 770f9ff516a630ff031b94bb2c8a2b41c1686eec..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.001, - "beta1": 0.9, - "beta2": 0.999, - "l1_weight": 1.0, - "l2_weight": 0.0, - "kl_weight": 1e-05, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.0, - "state_weight": 0.0, - "nz": 32, - "max_epochs":2, - "context_frames":12 -} diff --git a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json deleted file mode 100644 index f36e1c0b44279ad2e4f9e741c7bfade0a5aa0a05..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.0002, - "beta1": 0.5, - "beta2": 0.999, - "l1_weight": 100.0, - "l2_weight": 0.0, - "kl_weight": 0.01, - "video_sn_vae_gan_weight": 0.1, - "video_sn_gan_weight": 0.1, - "vae_gan_feature_cdist_weight": 10.0, - "gan_feature_cdist_weight": 0.0, - "state_weight": 0.0, - "nz": 16, - "max_epochs":4, - "context_frames": 12, - "opt_var": "0", - "decay_steps":[3000,9000], - "end_lr": 0.00000008 -} - - diff --git a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json b/video_prediction_tools/hparams/era5/vae/model_hparams_template.json deleted file mode 100644 index 1306627e24bec0888600fb88fcaa937e5f01dbd7..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json +++ /dev/null @@ -1,14 +0,0 @@ - -{ - "batch_size": 10, - "lr": 0.001, - "nz":16, - "max_epochs":2, - "context_frames":12, - "weight_recon":1, - "loss_fun": "rmse", - "shuffle_on_val": true -} - - - diff --git a/video_prediction_tools/hparams/gzprcp/savp/model_hparams_template.json b/video_prediction_tools/hparams/gzprcp/savp/model_hparams_template.json deleted file mode 100644 index 7c1ab72eea7ad1b341a66a76c4a88d1524450417..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/gzprcp/savp/model_hparams_template.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "batch_size": 16, - "lr": 0.0002, - "beta1": 0.5, - "beta2": 0.999, - "l1_weight": 100.0, - "l2_weight": 0.0, - "kl_weight": 0.01, - "video_sn_vae_gan_weight": 0.1, - "video_sn_gan_weight": 0.1, - "vae_gan_feature_cdist_weight": 10.0, - "gan_feature_cdist_weight": 0.0, - "state_weight": 0.0, - "nz": 16, - "max_epochs":2, - "context_frames":10, - "sequence_length":30 -} - - diff --git a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json deleted file mode 100644 index 2b341c11e9853c974c84a7573a70aead6b985d65..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json +++ /dev/null @@ -1,12 +0,0 @@ - -{ - "batch_size": 10, - "lr": 0.001, - "max_epochs":20, - "context_frames":10, - "loss_fun":"cross_entropy", - "opt_var": "all" -} - - - diff --git a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json deleted file mode 100644 index 6cda5552d437b9b283a40bdde77eb1d3b3497b36..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json +++ /dev/null @@ -1,11 +0,0 @@ - -{ - "batch_size": 10, - "lr": 0.001, - "max_epochs":20, - "context_frames":10, - "loss_fun":"cross_entropy" -} - - - diff --git a/video_prediction_tools/hparams/weatherbench/mcnet/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/mcnet/model_hparams_template.json deleted file mode 100644 index bc5f8983a5aa6b0b2ba3d560bc4c2391995794a4..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/weatherbench/mcnet/model_hparams_template.json +++ /dev/null @@ -1,10 +0,0 @@ - -{ - "batch_size": 10, - "lr": 0.001, - "max_epochs": 2, - "context_frames": 12 -} - - - diff --git a/video_prediction_tools/hparams/weatherbench/ours_gan/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/ours_gan/model_hparams_template.json deleted file mode 100644 index 0ccf44e6370f765857204317f172c866865b4b35..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/weatherbench/ours_gan/model_hparams_template.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "batch_size": 16, - "lr": 0.0002, - "beta1": 0.5, - "beta2": 0.999, - "l1_weight": 100.0, - "l2_weight": 0.0, - "kl_weight": 0.0, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.1, - "vae_gan_feature_cdist_weight": 0.0, - "gan_feature_cdist_weight": 10.0, - "state_weight": 0.0, - "nz": 32, - "max_epochs":2, - "context_frames":12 -} diff --git a/video_prediction_tools/hparams/weatherbench/ours_vae_l1/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/ours_vae_l1/model_hparams_template.json deleted file mode 100644 index 770f9ff516a630ff031b94bb2c8a2b41c1686eec..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/weatherbench/ours_vae_l1/model_hparams_template.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.001, - "beta1": 0.9, - "beta2": 0.999, - "l1_weight": 1.0, - "l2_weight": 0.0, - "kl_weight": 1e-05, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.0, - "state_weight": 0.0, - "nz": 32, - "max_epochs":2, - "context_frames":12 -} diff --git a/video_prediction_tools/hparams/weatherbench/savp/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/savp/model_hparams_template.json deleted file mode 100644 index f36e1c0b44279ad2e4f9e741c7bfade0a5aa0a05..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/weatherbench/savp/model_hparams_template.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.0002, - "beta1": 0.5, - "beta2": 0.999, - "l1_weight": 100.0, - "l2_weight": 0.0, - "kl_weight": 0.01, - "video_sn_vae_gan_weight": 0.1, - "video_sn_gan_weight": 0.1, - "vae_gan_feature_cdist_weight": 10.0, - "gan_feature_cdist_weight": 0.0, - "state_weight": 0.0, - "nz": 16, - "max_epochs":4, - "context_frames": 12, - "opt_var": "0", - "decay_steps":[3000,9000], - "end_lr": 0.00000008 -} - - diff --git a/video_prediction_tools/hparams/weatherbench/vae/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/vae/model_hparams_template.json deleted file mode 100644 index 1306627e24bec0888600fb88fcaa937e5f01dbd7..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/weatherbench/vae/model_hparams_template.json +++ /dev/null @@ -1,14 +0,0 @@ - -{ - "batch_size": 10, - "lr": 0.001, - "nz":16, - "max_epochs":2, - "context_frames":12, - "weight_recon":1, - "loss_fun": "rmse", - "shuffle_on_val": true -} - - - diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index cf9fc63f01483cc9d0d9fe6925e10fdc2607cd4a..a93642776beeb7b6a8ad9db0968b9be08594fb4b 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -113,6 +113,7 @@ class TrainModel(object): """ if self.model_hparams_dict: with open(self.model_hparams_dict, 'r') as f: + print("self.model_hparams_dict",self.model_hparams_dict) self.model_hparams_dict_load = json.loads(f.read()) else: raise FileNotFoundError("hparam directory doesn't exist! please check {}!".format(self.model_hparams_dict)) @@ -171,7 +172,7 @@ class TrainModel(object): :param mode: "train" used the model graph in train process; "test" for postprocessing step """ VideoPredictionModel = models.get_model_class(self.model) - self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict, mode=mode) + self.video_model = VideoPredictionModel(hparams_dict_config=self.model_hparams_dict, mode=mode) def setup_graph(self): """ @@ -209,7 +210,8 @@ class TrainModel(object): with open(os.path.join(self.output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(dataset.hparams, sort_keys=True, indent=4)) with open(os.path.join(self.output_dir, "model_hparams.json"), "w") as f: - f.write(json.dumps(video_model.hparams, sort_keys=True, indent=4)) + print("video_model.get_hparams",video_model.get_hparams) + f.write(json.dumps(video_model.get_hparams, sort_keys=True, indent=4)) #with open(os.path.join(self.output_dir, "data_dict.json"), "w") as f: # f.write(json.dumps(dataset.data_dict, sort_keys=True, indent=4)) @@ -407,45 +409,24 @@ class TrainModel(object): Fetch variables in the graph, this can be custermized based on models and also the needs of users """ # This is the basic fetch for all the models - fetch_list = ["train_op", "summary_op", "global_step"] + fetch_list = ["train_op", "summary_op", "global_step","total_loss"] # Append fetches depending on model to be trained - if self.video_model.__class__.__name__ == "McNetVideoPredictionModel": - fetch_list = fetch_list + ["L_p", "L_gdl", "L_GAN"] - self.saver_loss = fetch_list[-3] # ML: Is this a reasonable choice? + if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": + self.saver_loss = fetch_list[-1] self.saver_loss_name = "Loss" if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": - fetch_list = fetch_list + ["inputs", "total_loss"] - self.saver_loss = fetch_list[-1] - self.saver_loss_name = "Total loss" - if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": - fetch_list = fetch_list + ["g_losses", "d_losses", "d_loss", "g_loss", ("g_losses", "gen_l1_loss")] - # Add loss that is tracked - self.saver_loss = fetch_list[-1][1] - self.saver_loss_dict = fetch_list[-1][0] - self.saver_loss_name = "Generator L1 loss" - if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": - fetch_list = fetch_list + ["latent_loss", "recon_loss", "total_loss"] - self.saver_loss = fetch_list[-2] - self.saver_loss_name = "Reconstruction loss" - if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel": - fetch_list = fetch_list + ["inputs", "total_loss"] - self.saver_loss = fetch_list[-1] - self.saver_loss_name = "Total loss" - if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": - fetch_list = fetch_list + ["inputs", "total_loss"] + fetch_list = fetch_list + ["inputs"] self.saver_loss = fetch_list[-1] +<<<<<<< HEAD self.saver_loss_name = "Total loss" if self.video_model.__class__.__name__ == "WeatherBenchModel": fetch_list = fetch_list + ["total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" - else: - raise ("self.saver_loss is not set up for your video model class {}".format(self.video_model.__class__.__name__ )) - - + else: + self.saver_loss = "total_loss" self.fetches = self.generate_fetches(fetch_list) - return self.fetches def create_fetches_for_val(self): @@ -502,27 +483,16 @@ class TrainModel(object): Print the training results /validation results from the training step. """ method = TrainModel.print_results.__name__ - train_epoch = step/self.steps_per_epoch print("%{0}: Progress global step {1:d} epoch {2:.1f}".format(method, step + 1, train_epoch)) - if self.video_model.__class__.__name__ == "McNetVideoPredictionModel": - print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"], results["L_p"], - results["L_gdl"],results["L_GAN"])) - elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or self.video_model.__class__.__name__ == "WeatherBenchModel": + if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": print ("Total_loss:{}".format(results["total_loss"])) - elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": - print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}, gen_l1_loss: {}" - .format(results["g_losses"], results["d_losses"], results["g_loss"], results["d_loss"], - results["gen_l1_loss"])) - elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": - print("Total_loss:{}; latent_losses:{}; reconst_loss:{}" - .format(results["total_loss"], results["latent_loss"], results["recon_loss"])) elif self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": print("Total_loss:{}" .format(results["total_loss"])) else: - print("%{0}: Printing results of model '{1}' is not implemented yet".format(method, self.video_model.__class__.__name__)) - + print("Total_loss:{}" + .format(results["total_loss"])) @staticmethod def plot_train(train_losses, val_losses, loss_name, output_dir): """ diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py index 9ab53f1918b2dd57d43e77ee8f4dd5a7556b2d5a..d2bcc4d0d93533e7306d30a6ff655cb2d676be12 100644 --- a/video_prediction_tools/model_modules/model_architectures.py +++ b/video_prediction_tools/model_modules/model_architectures.py @@ -1,20 +1,12 @@ def known_models(): """ An auxilary function - ours_vae_l1 and ours_gan are from savp papers :return: dictionary of known model architectures """ model_mappings = { - 'ground_truth': 'GroundTruthVideoPredictionModel', - 'savp': 'SAVPVideoPredictionModel', - 'vae': 'VanillaVAEVideoPredictionModel', 'convLSTM': 'VanillaConvLstmVideoPredictionModel', - 'mcnet': 'McNetVideoPredictionModel', 'convLSTM_gan': "ConvLstmGANVideoPredictionModel", - 'ours_vae_l1': 'SAVPVideoPredictionModel', - 'ours_gan': 'SAVPVideoPredictionModel', "weatherBench": "WeatherBenchModel" - 'precrnn_v2': 'PredRNNv2VideoPredictionModel' } return model_mappings diff --git a/video_prediction_tools/model_modules/video_prediction/flow_ops.py b/video_prediction_tools/model_modules/video_prediction/flow_ops.py deleted file mode 100644 index 88114a58f51524fc9dad48a74a201ab4805f3c15..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/flow_ops.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-FileCopyrightText: 2017 Simon Meister -# -# SPDX-License-Identifier: MIT - -import tensorflow as tf - - -def image_warp(im, flow): - """Performs a backward warp of an image using the predicted flow. - - Args: - im: Batch of images. [num_batch, height, width, channels] - flow: Batch of flow vectors. [num_batch, height, width, 2] - Returns: - warped: transformed image of the same shape as the input image. - - Implementation taken from here: https://github.com/simonmeister/UnFlow - """ - with tf.variable_scope('image_warp'): - - num_batch, height, width, channels = tf.unstack(tf.shape(im)) - max_x = tf.cast(width - 1, 'int32') - max_y = tf.cast(height - 1, 'int32') - zero = tf.zeros([], dtype='int32') - - # We have to flatten our tensors to vectorize the interpolation - im_flat = tf.reshape(im, [-1, channels]) - flow_flat = tf.reshape(flow, [-1, 2]) - - # Floor the flow, as the final indices are integers - # The fractional part is used to control the bilinear interpolation. - flow_floor = tf.to_int32(tf.floor(flow_flat)) - bilinear_weights = flow_flat - tf.floor(flow_flat) - - # Construct base indices which are displaced with the flow - pos_x = tf.tile(tf.range(width), [height * num_batch]) - grid_y = tf.tile(tf.expand_dims(tf.range(height), 1), [1, width]) - pos_y = tf.tile(tf.reshape(grid_y, [-1]), [num_batch]) - - x = flow_floor[:, 0] - y = flow_floor[:, 1] - xw = bilinear_weights[:, 0] - yw = bilinear_weights[:, 1] - - # Compute interpolation weights for 4 adjacent pixels - # expand to num_batch * height * width x 1 for broadcasting in add_n below - wa = tf.expand_dims((1 - xw) * (1 - yw), 1) # top left pixel - wb = tf.expand_dims((1 - xw) * yw, 1) # bottom left pixel - wc = tf.expand_dims(xw * (1 - yw), 1) # top right pixel - wd = tf.expand_dims(xw * yw, 1) # bottom right pixel - - x0 = pos_x + x - x1 = x0 + 1 - y0 = pos_y + y - y1 = y0 + 1 - - x0 = tf.clip_by_value(x0, zero, max_x) - x1 = tf.clip_by_value(x1, zero, max_x) - y0 = tf.clip_by_value(y0, zero, max_y) - y1 = tf.clip_by_value(y1, zero, max_y) - - dim1 = width * height - batch_offsets = tf.range(num_batch) * dim1 - base_grid = tf.tile(tf.expand_dims(batch_offsets, 1), [1, dim1]) - base = tf.reshape(base_grid, [-1]) - - base_y0 = base + y0 * width - base_y1 = base + y1 * width - idx_a = base_y0 + x0 - idx_b = base_y1 + x0 - idx_c = base_y0 + x1 - idx_d = base_y1 + x1 - - Ia = tf.gather(im_flat, idx_a) - Ib = tf.gather(im_flat, idx_b) - Ic = tf.gather(im_flat, idx_c) - Id = tf.gather(im_flat, idx_d) - - warped_flat = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) - warped = tf.reshape(warped_flat, [num_batch, height, width, channels]) - warped.set_shape(im.shape) - - return warped diff --git a/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py index fd743cb9052605e009c8ed0c31c5722d32238dd7..919b0d5ac64f4f9cc081d8aba681974870e91f3f 100644 --- a/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py +++ b/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py @@ -95,16 +95,8 @@ class BasicConvLSTMCell(ConvRNNCell): #Bing20200930#replace with non-linear convolutional layers #concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True) input_h_con = tf.concat(axis = 3, values = input_h) - concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4, "decode_1", activate="sigmoid") - - print("concat1:",concat) - # i = input_gate, j = new_input, f = forget_gate, o = output_gate + concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4, "decode_1", activate="sigmoid") i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat) - print("input gate i:",i) - print("new_input j:",j) - print("forget gate:",f) - print("output gate:",o) - new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * tf.nn.sigmoid(o) @@ -113,8 +105,7 @@ class BasicConvLSTMCell(ConvRNNCell): new_state = LSTMStateTuple(new_c, new_h) else: new_state = tf.concat(axis = 3, values = [new_c, new_h]) - print("new h", new_h) - print("new state",new_state) + return new_h, new_state @@ -150,14 +141,11 @@ def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=No matrix = tf.get_variable( "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype) if len(args) == 1: - print("args[0]:",args[0]) + res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME') - print("res1:",res) + else: - print("matrix:",matrix) - print("tf.concat(axis = 3, values = args):",tf.concat(axis = 3, values = args)) res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME') - print("res2:",res) if not bias: return res bias_term = tf.get_variable( diff --git a/video_prediction_tools/model_modules/video_prediction/layers/__init__.py b/video_prediction_tools/model_modules/video_prediction/layers/__init__.py index 8530ffd70f0899c8d8e0832d0dcd377b78bbe349..8b137891791fe96927ad78e64b0aad7bded08bdc 100644 --- a/video_prediction_tools/model_modules/video_prediction/layers/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/layers/__init__.py @@ -1 +1 @@ -from .normalization import fused_instance_norm + diff --git a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py index 79d7653a5cc55d72abb7ea2fbdb22ca8be4c3b67..6f7c4f38b222afecb2b21d36bedc938b7813399a 100644 --- a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py +++ b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py @@ -58,11 +58,11 @@ def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.l def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"): - print("conv_layer activation function",activate) + with tf.variable_scope('{0}_conv'.format(idx)) as scope: input_channels = inputs.get_shape()[-1] - weights = _variable_with_weight_decay('weights',shape = [kernel_size, kernel_size, + weights = _variable_with_weight_decay('weights', shape = [kernel_size, kernel_size, input_channels, num_features], stddev = 0.01, wd = weight_decay) biases = _variable_on_gpu('biases', [num_features], initializer) @@ -97,7 +97,7 @@ def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initial output_shape = tf.stack( [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features]) - print ("output_shape",output_shape) + conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME') conv_biased = tf.nn.bias_add(conv, biases) if activate == "linear": @@ -140,6 +140,7 @@ def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01, else: ip = tf.add(tf.matmul(inputs_processed, weights), biases) return tf.nn.elu(ip, name = str(idx) + '_fc') + def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None): with tf.variable_scope('{0}_bn'.format(idx)) as scope: @@ -159,5 +160,20 @@ def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None): else: return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon) -def bn_layers_wrapper(inputs, is_training): - pass + +class batch_norm(object): + def __init__(self, epsilon=1e-5, momentum=0.9, name="batch_norm"): + with tf.variable_scope(name): + self.epsilon = epsilon + self.momentum = momentum + self.name = name + + def __call__(self, x, train=True): + return tf.contrib.layers.batch_norm(x, + decay=self.momentum, + updates_collections=None, + epsilon=self.epsilon, + scale=True, + is_training=train, + scope=self.name) + diff --git a/video_prediction_tools/model_modules/video_prediction/layers/mcnet_ops.py b/video_prediction_tools/model_modules/video_prediction/layers/mcnet_ops.py deleted file mode 100644 index 0b2e41ea2ebcef4247cceaa547fc97b73799ccc1..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/layers/mcnet_ops.py +++ /dev/null @@ -1,182 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -import math -import numpy as np -import tensorflow as tf - -from tensorflow.python.framework import ops -from model_modules.video_prediction.utils.mcnet_utils import * - - -def batch_norm(inputs, name, train=True, reuse=False): - return tf.contrib.layers.batch_norm(inputs=inputs,is_training=train, - reuse=reuse,scope=name,scale=True) - - -def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="conv2d", reuse=False, padding='SAME'): - with tf.variable_scope(name, reuse=reuse): - w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], - initializer=tf.contrib.layers.xavier_initializer()) - conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding) - - biases = tf.get_variable('biases', [output_dim], - initializer=tf.constant_initializer(0.0)) - conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) - - return conv - - -def deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, - name="deconv2d", reuse=False, with_w=False, padding='SAME'): - with tf.variable_scope(name, reuse=reuse): - # filter : [height, width, output_channels, in_channels] - w = tf.get_variable('w', [k_h, k_h, output_shape[-1], - input_.get_shape()[-1]], - initializer=tf.contrib.layers.xavier_initializer()) - - try: - deconv = tf.nn.conv2d_transpose(input_, w, - output_shape=output_shape, - strides=[1, d_h, d_w, 1], - padding=padding) - - # Support for verisons of TensorFlow before 0.7.0 - except AttributeError: - deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, - strides=[1, d_h, d_w, 1]) - biases = tf.get_variable('biases', [output_shape[-1]], - initializer=tf.constant_initializer(0.0)) - deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) - - if with_w: - return deconv, w, biases - else: - return deconv - - -def lrelu(x, leak=0.2, name="lrelu"): - with tf.variable_scope(name): - f1 = 0.5 * (1 + leak) - f2 = 0.5 * (1 - leak) - return f1 * x + f2 * abs(x) - - -def relu(x): - return tf.nn.relu(x) - - -def tanh(x): - return tf.nn.tanh(x) - - -def shape2d(a): - """ - a: a int or tuple/list of length 2 - """ - if type(a) == int: - return [a, a] - if isinstance(a, (list, tuple)): - assert len(a) == 2 - return list(a) - raise RuntimeError("Illegal shape: {}".format(a)) - - -def shape4d(a): - # for use with tensorflow - return [1] + shape2d(a) + [1] - - -def UnPooling2x2ZeroFilled(x): - out = tf.concat(axis=3, values=[x, tf.zeros_like(x)]) - out = tf.concat(axis=2, values=[out, tf.zeros_like(out)]) - - sh = x.get_shape().as_list() - if None not in sh[1:]: - out_size = [-1, sh[1] * 2, sh[2] * 2, sh[3]] - return tf.reshape(out, out_size) - else: - sh = tf.shape(x) - return tf.reshape(out, [-1, sh[1] * 2, sh[2] * 2, sh[3]]) - - -def MaxPooling(x, shape, stride=None, padding='VALID'): - """ - MaxPooling on images. - :param input: NHWC tensor. - :param shape: int or [h, w] - :param stride: int or [h, w]. default to be shape. - :param padding: 'valid' or 'same'. default to 'valid' - :returns: NHWC tensor. - """ - padding = padding.upper() - shape = shape4d(shape) - if stride is None: - stride = shape - else: - stride = shape4d(stride) - - return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding) - - -#@layer_register() -def FixedUnPooling(x, shape): - """ - Unpool the input with a fixed mat to perform kronecker product with. - :param input: NHWC tensor - :param shape: int or [h, w] - :returns: NHWC tensor - """ - shape = shape2d(shape) - - # a faster implementation for this special case - return UnPooling2x2ZeroFilled(x) - - -def gdl(gen_frames, gt_frames, alpha): - """ - Calculates the sum of GDL losses between the predicted and gt frames. - @param gen_frames: The predicted frames at each scale. - @param gt_frames: The ground truth frames at each scale - @param alpha: The power to which each gradient term is raised. - @return: The GDL loss. - """ - # create filters [-1, 1] and [[1],[-1]] - # for diffing to the left and down respectively. - pos = tf.constant(np.identity(3), dtype=tf.float32) - neg = -1 * pos - # [-1, 1] - filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) - # [[1],[-1]] - filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) - strides = [1, 1, 1, 1] # stride of (1, 1) - padding = 'SAME' - - gen_dx = tf.abs(tf.nn.conv2d(gen_frames, filter_x, strides, padding=padding)) - gen_dy = tf.abs(tf.nn.conv2d(gen_frames, filter_y, strides, padding=padding)) - gt_dx = tf.abs(tf.nn.conv2d(gt_frames, filter_x, strides, padding=padding)) - gt_dy = tf.abs(tf.nn.conv2d(gt_frames, filter_y, strides, padding=padding)) - - grad_diff_x = tf.abs(gt_dx - gen_dx) - grad_diff_y = tf.abs(gt_dy - gen_dy) - - gdl_loss = tf.reduce_mean((grad_diff_x ** alpha + grad_diff_y ** alpha)) - - # condense into one tensor and avg - return gdl_loss - - -def linear(input_, output_size, name, stddev=0.02, bias_start=0.0, - reuse=False, with_w=False): - shape = input_.get_shape().as_list() - - with tf.variable_scope(name, reuse=reuse): - matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, - tf.random_normal_initializer(stddev=stddev)) - bias = tf.get_variable("bias", [output_size], - initializer=tf.constant_initializer(bias_start)) - if with_w: - return tf.matmul(input_, matrix) + bias, matrix, bias - else: - return tf.matmul(input_, matrix) + bias diff --git a/video_prediction_tools/model_modules/video_prediction/layers/normalization.py b/video_prediction_tools/model_modules/video_prediction/layers/normalization.py deleted file mode 100644 index e6f79effdb83ec43224dcf808e8d7e7d55fcdc60..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/layers/normalization.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Contains the normalization layer classes and their functional aliases.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from tensorflow.contrib.framework.python.ops import variables -from tensorflow.contrib.layers.python.layers import utils -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import variable_scope - - -DATA_FORMAT_NCHW = 'NCHW' -DATA_FORMAT_NHWC = 'NHWC' - - -def fused_instance_norm(inputs, - center=True, - scale=True, - epsilon=1e-6, - activation_fn=None, - param_initializers=None, - reuse=None, - variables_collections=None, - outputs_collections=None, - trainable=True, - data_format=DATA_FORMAT_NHWC, - scope=None): - """Functional interface for the instance normalization layer. - - Reference: https://arxiv.org/abs/1607.08022. - - "Instance Normalization: The Missing Ingredient for Fast Stylization" - Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky - - Args: - inputs: A tensor with 2 or more dimensions, where the first dimension has - `batch_size`. The normalization is over all but the last dimension if - `data_format` is `NHWC` and the second dimension if `data_format` is - `NCHW`. - center: If True, add offset of `beta` to normalized tensor. If False, `beta` - is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is - not used. When the next layer is linear (also e.g. `nn.relu`), this can be - disabled since the scaling can be done by the next layer. - epsilon: Small float added to variance to avoid dividing by zero. - activation_fn: Activation function, default set to None to skip it and - maintain a linear activation. - param_initializers: Optional initializers for beta, gamma, moving mean and - moving variance. - reuse: Whether or not the layer and its variables should be reused. To be - able to reuse the layer scope must be given. - variables_collections: Optional collections for the variables. - outputs_collections: Collections to add the outputs. - trainable: If `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - data_format: A string. `NHWC` (default) and `NCHW` are supported. - scope: Optional scope for `variable_scope`. - - Returns: - A `Tensor` representing the output of the operation. - - Raises: - ValueError: If `data_format` is neither `NHWC` nor `NCHW`. - ValueError: If the rank of `inputs` is undefined. - ValueError: If rank or channels dimension of `inputs` is undefined. - """ - inputs = ops.convert_to_tensor(inputs) - inputs_shape = inputs.shape - inputs_rank = inputs.shape.ndims - - if inputs_rank is None: - raise ValueError('Inputs %s has undefined rank.' % inputs.name) - if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): - raise ValueError('data_format has to be either NCHW or NHWC.') - - with variable_scope.variable_scope( - scope, 'InstanceNorm', [inputs], reuse=reuse) as sc: - if data_format == DATA_FORMAT_NCHW: - reduction_axis = 1 - # For NCHW format, rather than relying on implicit broadcasting, we - # explicitly reshape the params to params_shape_broadcast when computing - # the moments and the batch normalization. - params_shape_broadcast = list( - [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) - else: - reduction_axis = inputs_rank - 1 - params_shape_broadcast = None - moments_axes = list(range(inputs_rank)) - del moments_axes[reduction_axis] - del moments_axes[0] - params_shape = inputs_shape[reduction_axis:reduction_axis + 1] - if not params_shape.is_fully_defined(): - raise ValueError('Inputs %s has undefined channels dimension %s.' % ( - inputs.name, params_shape)) - - # Allocate parameters for the beta and gamma of the normalization. - beta, gamma = None, None - dtype = inputs.dtype.base_dtype - if param_initializers is None: - param_initializers = {} - if center: - beta_collections = utils.get_variable_collections( - variables_collections, 'beta') - beta_initializer = param_initializers.get( - 'beta', init_ops.zeros_initializer()) - beta = variables.model_variable('beta', - shape=params_shape, - dtype=dtype, - initializer=beta_initializer, - collections=beta_collections, - trainable=trainable) - if params_shape_broadcast: - beta = array_ops.reshape(beta, params_shape_broadcast) - if scale: - gamma_collections = utils.get_variable_collections( - variables_collections, 'gamma') - gamma_initializer = param_initializers.get( - 'gamma', init_ops.ones_initializer()) - gamma = variables.model_variable('gamma', - shape=params_shape, - dtype=dtype, - initializer=gamma_initializer, - collections=gamma_collections, - trainable=trainable) - if params_shape_broadcast: - gamma = array_ops.reshape(gamma, params_shape_broadcast) - - if data_format == DATA_FORMAT_NHWC: - inputs = array_ops.transpose(inputs, list(range(1, reduction_axis)) + [0, reduction_axis]) - if data_format == DATA_FORMAT_NCHW: - inputs = array_ops.transpose(inputs, list(range(2, inputs_rank)) + [0, reduction_axis]) - hw, n, c = inputs.shape.as_list()[:-2], inputs.shape[-2].value, inputs.shape[-1].value - inputs = array_ops.reshape(inputs, [1] + hw + [n * c]) - if inputs.shape.ndims != 4: - # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W] - if inputs.shape.ndims > 4: - inputs_ndims4_shape = [1, hw[0], -1, n * c] - else: - inputs_ndims4_shape = [1, 1, -1, n * c] - inputs = array_ops.reshape(inputs, inputs_ndims4_shape) - beta = array_ops.reshape(array_ops.tile(beta[None, :], [n, 1]), [-1]) - gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [n, 1]), [-1]) - - outputs, _, _ = nn.fused_batch_norm( - inputs, gamma, beta, epsilon=epsilon, - data_format=DATA_FORMAT_NHWC, name='instancenorm') - - outputs = array_ops.reshape(outputs, hw + [n, c]) - if data_format == DATA_FORMAT_NHWC: - outputs = array_ops.transpose(outputs, [inputs_rank - 2] + list(range(inputs_rank - 2)) + [inputs_rank - 1]) - if data_format == DATA_FORMAT_NCHW: - outputs = array_ops.transpose(outputs, [inputs_rank - 2, inputs_rank - 1] + list(range(inputs_rank - 2))) - - # if data_format == DATA_FORMAT_NHWC: - # inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis))) - # inputs_nchw_shape = inputs.shape - # inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:]) - # if inputs.shape.ndims != 4: - # # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W] - # if inputs.shape.ndims > 4: - # inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]] - # else: - # inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1] - # inputs = array_ops.reshape(inputs, inputs_ndims4_shape) - # beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) - # gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) - # - # outputs, _, _ = nn.fused_batch_norm( - # inputs, gamma, beta, epsilon=epsilon, - # data_format=DATA_FORMAT_NCHW, name='instancenorm') - # - # outputs = array_ops.reshape(outputs, inputs_nchw_shape) - # if data_format == DATA_FORMAT_NHWC: - # outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1]) - - if activation_fn is not None: - outputs = activation_fn(outputs) - return utils.collect_named_outputs(outputs_collections, sc.name, outputs) diff --git a/video_prediction_tools/model_modules/video_prediction/losses.py b/video_prediction_tools/model_modules/video_prediction/losses.py index 4ba586d5a51e8ff065c98afebae26981b961f997..bdf4f651fce3896fd01fcf34092cfff9c68a029c 100644 --- a/video_prediction_tools/model_modules/video_prediction/losses.py +++ b/video_prediction_tools/model_modules/video_prediction/losses.py @@ -4,8 +4,6 @@ import tensorflow as tf -from model_modules.video_prediction.ops import sigmoid_kl_with_logits - def l1_loss(pred, target): return tf.reduce_mean(tf.abs(target - pred)) @@ -57,15 +55,3 @@ def gan_loss(logits, labels, gan_loss_type): raise ValueError('Unknown GAN loss type %s' % gan_loss_type) return loss - -def kl_loss(mu, log_sigma_sq, mu2=None, log_sigma2_sq=None): - if mu2 is None and log_sigma2_sq is None: - sigma_sq = tf.exp(log_sigma_sq) - return -0.5 * tf.reduce_mean(tf.reduce_sum(1 + log_sigma_sq - tf.square(mu) - sigma_sq, axis=-1)) - else: - mu1 = mu - log_sigma1_sq = log_sigma_sq - return tf.reduce_mean(tf.reduce_sum( - (log_sigma2_sq - log_sigma1_sq) / 2 - + (tf.exp(log_sigma1_sq) + tf.square(mu1 - mu2)) / (2 * tf.exp(log_sigma2_sq)) - - 1 / 2, axis=-1)) diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py index 1bb913f18ccc7b9af2053b446584594bc6875c1b..10c5c72e932e38597731a61555f56803d639bde6 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -2,21 +2,12 @@ # # SPDX-License-Identifier: MIT -from .base_model import BaseVideoPredictionModel -from .base_model import VideoPredictionModel -from .non_trainable_model import NonTrainableVideoPredictionModel -from .non_trainable_model import GroundTruthVideoPredictionModel -from .non_trainable_model import RepeatVideoPredictionModel -from .savp_model import SAVPVideoPredictionModel -from .vanilla_vae_model import VanillaVAEVideoPredictionModel + from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel -from .mcnet_model import McNetVideoPredictionModel from .test_model import TestModelVideoPredictionModel from model_modules.model_architectures import known_models from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel from .weatherBench3DCNN import WeatherBenchModel -#from .vanilla_predrnnv2 import PredRNNv2VideoPredictionModel - def get_model_class(model): model_mappings = known_models() diff --git a/video_prediction_tools/model_modules/video_prediction/models/base_model.py b/video_prediction_tools/model_modules/video_prediction/models/base_model.py deleted file mode 100644 index 8ac05c98732eeca53c0356845faa3c8db59cb527..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/models/base_model.py +++ /dev/null @@ -1,880 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import functools -import itertools -import os -import re -from collections import OrderedDict -import numpy as np -import tensorflow as tf -print('tensorflow version: {}'.format(tf.__version__)) -print(tf.contrib.training.HParams) -from tensorflow.contrib.training import HParams -from tensorflow.python.util import nest -import model_modules.video_prediction as vp -from model_modules.video_prediction.utils import tf_utils -from model_modules.video_prediction.utils.tf_utils import compute_averaged_gradients, reduce_tensors, local_device_setter, \ - replace_read_ops, print_loss_info, transpose_batch_time, add_gif_summaries, add_scalar_summaries, \ - add_plot_and_scalar_summaries, add_summaries - - -class BaseVideoPredictionModel(object): - def __init__(self, mode='train', hparams_dict=None, hparams=None, - num_gpus=None, eval_num_samples=100, - eval_num_samples_for_diversity=10, eval_parallel_iterations=1): - """ - Base video prediction model. - - Trainable and non-trainable video prediction models can be derived - from this base class. - - Args: - mode: `'train'` or `'test'`. - hparams_dict: a dict of `name=value` pairs, where `name` must be - defined in `self.get_default_hparams()`. - hparams: a string of comma separated list of `name=value` pairs, - where `name` must be defined in `self.get_default_hparams()`. - These values overrides any values in hparams_dict (if any). - """ - if mode not in ('train', 'test'): - raise ValueError('mode must be train or test, but %s given' % mode) - self.mode = mode - cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0') - if cuda_visible_devices == '': - max_num_gpus = 0 - else: - max_num_gpus = len(cuda_visible_devices.split(',')) - if num_gpus is None: - num_gpus = max_num_gpus - elif num_gpus > max_num_gpus: - raise ValueError('num_gpus=%d is greater than the number of visible devices %d' % (num_gpus, max_num_gpus)) - self.num_gpus = num_gpus - self.eval_num_samples = eval_num_samples - self.eval_num_samples_for_diversity = eval_num_samples_for_diversity - self.eval_parallel_iterations = eval_parallel_iterations - self.hparams = self.parse_hparams(hparams_dict, hparams) - if self.hparams.context_frames == -1: - raise ValueError('Invalid context_frames %r. It might have to be ' - 'specified.' % self.hparams.context_frames) - if self.hparams.sequence_length == -1: - raise ValueError('Invalid sequence_length %r. It might have to be ' - 'specified.' % self.hparams.sequence_length) - - # should be overriden by descendant class if the model is stochastic - self.deterministic = True - - # member variables that should be set by `self.build_graph` - self.inputs = None - self.gen_images = None - self.outputs = None - self.metrics = None - self.eval_outputs = None - self.eval_metrics = None - self.accum_eval_metrics = None - self.saveable_variables = None - self.post_init_ops = None - # ML 2021-06-23: Do not hide global step in self.saveable_variables - self.global_step = None - - def get_default_hparams_dict(self): - """ - The keys of this dict define valid hyperparameters for instances of - this class. A class inheriting from this one should override this - method if it has a different set of hyperparameters. - - Returns: - A dict with the following hyperparameters. - - context_frames: the number of ground-truth frames to pass in at - start. Must be specified during instantiation. - sequence_length: the number of frames in the video sequence, - including the context frames, so this model predicts - `sequence_length - context_frames` future frames. Must be - specified during instantiation. - repeat: the number of repeat actions (if applicable). - opt_var :string: "0","1",..."n", or "all", the target variable to be optimized in the loss function, if "all" means optimize all the variables and channels - - """ - hparams = dict( - context_frames=-1, - sequence_length=-1, - repeat=1, - opt_var="0" - ) - return hparams - - def get_default_hparams(self): - return HParams(**self.get_default_hparams_dict()) - - def parse_hparams(self, hparams_dict, hparams): - parsed_hparams = self.get_default_hparams().override_from_dict(hparams_dict or {}) - if hparams: - if not isinstance(hparams, (list, tuple)): - hparams = [hparams] - for hparam in hparams: - parsed_hparams.parse(hparam) - return parsed_hparams - - def build_graph(self, inputs): - self.inputs = inputs - - def metrics_fn(self, inputs, outputs): - metrics = OrderedDict() - sequence_length = tf.shape(inputs['images'])[0] - context_frames = self.hparams.context_frames - future_length = sequence_length - context_frames - # target_images and pred_images include only the future frames - target_images = inputs['images'][-future_length:] - pred_images = outputs['gen_images'][-future_length:] - metric_fns = [ - ('psnr', vp.metrics.psnr), - ('mse', vp.metrics.mse), - ('ssim', vp.metrics.ssim), - #('lpips', vp.metrics.lpips), #bing : remove lpips metric course the url fetching issue - ] - for metric_name, metric_fn in metric_fns: - metrics[metric_name] = tf.reduce_mean(metric_fn(target_images, pred_images)) - return metrics - - def eval_outputs_and_metrics_fn(self, inputs, outputs, num_samples=None, - num_samples_for_diversity=None, parallel_iterations=None): - num_samples = num_samples or self.eval_num_samples - num_samples_for_diversity = num_samples_for_diversity or self.eval_num_samples_for_diversity - parallel_iterations = parallel_iterations or self.eval_parallel_iterations - - sequence_length, batch_size = inputs['images'].shape[:2].as_list() - if batch_size is None: - batch_size = tf.shape(inputs['images'])[1] - if sequence_length is None: - sequence_length = tf.shape(inputs['images'])[0] - context_frames = self.hparams.context_frames - future_length = sequence_length - context_frames - # the outputs include all the frames, whereas the metrics include only the future frames - eval_outputs = OrderedDict() - eval_metrics = OrderedDict() - metric_fns = [ - ('psnr', vp.metrics.psnr), - ('mse', vp.metrics.mse), - ('ssim', vp.metrics.ssim), - # ('lpips', vp.metrics.lpips), #bing - ] - # images and gen_images include all the frames - images = inputs['images'] - gen_images = outputs['gen_images'] - # target_images and pred_images include only the future frames - target_images = inputs['images'][-future_length:] - pred_images = outputs['gen_images'][-future_length:] - # ground truth is the same for deterministic and stochastic models - eval_outputs['eval_images'] = images - if self.deterministic: - for metric_name, metric_fn in metric_fns: - metric = metric_fn(target_images, pred_images) - eval_metrics['eval_%s/min' % metric_name] = metric - eval_metrics['eval_%s/avg' % metric_name] = metric - eval_metrics['eval_%s/max' % metric_name] = metric - eval_outputs['eval_gen_images'] = gen_images - else: - def where_axis1(cond, x, y): - return transpose_batch_time(tf.where(cond, transpose_batch_time(x), transpose_batch_time(y))) - - def sort_criterion(x): - return tf.reduce_mean(x, axis=0) - - def accum_gen_images_and_metrics_fn(a, unused): - with tf.variable_scope(self.generator_scope, reuse=True): - outputs_sample = self.generator_fn(inputs) - gen_images_sample = outputs_sample['gen_images'] - pred_images_sample = gen_images_sample[-future_length:] - # set the posisbly static shape since it might not have been inferred correctly - pred_images_sample = tf.reshape(pred_images_sample, tf.shape(a['eval_pred_images_last'])) - for name, metric_fn in metric_fns: - metric = metric_fn(target_images, pred_images_sample) # time, batch_size - cond_min = tf.less(sort_criterion(metric), sort_criterion(a['eval_%s/min' % name])) - cond_max = tf.greater(sort_criterion(metric), sort_criterion(a['eval_%s/max' % name])) - a['eval_%s/min' % name] = where_axis1(cond_min, metric, a['eval_%s/min' % name]) - a['eval_%s/sum' % name] = metric + a['eval_%s/sum' % name] - a['eval_%s/max' % name] = where_axis1(cond_max, metric, a['eval_%s/max' % name]) - a['eval_gen_images_%s/min' % name] = where_axis1(cond_min, gen_images_sample, a['eval_gen_images_%s/min' % name]) - a['eval_gen_images_%s/sum' % name] = gen_images_sample + a['eval_gen_images_%s/sum' % name] - a['eval_gen_images_%s/max' % name] = where_axis1(cond_max, gen_images_sample, a['eval_gen_images_%s/max' % name]) - #bing - # a['eval_diversity'] = tf.cond( - # tf.logical_and(tf.less(0, a['eval_sample_ind']), - # tf.less_equal(a['eval_sample_ind'], num_samples_for_diversity)), - # lambda: -vp.metrics.lpips(a['eval_pred_images_last'], pred_images_sample) + a['eval_diversity'], - # lambda: a['eval_diversity']) - a['eval_sample_ind'] = 1 + a['eval_sample_ind'] - a['eval_pred_images_last'] = pred_images_sample - return a - - initializer = {} - for name, _ in metric_fns: - initializer['eval_gen_images_%s/min' % name] = tf.zeros_like(gen_images) - initializer['eval_gen_images_%s/sum' % name] = tf.zeros_like(gen_images) - initializer['eval_gen_images_%s/max' % name] = tf.zeros_like(gen_images) - initializer['eval_%s/min' % name] = tf.fill([future_length, batch_size], float('inf')) - initializer['eval_%s/sum' % name] = tf.zeros([future_length, batch_size]) - initializer['eval_%s/max' % name] = tf.fill([future_length, batch_size], float('-inf')) - #initializer['eval_diversity'] = tf.zeros([future_length, batch_size]) - initializer['eval_sample_ind'] = tf.zeros((), dtype=tf.int32) - initializer['eval_pred_images_last'] = tf.zeros_like(pred_images) - - eval_outputs_and_metrics = tf.foldl( - accum_gen_images_and_metrics_fn, tf.zeros([num_samples, 0]), initializer=initializer, back_prop=False, - parallel_iterations=parallel_iterations) - - for name, _ in metric_fns: - eval_outputs['eval_gen_images_%s/min' % name] = eval_outputs_and_metrics['eval_gen_images_%s/min' % name] - eval_outputs['eval_gen_images_%s/avg' % name] = eval_outputs_and_metrics['eval_gen_images_%s/sum' % name] / float(num_samples) - eval_outputs['eval_gen_images_%s/max' % name] = eval_outputs_and_metrics['eval_gen_images_%s/max' % name] - eval_metrics['eval_%s/min' % name] = eval_outputs_and_metrics['eval_%s/min' % name] - eval_metrics['eval_%s/avg' % name] = eval_outputs_and_metrics['eval_%s/sum' % name] / float(num_samples) - eval_metrics['eval_%s/max' % name] = eval_outputs_and_metrics['eval_%s/max' % name] - #eval_metrics['eval_diversity'] = eval_outputs_and_metrics['eval_diversity'] / float(num_samples_for_diversity) - return eval_outputs, eval_metrics - - def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): - - method = BaseVideoPredictionModel.restore.__name__ - if checkpoints: - var_list = self.saveable_variables - # possibly restore from multiple checkpoints. useful if subset of weights - # (e.g. generator or discriminator) are on different checkpoints. - if not isinstance(checkpoints, (list, tuple)): - checkpoints = [checkpoints] - # automatically skip global_step if more than one checkpoint is provided - skip_global_step = len(checkpoints) > 1 - savers = [] - for checkpoint in checkpoints: - print("%{0}: Creating restore saver from checkpoint '{1}'".format(method, checkpoint)) - saver, _ = tf_utils.get_checkpoint_restore_saver( - checkpoint, var_list, skip_global_step=skip_global_step, - restore_to_checkpoint_mapping=restore_to_checkpoint_mapping) - savers.append(saver) - restore_op = [saver.saver_def.restore_op_name for saver in savers] - sess.run(restore_op) - return True - else: - return False - -class VideoPredictionModel(BaseVideoPredictionModel): - def __init__(self, - generator_fn, - discriminator_fn=None, - generator_scope='generator', - discriminator_scope='discriminator', - aggregate_nccl=False, - mode='train', - hparams_dict=None, - hparams=None, - **kwargs): - """ - Trainable video prediction model with CPU and multi-GPU support. - - If num_gpus <= 1, the devices for the ops in `self.build_graph` are - automatically chosen by TensorFlow (i.e. `tf.device` is not specified), - otherwise they are explicitly chosen. - - Args: - generator_fn: callable that takes in inputs and returns a dict of - tensors. - discriminator_fn: callable that takes in fake/real data (and - optionally conditioned on inputs) and returns a dict of - tensors. - hparams_dict: a dict of `name=value` pairs, where `name` must be - defined in `self.get_default_hparams()`. - hparams: a string of comma separated list of `name=value` pairs, - where `name` must be defined in `self.get_default_hparams()`. - These values overrides any values in hparams_dict (if any). - """ - super(VideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) - self.generator_fn = functools.partial(generator_fn, mode=self.mode, hparams=self.hparams) - self.discriminator_fn = functools.partial(discriminator_fn, mode=self.mode, hparams=self.hparams) if discriminator_fn else None - self.generator_scope = generator_scope - self.discriminator_scope = discriminator_scope - self.aggregate_nccl = aggregate_nccl - - if any(self.hparams.lr_boundaries): - global_step = tf.train.get_or_create_global_step() - lr_values = list(self.hparams.lr * 0.1 ** np.arange(len(self.hparams.lr_boundaries) + 1)) - self.learning_rate = tf.train.piecewise_constant(global_step, self.hparams.lr_boundaries, lr_values) - elif any(self.hparams.decay_steps): - lr, end_lr = self.hparams.lr, self.hparams.end_lr - start_step, end_step = self.hparams.decay_steps - if start_step == end_step: - schedule = tf.cond(tf.less(tf.train.get_or_create_global_step(), start_step), - lambda: 0.0, lambda: 1.0) - else: - step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step) - schedule = tf.to_float(step - start_step) / tf.to_float(end_step - start_step) - self.learning_rate = lr + (end_lr - lr) * schedule - else: - self.learning_rate = self.hparams.lr - - if self.hparams.kl_weight: - if self.hparams.kl_anneal == 'none': - self.kl_weight = tf.constant(self.hparams.kl_weight, tf.float32) - elif self.hparams.kl_anneal == 'sigmoid': - k = self.hparams.kl_anneal_k - if k == -1.0: - raise ValueError('Invalid kl_anneal_k %d when kl_anneal is sigmoid.' % k) - iter_num = tf.train.get_or_create_global_step() - self.kl_weight = self.hparams.kl_weight / (1 + k * tf.exp(-tf.to_float(iter_num) / k)) - elif self.hparams.kl_anneal == 'linear': - start_step, end_step = self.hparams.kl_anneal_steps - step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step) - self.kl_weight = self.hparams.kl_weight * tf.to_float(step - start_step) / tf.to_float(end_step - start_step) - else: - raise NotImplementedError - else: - self.kl_weight = None - - # member variables that should be set by `self.build_graph` - # (in addition to the ones in the base class) - self.gen_images_enc = None - self.g_losses = None - self.d_losses = None - self.g_loss = None - self.d_loss = None - self.g_vars = None - self.d_vars = None - self.train_op = None - self.summary_op = None - self.image_summary_op = None - self.eval_summary_op = None - self.accum_eval_summary_op = None - self.accum_eval_metrics_reset_op = None - - def get_default_hparams_dict(self): - """ - The keys of this dict define valid hyperparameters for instances of - this class. A class inheriting from this one should override this - method if it has a different set of hyperparameters. - - Returns: - A dict with the following hyperparameters. - - batch_size: batch size for training. - lr: learning rate. if decay steps is non-zero, this is the - learning rate for steps <= decay_step. - end_lr: learning rate for steps >= end_decay_step if decay_steps - is non-zero, ignored otherwise. - decay_steps: (decay_step, end_decay_step) tuple. - max_steps: number of training steps. - beta1: momentum term of Adam. - beta2: momentum term of Adam. - context_frames: the number of ground-truth frames to pass in at - start. Must be specified during instantiation. - sequence_length: the number of frames in the video sequence, - including the context frames, so this model predicts - `sequence_length - context_frames` future frames. Must be - specified during instantiation. - """ - default_hparams = super(VideoPredictionModel, self).get_default_hparams_dict() - hparams = dict( - batch_size=16, - lr=0.001, - end_lr=0.0, - decay_steps=(200000, 300000), - lr_boundaries=(0,), - max_epochs=35, - beta1=0.9, - beta2=0.999, - context_frames=-1, - sequence_length=-1, - clip_length=10, - l1_weight=0.0, - l2_weight=1.0, - vgg_cdist_weight=0.0, - feature_l2_weight=0.0, - ae_l2_weight=0.0, - state_weight=0.0, - tv_weight=0.0, - image_sn_gan_weight=0.0, - image_sn_vae_gan_weight=0.0, - images_sn_gan_weight=0.0, - images_sn_vae_gan_weight=0.0, - video_sn_gan_weight=0.0, - video_sn_vae_gan_weight=0.0, - gan_feature_l2_weight=0.0, - gan_feature_cdist_weight=0.0, - vae_gan_feature_l2_weight=0.0, - vae_gan_feature_cdist_weight=0.0, - gan_loss_type='LSGAN', - joint_gan_optimization=False, - kl_weight=0.0, - kl_anneal='linear', - kl_anneal_k=-1.0, - kl_anneal_steps=(50000, 100000), - z_l1_weight=0.0, - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - def tower_fn(self, inputs): - """ - This method doesn't have side-effects. `inputs`, `targets`, and - `outputs` are batch-major but internal calculations use time-major - tensors. - """ - # batch-major to time-major - inputs = nest.map_structure(transpose_batch_time, inputs) - - with tf.variable_scope(self.generator_scope): - gen_outputs = self.generator_fn(inputs) - - if self.discriminator_fn: - with tf.variable_scope(self.discriminator_scope) as discrim_scope: - discrim_outputs = self.discriminator_fn(inputs, gen_outputs) - # post-update discriminator tensors (i.e. after the discriminator weights have been updated) - with tf.variable_scope(discrim_scope, reuse=True): - discrim_outputs_post = self.discriminator_fn(inputs, gen_outputs) - else: - discrim_outputs = {} - discrim_outputs_post = {} - - outputs = [gen_outputs, discrim_outputs] - total_num_outputs = sum([len(output) for output in outputs]) - outputs = OrderedDict(itertools.chain(*[output.items() for output in outputs])) - assert len(outputs) == total_num_outputs # ensure no output is lost because of repeated keys - - if isinstance(self.learning_rate, tf.Tensor): - outputs['learning_rate'] = self.learning_rate - if isinstance(self.kl_weight, tf.Tensor): - outputs['kl_weight'] = self.kl_weight - - if self.mode == 'train': - with tf.name_scope("discriminator_loss"): - d_losses = self.discriminator_loss_fn(inputs, outputs) - print_loss_info(d_losses, inputs, outputs) - with tf.name_scope("generator_loss"): - g_losses = self.generator_loss_fn(inputs, outputs) - print_loss_info(g_losses, inputs, outputs) - if discrim_outputs_post: - outputs_post = OrderedDict(itertools.chain(gen_outputs.items(), discrim_outputs_post.items())) - # generator losses after the discriminator weights have been updated - g_losses_post = self.generator_loss_fn(inputs, outputs_post) - else: - g_losses_post = g_losses - else: - d_losses = {} - g_losses = {} - g_losses_post = {} - with tf.name_scope("metrics"): - metrics = self.metrics_fn(inputs, outputs) - with tf.name_scope("eval_outputs_and_metrics"): - eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs) - - # time-major to batch-major - outputs_tuple = (outputs, eval_outputs) - outputs_tuple = nest.map_structure(transpose_batch_time, outputs_tuple) - losses_tuple = (d_losses, g_losses, g_losses_post) - losses_tuple = nest.map_structure(tf.convert_to_tensor, losses_tuple) - loss_tuple = tuple(tf.accumulate_n([loss * weight for loss, weight in losses.values()]) - if losses else tf.zeros(()) for losses in losses_tuple) - metrics_tuple = (metrics, eval_metrics) - metrics_tuple = nest.map_structure(transpose_batch_time, metrics_tuple) - return outputs_tuple, losses_tuple, loss_tuple, metrics_tuple - - def build_graph(self, inputs,finetune=False): - BaseVideoPredictionModel.build_graph(self, inputs) - - global_step = tf.train.get_or_create_global_step() - # ML 2021-06-23: Do not hide global step in self.saveable_variables - self.global_step = global_step - # Capture the variables created from here until the train_op for the - # saveable_variables. Note that if variables are being reused (e.g. - # they were created by a previously built model), those variables won't - # be captured here. - original_global_variables = tf.global_variables() - - if self.num_gpus <= 1: # cpu or 1 gpu - outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs) - self.outputs, self.eval_outputs = outputs_tuple - self.d_losses, self.g_losses, g_losses_post = losses_tuple - self.d_loss, self.g_loss, g_loss_post = loss_tuple - self.metrics, self.eval_metrics = metrics_tuple - - self.d_vars = tf.trainable_variables(self.discriminator_scope) - self.g_vars = tf.trainable_variables(self.generator_scope) - g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) - d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) - - - if self.mode == 'train' and (self.d_losses or self.g_losses): - with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): - if self.d_losses: - with tf.name_scope('d_compute_gradients'): - d_gradvars = d_optimizer.compute_gradients(self.d_loss, var_list=self.d_vars) - with tf.name_scope('d_apply_gradients'): - d_train_op = d_optimizer.apply_gradients(d_gradvars) - - else: - d_train_op = tf.no_op() - with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []): - if g_losses_post: - if not self.hparams.joint_gan_optimization: - replace_read_ops(g_loss_post, self.d_vars) - with tf.name_scope('g_compute_gradients'): - g_gradvars = g_optimizer.compute_gradients(g_loss_post, var_list=self.g_vars) - with tf.name_scope('g_apply_gradients'): - g_train_op = g_optimizer.apply_gradients(g_gradvars) - else: - g_train_op = tf.no_op() - with tf.control_dependencies([g_train_op]): - train_op = tf.assign_add(global_step, 1) - self.train_op = train_op - else: - self.train_op = None - - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [global_step] + global_variables - self.post_init_ops = [] - else: - if tf.get_variable_scope().name: - # This is because how variable scope works with empty strings when it's not the root scope, causing - # repeated forward slashes. - raise NotImplementedError('Unable to handle multi-gpu model created within a non-root variable scope.') - - tower_inputs = [OrderedDict() for _ in range(self.num_gpus)] - for name, input in self.inputs.items(): - input_splits = tf.split(input, self.num_gpus) # assumes batch_size is divisible by num_gpus - for i in range(self.num_gpus): - tower_inputs[i][name] = input_splits[i] - - tower_outputs_tuple = [] - tower_d_losses = [] - tower_g_losses = [] - tower_g_losses_post = [] - tower_d_loss = [] - tower_g_loss = [] - tower_g_loss_post = [] - tower_metrics_tuple = [] - for i in range(self.num_gpus): - worker_device = '/gpu:%d' % i - if self.aggregate_nccl: - scope_name = '' if i == 0 else 'v%d' % i - scope_reuse = False - device_setter = worker_device - else: - scope_name = '' - scope_reuse = i > 0 - device_setter = local_device_setter(worker_device=worker_device) - with tf.variable_scope(scope_name, reuse=scope_reuse): - with tf.device(device_setter): - outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(tower_inputs[i]) - tower_outputs_tuple.append(outputs_tuple) - d_losses, g_losses, g_losses_post = losses_tuple - tower_d_losses.append(d_losses) - tower_g_losses.append(g_losses) - tower_g_losses_post.append(g_losses_post) - d_loss, g_loss, g_loss_post = loss_tuple - tower_d_loss.append(d_loss) - tower_g_loss.append(g_loss) - tower_g_loss_post.append(g_loss_post) - tower_metrics_tuple.append(metrics_tuple) - self.d_vars = tf.trainable_variables(self.discriminator_scope) - self.g_vars = tf.trainable_variables(self.generator_scope) - - if self.aggregate_nccl: - scope_replica = lambda scope, i: ('' if i == 0 else 'v%d/' % i) + scope - tower_d_vars = [tf.trainable_variables( - scope_replica(self.discriminator_scope, i)) for i in range(self.num_gpus)] - tower_g_vars = [tf.trainable_variables( - scope_replica(self.generator_scope, i)) for i in range(self.num_gpus)] - assert self.d_vars == tower_d_vars[0] - assert self.g_vars == tower_g_vars[0] - tower_d_optimizer = [tf.train.AdamOptimizer( - self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)] - tower_g_optimizer = [tf.train.AdamOptimizer( - self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)] - - if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)): - tower_d_gradvars = [] - tower_g_gradvars = [] - tower_d_train_op = [] - tower_g_train_op = [] - with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): - if any(tower_d_losses): - for i in range(self.num_gpus): - with tf.device('/gpu:%d' % i): - with tf.name_scope(scope_replica('d_compute_gradients', i)): - d_gradvars = tower_d_optimizer[i].compute_gradients( - tower_d_loss[i], var_list=tower_d_vars[i]) - tower_d_gradvars.append(d_gradvars) - - all_d_grads, all_d_vars = tf_utils.split_grad_list(tower_d_gradvars) - all_d_grads = tf_utils.allreduce_grads(all_d_grads, average=True) - tower_d_gradvars = tf_utils.merge_grad_list(all_d_grads, all_d_vars) - - for i in range(self.num_gpus): - with tf.device('/gpu:%d' % i): - with tf.name_scope(scope_replica('d_apply_gradients', i)): - d_train_op = tower_d_optimizer[i].apply_gradients(tower_d_gradvars[i]) - tower_d_train_op.append(d_train_op) - d_train_op = tf.group(*tower_d_train_op) - else: - d_train_op = tf.no_op() - with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []): - if any(tower_g_losses_post): - for i in range(self.num_gpus): - with tf.device('/gpu:%d' % i): - if not self.hparams.joint_gan_optimization: - replace_read_ops(tower_g_loss_post[i], tower_d_vars[i]) - - with tf.name_scope(scope_replica('g_compute_gradients', i)): - g_gradvars = tower_g_optimizer[i].compute_gradients( - tower_g_loss_post[i], var_list=tower_g_vars[i]) - tower_g_gradvars.append(g_gradvars) - - all_g_grads, all_g_vars = tf_utils.split_grad_list(tower_g_gradvars) - all_g_grads = tf_utils.allreduce_grads(all_g_grads, average=True) - tower_g_gradvars = tf_utils.merge_grad_list(all_g_grads, all_g_vars) - - for i, g_gradvars in enumerate(tower_g_gradvars): - with tf.device('/gpu:%d' % i): - with tf.name_scope(scope_replica('g_apply_gradients', i)): - g_train_op = tower_g_optimizer[i].apply_gradients(g_gradvars) - tower_g_train_op.append(g_train_op) - g_train_op = tf.group(*tower_g_train_op) - else: - g_train_op = tf.no_op() - with tf.control_dependencies([g_train_op]): - train_op = tf.assign_add(global_step, 1) - self.train_op = train_op - else: - self.train_op = None - - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - tower_saveable_vars = [[] for _ in range(self.num_gpus)] - for var in global_variables: - m = re.match('v(\d+)/.*', var.name) - i = int(m.group(1)) if m else 0 - tower_saveable_vars[i].append(var) - self.saveable_variables = [global_step] + tower_saveable_vars[0] - - post_init_ops = [] - for i, saveable_vars in enumerate(tower_saveable_vars[1:], 1): - assert len(saveable_vars) == len(tower_saveable_vars[0]) - for var, var0 in zip(saveable_vars, tower_saveable_vars[0]): - assert var.name == 'v%d/%s' % (i, var0.name) - post_init_ops.append(var.assign(var0.read_value())) - self.post_init_ops = post_init_ops - else: # not self.aggregate_nccl (i.e. aggregation in cpu) - g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) - d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) - - if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)): - with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): - if any(tower_d_losses): - with tf.name_scope('d_compute_gradients'): - d_gradvars = compute_averaged_gradients( - d_optimizer, tower_d_loss, var_list=self.d_vars) - with tf.name_scope('d_apply_gradients'): - d_train_op = d_optimizer.apply_gradients(d_gradvars) - else: - d_train_op = tf.no_op() - with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []): - if any(tower_g_losses_post): - for g_loss_post in tower_g_loss_post: - if not self.hparams.joint_gan_optimization: - replace_read_ops(g_loss_post, self.d_vars) - with tf.name_scope('g_compute_gradients'): - g_gradvars = compute_averaged_gradients( - g_optimizer, tower_g_loss_post, var_list=self.g_vars) - with tf.name_scope('g_apply_gradients'): - g_train_op = g_optimizer.apply_gradients(g_gradvars) - else: - g_train_op = tf.no_op() - with tf.control_dependencies([g_train_op]): - train_op = tf.assign_add(global_step, 1) - self.train_op = train_op - else: - self.train_op = None - - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [global_step] + global_variables - self.post_init_ops = [] - - # Device that runs the ops to apply global gradient updates. - consolidation_device = '/cpu:0' - with tf.device(consolidation_device): - with tf.name_scope('consolidation'): - self.outputs, self.eval_outputs = reduce_tensors(tower_outputs_tuple) - self.d_losses = reduce_tensors(tower_d_losses, shallow=True) - self.g_losses = reduce_tensors(tower_g_losses, shallow=True) - self.metrics, self.eval_metrics = reduce_tensors(tower_metrics_tuple) - self.d_loss = reduce_tensors(tower_d_loss) - self.g_loss = reduce_tensors(tower_g_loss) - - original_local_variables = set(tf.local_variables()) - self.accum_eval_metrics = OrderedDict() - for name, eval_metric in self.eval_metrics.items(): - _, self.accum_eval_metrics['accum_' + name] = tf.metrics.mean_tensor(eval_metric) - local_variables = set(tf.local_variables()) - original_local_variables - self.accum_eval_metrics_reset_op = tf.group([tf.assign(v, tf.zeros_like(v)) for v in local_variables]) - - original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - add_summaries(self.inputs) - add_summaries(self.outputs) - add_scalar_summaries(self.d_losses) - add_scalar_summaries(self.g_losses) - add_scalar_summaries(self.metrics) - if self.d_losses: - add_scalar_summaries({'d_loss': self.d_loss}) - if self.g_losses: - add_scalar_summaries({'g_loss': self.g_loss}) - if self.d_losses and self.g_losses: - add_scalar_summaries({'loss': self.d_loss + self.g_loss}) - summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries - # split summaries into non-image summaries and image summaries - self.summary_op = tf.summary.merge(list(summaries - set(tf.get_collection(tf_utils.IMAGE_SUMMARIES)))) - self.image_summary_op = tf.summary.merge(list(summaries & set(tf.get_collection(tf_utils.IMAGE_SUMMARIES)))) - - original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - add_gif_summaries(self.eval_outputs) - add_plot_and_scalar_summaries( - {name: tf.reduce_mean(metric, axis=0) for name, metric in self.eval_metrics.items()}, - x_offset=self.hparams.context_frames + 1) - summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries - self.eval_summary_op = tf.summary.merge(list(summaries)) - - original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - add_plot_and_scalar_summaries( - {name: tf.reduce_mean(metric, axis=0) for name, metric in self.accum_eval_metrics.items()}, - x_offset=self.hparams.context_frames + 1) - summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries - self.accum_eval_summary_op = tf.summary.merge(list(summaries)) - - def generator_loss_fn(self, inputs, outputs): - hparams = self.hparams - opt_var = self.hparams.opt_var - gen_losses = OrderedDict() - if opt_var == "all": - gen_images = outputs.get("gen_images_enc", outputs["gen_images"]) - target_images = inputs["images"][1:] - print("The model is optimized on all variables/channels in the loss function") - elif opt_var != "all" and isinstance(opt_var,str): - opt_var = int(opt_var) - print("The model is optimized on the {} variable/channel in the loss function".format(opt_var)) - gen_images = outputs.get("gen_images_enc", outputs["gen_images"])[:, :, :, :, opt_var:opt_var+1] - target_images = inputs["images"][1:][:, :, :, :, opt_var:opt_var+1] - else: - raise ValueError("The opt_var in the hyper-parameters setting should be int or 'all'") - - - if hparams.l1_weight: - gen_l1_loss = vp.losses.l1_loss(gen_images, target_images) - gen_losses["gen_l1_loss"] = (gen_l1_loss, hparams.l1_weight) - if hparams.l2_weight: - gen_l2_loss = vp.losses.l2_loss(gen_images, target_images) - gen_losses["gen_l2_loss"] = (gen_l2_loss, hparams.l2_weight) - if hparams.vgg_cdist_weight: - gen_vgg_cdist_loss = vp.metrics.vgg_cosine_distance(gen_images, target_images) - gen_losses['gen_vgg_cdist_loss'] = (gen_vgg_cdist_loss, hparams.vgg_cdist_weight) - if hparams.feature_l2_weight: - gen_features = outputs.get('gen_features_enc', outputs['gen_features']) - target_features = outputs['features'][1:] - gen_feature_l2_loss = vp.losses.l2_loss(gen_features, target_features) - gen_losses["gen_feature_l2_loss"] = (gen_feature_l2_loss, hparams.feature_l2_weight) - if hparams.ae_l2_weight: - gen_images_dec = outputs.get('gen_images_dec_enc', outputs['gen_images_dec']) # they both should be the same - target_images = inputs['images'] - gen_ae_l2_loss = vp.losses.l2_loss(gen_images_dec, target_images) - gen_losses["gen_ae_l2_loss"] = (gen_ae_l2_loss, hparams.ae_l2_weight) - if hparams.state_weight: - gen_states = outputs.get('gen_states_enc', outputs['gen_states']) - target_states = inputs['states'][1:] - gen_state_loss = vp.losses.l2_loss(gen_states, target_states) - gen_losses["gen_state_loss"] = (gen_state_loss, hparams.state_weight) - if hparams.tv_weight: - gen_flows = outputs.get('gen_flows_enc', outputs['gen_flows']) - flow_diff1 = gen_flows[..., 1:, :, :, :] - gen_flows[..., :-1, :, :, :] - flow_diff2 = gen_flows[..., :, 1:, :, :] - gen_flows[..., :, :-1, :, :] - # sum over the multiple transformations but take the mean for the other dimensions - gen_tv_loss = (tf.reduce_mean(tf.reduce_sum(tf.abs(flow_diff1), axis=(-2, -1))) + - tf.reduce_mean(tf.reduce_sum(tf.abs(flow_diff2), axis=(-2, -1)))) - gen_losses['gen_tv_loss'] = (gen_tv_loss, hparams.tv_weight) - gan_weights = {'_image_sn': hparams.image_sn_gan_weight, - '_images_sn': hparams.images_sn_gan_weight, - '_video_sn': hparams.video_sn_gan_weight} - for infix, gan_weight in gan_weights.items(): - if gan_weight: - gen_gan_loss = vp.losses.gan_loss(outputs['discrim%s_logits_fake' % infix], 1.0, hparams.gan_loss_type) - gen_losses["gen%s_gan_loss" % infix] = (gen_gan_loss, gan_weight) - if gan_weight and (hparams.gan_feature_l2_weight or hparams.gan_feature_cdist_weight): - i_feature = 0 - discrim_features_fake = [] - discrim_features_real = [] - while True: - discrim_feature_fake = outputs.get('discrim%s_feature%d_fake' % (infix, i_feature)) - discrim_feature_real = outputs.get('discrim%s_feature%d_real' % (infix, i_feature)) - if discrim_feature_fake is None or discrim_feature_real is None: - break - discrim_features_fake.append(discrim_feature_fake) - discrim_features_real.append(discrim_feature_real) - i_feature += 1 - if hparams.gan_feature_l2_weight: - gen_gan_feature_l2_loss = sum([vp.losses.l2_loss(discrim_feature_fake, discrim_feature_real) - for discrim_feature_fake, discrim_feature_real in zip(discrim_features_fake, discrim_features_real)]) - gen_losses["gen%s_gan_feature_l2_loss" % infix] = (gen_gan_feature_l2_loss, hparams.gan_feature_l2_weight) - if hparams.gan_feature_cdist_weight: - gen_gan_feature_cdist_loss = sum([vp.losses.cosine_distance(discrim_feature_fake, discrim_feature_real) - for discrim_feature_fake, discrim_feature_real in zip(discrim_features_fake, discrim_features_real)]) - gen_losses["gen%s_gan_feature_cdist_loss" % infix] = (gen_gan_feature_cdist_loss, hparams.gan_feature_cdist_weight) - vae_gan_weights = {'_image_sn': hparams.image_sn_vae_gan_weight, - '_images_sn': hparams.images_sn_vae_gan_weight, - '_video_sn': hparams.video_sn_vae_gan_weight} - for infix, vae_gan_weight in vae_gan_weights.items(): - if vae_gan_weight: - gen_vae_gan_loss = vp.losses.gan_loss(outputs['discrim%s_logits_enc_fake' % infix], 1.0, hparams.gan_loss_type) - gen_losses["gen%s_vae_gan_loss" % infix] = (gen_vae_gan_loss, vae_gan_weight) - if vae_gan_weight and (hparams.vae_gan_feature_l2_weight or hparams.vae_gan_feature_cdist_weight): - i_feature = 0 - discrim_features_enc_fake = [] - discrim_features_enc_real = [] - while True: - discrim_feature_enc_fake = outputs.get('discrim%s_feature%d_enc_fake' % (infix, i_feature)) - discrim_feature_enc_real = outputs.get('discrim%s_feature%d_enc_real' % (infix, i_feature)) - if discrim_feature_enc_fake is None or discrim_feature_enc_real is None: - break - discrim_features_enc_fake.append(discrim_feature_enc_fake) - discrim_features_enc_real.append(discrim_feature_enc_real) - i_feature += 1 - if hparams.vae_gan_feature_l2_weight: - gen_vae_gan_feature_l2_loss = sum([vp.losses.l2_loss(discrim_feature_enc_fake, discrim_feature_enc_real) - for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real)]) - gen_losses["gen%s_vae_gan_feature_l2_loss" % infix] = (gen_vae_gan_feature_l2_loss, hparams.vae_gan_feature_l2_weight) - if hparams.vae_gan_feature_cdist_weight: - gen_vae_gan_feature_cdist_loss = sum([vp.losses.cosine_distance(discrim_feature_enc_fake, discrim_feature_enc_real) - for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real)]) - gen_losses["gen%s_vae_gan_feature_cdist_loss" % infix] = (gen_vae_gan_feature_cdist_loss, hparams.vae_gan_feature_cdist_weight) - if hparams.kl_weight: - gen_kl_loss = vp.losses.kl_loss(outputs['zs_mu_enc'], outputs['zs_log_sigma_sq_enc'], - outputs.get('zs_mu_prior'), outputs.get('zs_log_sigma_sq_prior')) - gen_losses["gen_kl_loss"] = (gen_kl_loss, self.kl_weight) # possibly annealed kl_weight - return gen_losses - - def discriminator_loss_fn(self, inputs, outputs): - hparams = self.hparams - discrim_losses = OrderedDict() - gan_weights = {'_image_sn': hparams.image_sn_gan_weight, - '_images_sn': hparams.images_sn_gan_weight, - '_video_sn': hparams.video_sn_gan_weight} - for infix, gan_weight in gan_weights.items(): - if gan_weight: - discrim_gan_loss_real = vp.losses.gan_loss(outputs['discrim%s_logits_real' % infix], 1.0, hparams.gan_loss_type) - discrim_gan_loss_fake = vp.losses.gan_loss(outputs['discrim%s_logits_fake' % infix], 0.0, hparams.gan_loss_type) - discrim_gan_loss = discrim_gan_loss_real + discrim_gan_loss_fake - discrim_losses["discrim%s_gan_loss" % infix] = (discrim_gan_loss, gan_weight) - vae_gan_weights = {'_image_sn': hparams.image_sn_vae_gan_weight, - '_images_sn': hparams.images_sn_vae_gan_weight, - '_video_sn': hparams.video_sn_vae_gan_weight} - for infix, vae_gan_weight in vae_gan_weights.items(): - if vae_gan_weight: - discrim_vae_gan_loss_real = vp.losses.gan_loss(outputs['discrim%s_logits_enc_real' % infix], 1.0, hparams.gan_loss_type) - discrim_vae_gan_loss_fake = vp.losses.gan_loss(outputs['discrim%s_logits_enc_fake' % infix], 0.0, hparams.gan_loss_type) - discrim_vae_gan_loss = discrim_vae_gan_loss_real + discrim_vae_gan_loss_fake - discrim_losses["discrim%s_vae_gan_loss" % infix] = (discrim_vae_gan_loss, vae_gan_weight) - return discrim_losses diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py index c4035963dc37b840e55cdacb2c7288b10e458cfd..5bad1430f9c7499965899fb948bd46bb8e7686c9 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py @@ -3,35 +3,27 @@ # SPDX-License-Identifier: MIT __email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong,Yanji" +__author__ = "Bing Gong" __date__ = "2021-04-13" -from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames import tensorflow as tf -from model_modules.video_prediction.layers import layer_def as ld -from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell +from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames +from model_modules.video_prediction.layers.layer_def import batch_norm +from model_modules.video_prediction.models.vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel as convLSTM from .our_base_model import BaseModels class ConvLstmGANVideoPredictionModel(BaseModels): - def __init__(self, hparams_dict=None, mode='train', **kwargs): - """ - This is class for building convLSTM_GAN architecture by using updated hparameters - args: - mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model - hparams_dict: dict, the dictionary contains the hparaemters names and values - """ - super().__init__(hparams_dict) - self.hparams = self.get_hparams() - self.mode = mode + + def __init__(self, hparams_dict_config=None, mode='train'): + super().__init__(hparams_dict_config, mode) self.bd1 = batch_norm(name = "bd1") self.bd2 = batch_norm(name = "bd2") + self.bd3 = batch_norm(name = "dis3") - def get_hparams(self): + def parse_hparams(self, hparams): """ - obtain the hparams from the dict to the class variables + Obtain the hparams from the dict to the class variables """ - method = BaseModels.get_hparams.__name__ - try: self.context_frames = self.hparams.context_frames self.max_epochs = self.hparams.max_epochs @@ -41,168 +33,153 @@ class ConvLstmGANVideoPredictionModel(BaseModels): self.recon_weight = self.hparams.recon_weight self.learning_rate = self.hparams.lr self.sequence_length = self.hparams.sequence_length + self.opt_var = self.hparams.opt_var self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) self.ngf = self.hparams.ngf self.ndf = self.hparams.ndf - except Exception as error: - print("Method %{}: error: {}".format(method,error)) - raise("Method %{}: the hparameter dictionary must include parameters above".format(method)) + print("error: {}".format(error)) + raise ValueError("Method %{}: the hyper-parameter dictionary must include parameters above") def build_graph(self, x: tf.Tensor): - self.is_build_graph = False self.inputs = x + self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables() - # Architecture - self.build_model() - # define loss function - self.total_loss = (1-self.recon_weight) * self.G_loss + self.recon_weight*self.recon_loss - self.D_loss = (1-self.recon_weight) * self.D_loss + + #Build graph + x_hat = self.build_model(x) + + #Get losses (reconstruciton loss, total loss and descriminator loss) + self.total_loss = self.get_loss(x, x_hat) + + #Define optimizer + self.train_op = self.optimizer(self.total_loss) + + #Save to outputs + self.outputs["gen_images"] = x_hat + self.outputs["total_loss"] = self.total_loss + # Summary op + sum_dict = {"total_loss": self.total_loss, + "D_loss": self.D_loss, + "G_loss": self.G_loss, + "D_loss_fake": self.D_loss_fake, + "D_loss_real": self.D_loss_real, + "recon_loss": self.recon_loss} + + self.summary_op = self.summary(**sum_dict) + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self.is_build_graph = True + return self.is_build_graph + + def get_loss(self, x: tf.Tensor, x_hat: tf.Tensor): + """ + We use the loss from vanilla convolutional LSTM as reconstruction loss + """ + self.G_loss = self.get_gen_loss() + self.D_loss = self.get_disc_loss() + self._get_vars() + #self.recon_loss = self.get_loss(self, x, x_hat) #use the loss from vanilla convLSTM + + if self.opt_var == "all": + x = x[:, self.context_frames:, :, :, :] + print("The model is optimzied on all the variables in the loss function") + elif self.opt_var != "all" and isinstance(self.opt_var, str): + self.opt_var = int(self.opt_var) + print("The model is optimized on the {} variable in the loss function".format(self.opt_var)) + x = x[:, self.context_frames:, :, :, self.opt_var] + x_hat = x_hat[:, :, :, :, self.opt_var] + else: + raise ValueError("The opt var in the hyperparameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables") + + if self.loss_fun == "mse": + self.recon_loss = tf.reduce_mean(tf.square(x - x_hat)) + elif self.loss_fun == "cross_entropy": + x_flatten = tf.reshape(x, [-1]) + x_hat_predict_frames_flatten = tf.reshape(x_hat, [-1]) + bce = tf.keras.losses.BinaryCrossentropy() + self.recon_loss = bce(x_flatten, x_hat_predict_frames_flatten) + else: + raise ValueError("Loss function is not selected properly, you should chose either 'mse' or 'cross_entropy'") + + self.D_loss = (1 - self.recon_weight) * self.D_loss + total_loss = (1-self.recon_weight) * self.G_loss + self.recon_weight*self.recon_loss + return total_loss + + def optimizer(self, *args): if self.mode == "train": if self.recon_weight == 1: - print("Only train generator- convLSTM") - self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) + print("Only train generator- ConvLSTM") + train_op = tf.train.AdamOptimizer(learning_rate = + self.learning_rate).\ + minimize(self.total_loss, var_list=self.gen_vars) else: - print("Training distriminator") - self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars) + print("Training discriminator") + self.D_solver = tf.train.AdamOptimizer(learning_rate =self.learning_rate).\ + minimize(self.D_loss, var_list=self.disc_vars) with tf.control_dependencies([self.D_solver]): print("Training generator....") - self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) + self.G_solver = tf.train.AdamOptimizer(learning_rate =self.learning_rate).\ + minimize(self.total_loss, var_list=self.gen_vars) with tf.control_dependencies([self.G_solver]): - self.train_op = tf.assign_add(self.global_step,1) + train_op = tf.assign_add(self.global_step, 1) else: - self.train_op = None + train_op = None + return train_op - self.outputs["gen_images"] = self.gen_images - self.outputs["total_loss"] = self.total_loss - # Summary op - tf.summary.scalar("total_loss", self.total_loss) - tf.summary.scalar("D_loss", self.D_loss) - tf.summary.scalar("G_loss", self.G_loss) - tf.summary.scalar("D_loss_fake", self.D_loss_fake) - tf.summary.scalar("D_loss_real", self.D_loss_real) - tf.summary.scalar("recon_loss",self.recon_loss) - self.summary_op = tf.summary.merge_all() - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [self.global_step] + global_variables - self.is_build_graph = True - return self.is_build_graph - @staticmethod - def Unet_ConvLSTM_cell(x: tf.Tensor, ngf: int, hidden: tf.Tensor): + def build_model(self, x): """ - Build up a Unet ConvLSTM cell for each time stamp i - params: x: the input at timestamp i - params: ngf: the numnber of filters for convoluational layers - params: hidden: the hidden state from the previous timestamp t-1 - return: - outputs: the predict frame at timestamp i - hidden: the hidden state at current timestamp i + Define gan architectures """ - input_shape = x.get_shape().as_list() - num_channels = input_shape[3] - with tf.variable_scope("down_scale", reuse = tf.AUTO_REUSE): - conv1f = ld.conv_layer(x, 3 , 1, ngf, 1, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") - conv1s = ld.conv_layer(conv1f, 3, 1, ngf, 2, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") - pool1 = tf.layers.max_pooling2d(conv1s, pool_size=(2, 2), strides=(2, 2)) - print('pool1 shape: ',pool1.shape) - - conv2f = ld.conv_layer(pool1, 3, 1, ngf * 2, 3, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") - conv2s = ld.conv_layer(conv2f, 3, 1, ngf * 2, 4, initializer = tf.contrib.layers.xavier_initializer(), activate = "relu") - pool2 = tf.layers.max_pooling2d(conv2s, pool_size=(2, 2), strides=(2, 2)) - print('pool2 shape: ',pool2.shape) - - conv3f = ld.conv_layer(pool2, 3, 1, ngf * 4, 5, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") - conv3s = ld.conv_layer(conv3f, 3, 1, ngf * 4, 6, initializer = tf.contrib.layers.xavier_initializer(), activate = "relu") - pool3 = tf.layers.max_pooling2d(conv3s, pool_size=(2, 2), strides=(2, 2)) - print('pool3 shape: ',pool3.shape) - - convLSTM_input = pool3 - #convLSTM_input = tf.layers.dropout(pool2, 0.8) - - convLSTM4, hidden = ConvLstmGANVideoPredictionModel.convLSTM_cell(convLSTM_input, hidden) - print('convLSTM4 shape: ',convLSTM4.shape) - - with tf.variable_scope("upscale", reuse = tf.AUTO_REUSE): - deconv5 = ld.transpose_conv_layer(convLSTM4, 2, 2, ngf * 4, 1, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") - print('deconv5 shape: ',deconv5.shape) - up5 = tf.concat([deconv5, conv3s], axis=3) - print('up5 shape: ',up5.shape) - - conv5f = ld.conv_layer(up5, 3, 1, ngf * 4, 2, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") - conv5s = ld.conv_layer(conv5f, 3, 1, ngf * 4, 3, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") - print('conv5s shape:',conv5s.shape) - - deconv6 = ld.transpose_conv_layer(conv5s, 2, 2, ngf * 2, 4, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") - print('deconv6 shape: ',deconv6.shape) - up6 = tf.concat([deconv6, conv2s], axis=3) - print('up6 shape: ',up6.shape) - - conv6f = ld.conv_layer(up6, 3, 1, ngf * 2, 5, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") - conv6s = ld.conv_layer(conv6f, 3, 1, ngf * 2, 6, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") - print('conv6s shape:',conv6s.shape) - - deconv7 = ld.transpose_conv_layer(conv6s, 2, 2, ngf, 7, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") - print('deconv7 shape: ',deconv7.shape) - up7 = tf.concat([deconv7, conv1s], axis=3) - print('up7 shape: ',up7.shape) - - conv7f = ld.conv_layer(up7, 3, 1, ngf, 8, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") - conv7s = ld.conv_layer(conv7f, 3, 1, ngf, 9, initializer = tf.contrib.layers.xavier_initializer(),activate= "relu") - print('conv7s shape:',conv7s.shape) - - conv7t = ld.conv_layer(conv7s, 3, 1, num_channels, 10, initializer = tf.contrib.layers.xavier_initializer(),activate="relu") - outputs = ld.conv_layer(conv7t, 1, 1, num_channels, 11, initializer = tf.contrib.layers.xavier_initializer(),activate="linear") - print('outputs shape: ',outputs.shape) - - return outputs, hidden - - def generator(self, x: tf.Tensor): + #conditional GAN + x_hat = self.generator(x) + + self.D_real, self.D_real_logits = self.discriminator(self.inputs[:, self.context_frames:, :, :, 0:1]) + self.D_fake, self.D_fake_logits = self.discriminator(x_hat[:, self.context_frames - 1:, :, :, 0:1]) + + return x_hat + + + def generator(self,x): """ - Function to build up the generator architecture, here we take Unet_ConvLSTM as generator + Function to build up the generator architecture args: input images: a input tensor with dimension (n_batch,sequence_length,height,width,channel) - output images: (n_batch,forecast_length,height,width,channel) """ - network_template = tf.make_template('network', ConvLstmGANVideoPredictionModel.Unet_ConvLSTM_cell) with tf.variable_scope("generator", reuse = tf.AUTO_REUSE): - # create network - x_hat = [] - #This is for training (optimization of convLSTM layer) - hidden_g = None - for i in range(self.sequence_length-1): - print('i: ',i) - if i < self.context_frames: - x_1_g, hidden_g = network_template(x[:, i, :, :, :], self.ngf, hidden_g) - else: - x_1_g, hidden_g = network_template(x_1_g, self.ngf, hidden_g) - x_hat.append(x_1_g) - # pack them all together - x_hat = tf.stack(x_hat) - self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) - print('self.x_hat shape is: ',self.x_hat.shape) - return self.x_hat - - def discriminator(self, x): + network_template = tf.make_template('network', + convLSTM.convLSTM_cell) # make the template to share the variables + + x_hat = convLSTM.convLSTM_network(self.inputs, + self.sequence_length, + self.context_frames, + network_template) + return x_hat + + + def discriminator(self, vid): """ - Function that get discriminator architecture + Function that get discriminator architecture """ with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): - conv1 = tf.layers.conv3d(x, 4, kernel_size=[4,4,4], strides=[1,2,2], padding="SAME", name="dis1") - conv1 = ConvLstmGANVideoPredictionModel.lrelu(conv1) - #conv2 = tf.layers.conv3d(conv1, 1, kernel_size=[4,4,4], strides=[1,2,2], padding="SAME", name="dis2") - conv2 = tf.reshape(conv1, [-1,1]) - #fc1 = ConvLstmGANVideoPredictionModel.lrelu(self.bd1(ConvLstmGANVideoPredictionModel.linear(conv2, output_size=256, scope='d_fc1'))) - fc2 = ConvLstmGANVideoPredictionModel.lrelu(self.bd2(ConvLstmGANVideoPredictionModel.linear(conv2, output_size=64, scope='d_fc2'))) - out_logit = ConvLstmGANVideoPredictionModel.linear(fc2, 1, scope='d_fc3') - out = tf.nn.sigmoid(out_logit) - #out,out_logit = self.Conv3Dnet(x,self.ndf) - return out, out_logit + conv1 = tf.layers.conv3d(vid,64,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis1") + conv1 = self._lrelu(conv1) + conv2 = tf.layers.conv3d(conv1, 128, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis2") + conv2 = self._lrelu(self.bd1(conv2)) + conv3 = tf.layers.conv3d(conv2, 256, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME" ,name="dis3") + conv3 = self._lrelu(self.bd2(conv3)) + conv4 = tf.layers.conv3d(conv3, 512, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis4") + conv4 = self._lrelu(self.bd3(conv4)) + conv5 = tf.layers.conv3d(conv4, 1, kernel_size=[2,4,4],strides=[1,1,1],padding="SAME", name="dis5") + conv5 = tf.reshape(conv5, [-1,1]) + conv5sigmoid = tf.nn.sigmoid(conv5) + return conv5sigmoid, conv5 def get_disc_loss(self): """ @@ -212,8 +189,10 @@ class ConvLstmGANVideoPredictionModel(BaseModels): gen_labels = tf.zeros_like(self.D_fake) self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real_logits, labels=real_labels)) self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=gen_labels)) - self.D_loss = self.D_loss_real + self.D_loss_fake - return self.D_loss + D_loss = self.D_loss_real + self.D_loss_fake + return D_loss + + def get_gen_loss(self): """ @@ -223,129 +202,24 @@ class ConvLstmGANVideoPredictionModel(BaseModels): Return the loss of generator given inputs """ real_labels = tf.ones_like(self.D_fake) - self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=real_labels)) - return self.G_loss + G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, + labels=real_labels)) + return G_loss - def get_vars(self): + def _get_vars(self): """ Get trainable variables from discriminator and generator """ self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] - - def build_model(self): - """ - Define gan architectures - """ - self.gen_images = self.generator(self.inputs) - self.D_real, self.D_real_logits = self.discriminator(self.inputs[:,self.context_frames:, :, :, 0:1]) # use the first varibale as targeted - #self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,:,:,:,0:1]) #0:1 - self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,self.context_frames-1:, :, :, 0:1]) #0:1 - - self.get_gen_loss() - self.get_disc_loss() - self.get_vars() - if self.loss_fun == "rmse": - #self.recon_loss = tf.reduce_mean(tf.square(self.inputs[:, self.context_frames:,:,:,0] - self.gen_images[:,:,:,:,0])) - self.recon_loss = tf.reduce_mean(tf.square(self.inputs[:, self.context_frames:, :, :, 0] - self.gen_images[:, self.context_frames-1:, :, :, 0])) - elif self.loss_fun == "cross_entropy": - x_flatten = tf.reshape(self.inputs[:, self.context_frames:,:,:,0],[-1]) - #x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,:,:,:,0],[-1]) - x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,self.context_frames-1:, :, :, 0], [-1]) - bce = tf.keras.losses.BinaryCrossentropy() - self.recon_loss = bce(x_flatten, x_hat_predict_frames_flatten) - else: - raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'") - - @staticmethod - def convLSTM_cell(inputs, hidden): - y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset - # conv lstm cell - cell_shape = y_0.get_shape().as_list() - channels = cell_shape[-1] - with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): - cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64) - if hidden is None: - hidden = cell.zero_state(y_0, tf.float32) - output, hidden = cell(y_0, hidden) - return output, hidden - #output_shape = output.get_shape().as_list() - #z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) - ###we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction - #x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") - #print('x_hat shape is: ',x_hat.shape) - #return x_hat, hidden - - def get_noise(self, x, sigma=0.2): - """ - Function for creating noise: Given the dimensions (n_batch,n_seq, n_height, n_width, channel) - """ - x_shape = x.get_shape().as_list() - noise = sigma * tf.random.uniform(minval=-1., maxval=1., shape=x_shape) - x = x + noise - return x - - def Conv3Dnet_v1(self, x, ndf): - conv1 = tf.layers.conv3d(x, ndf, kernel_size = [4, 4, 4], strides = [1, 2, 2], padding = "SAME", name = 'conv1') - conv1 = self.lrelu(conv1) - # conv2 = tf.layers.conv3d(conv1,ndf*2,kernel_size=[4,4,4],strides=[1,2,2],padding="SAME",name='conv2') - # conv2 = self.lrelu(conv2) - conv3 = tf.layers.conv3d(conv1, 1, kernel_size = [4, 4, 4], strides = [1, 1, 1], padding = "SAME", name = 'conv3') - fl = tf.reshape(conv3, [-1, 1]) - print('fl shape: ', fl.shape) - fc1 = self.lrelu(self.bd1(self.linear(fl, 256, scope = 'fc1'))) - print('fc1 shape: ', fc1.shape) - fc2 = self.lrelu(self.bd2(self.linear(fc1, 64, scope = 'fc2'))) - print('fc2 shape: ', fc2.shape) - out_logit = self.linear(fc2, 1, scope = 'out') - out = tf.nn.sigmoid(out_logit) - return out, out_logit - - - def Conv3Dnet_v2(self, x, ndf): - """ - args: - input images: a input tensor with dimension (n_batch,forecast_length,height,width,channel) - output images: - """ - conv1 = Conv3D(ndf, 4, strides = (1, 2, 2), padding = 'same', kernel_initializer = 'he_normal')(x) - bn1 = BatchNormalization()(conv1) - bn1 = LeakyReLU(0.2)(bn1) - pool1 = MaxPooling3D(pool_size = (1, 2, 2), padding = 'same')(bn1) - noise1 = self.get_noise(pool1) - - conv2 = Conv3D(ndf * 2, 4, strides = (1, 2, 2), padding = 'same', kernel_initializer = 'he_normal')(noise1) - bn2 = BatchNormalization()(conv2) - bn2 = LeakyReLU(0.2)(bn2) - pool2 = MaxPooling3D(pool_size = (1, 2, 2), padding = 'same')(bn2) - noise2 = self.get_noise(pool2) - - conv3 = Conv3D(ndf * 4, 4, strides = (1, 2, 2), padding = 'same', kernel_initializer = 'he_normal')(noise2) - bn3 = BatchNormalization()(conv3) - bn3 = LeakyReLU(0.2)(bn3) - pool3 = MaxPooling3D(pool_size = (1, 2, 2), padding = 'same')(bn3) - - conv4 = Conv3D(1, 4, 1, padding = 'same')(pool3) - - fl = tf.reshape(conv4, [-1, 1]) - drop1 = Dropout(0.3)(fl) - fc1 = Dense(1024, activation = 'relu')(drop1) - drop2 = Dropout(0.3)(fc1) - fc2 = Dense(512, activation = 'relu')(drop2) - out_logit = Dense(1, activation = 'linear')(fc2) - out = tf.nn.sigmoid(out_logit) - return out, out_logit - - - @staticmethod - def lrelu(x, leak=0.2, name='lrelu'): - return tf.maximum(x, leak * x) - @staticmethod - def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): - shape = input_.get_shape().as_list() + def _lrelu(self, x, leak=0.2): + return tf.maximum(x, leak * x) + + def _linear(self, input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): + shape = input_.get_shape().as_list() with tf.variable_scope(scope or "Linear"): matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, tf.random_normal_initializer(stddev = stddev)) @@ -356,19 +230,5 @@ class ConvLstmGANVideoPredictionModel(BaseModels): else: return tf.matmul(input_, matrix) + bias -class batch_norm(object): - def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): - with tf.variable_scope(name): - self.epsilon = epsilon - self.momentum = momentum - self.name = name - - def __call__(self, x, train=True): - return tf.contrib.layers.batch_norm(x, - decay=self.momentum, - updates_collections=None, - epsilon=self.epsilon, - scale=True, - is_training=train, - scope=self.name) + diff --git a/video_prediction_tools/model_modules/video_prediction/models/dna_model.py b/video_prediction_tools/model_modules/video_prediction/models/dna_model.py deleted file mode 100644 index 8badf600f62c21d71cd81d8c2bfcde2f75e91d34..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/models/dna_model.py +++ /dev/null @@ -1,475 +0,0 @@ -# Copyright 2016 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Model architecture for predictive model, including CDNA, DNA, and STP.""" - -import itertools - -import numpy as np -import tensorflow as tf -import tensorflow.contrib.slim as slim -from tensorflow.contrib.layers.python import layers as tf_layers -from model_modules.video_prediction.models import VideoPredictionModel -from .sna_model import basic_conv_lstm_cell - - -# Amount to use when lower bounding tensors -RELU_SHIFT = 1e-12 - - -def construct_model(images, - actions=None, - states=None, - iter_num=-1.0, - kernel_size=(5, 5), - k=-1, - use_state=True, - num_masks=10, - stp=False, - cdna=True, - dna=False, - context_frames=2, - pix_distributions=None): - """Build convolutional lstm video predictor using STP, CDNA, or DNA. - - Args: - images: tensor of ground truth image sequences - actions: tensor of action sequences - states: tensor of ground truth state sequences - iter_num: tensor of the current training iteration (for sched. sampling) - k: constant used for scheduled sampling. -1 to feed in own prediction. - use_state: True to include state and action in prediction - num_masks: the number of different pixel motion predictions (and - the number of masks for each of those predictions) - stp: True to use Spatial Transformer Predictor (STP) - cdna: True to use Convoluational Dynamic Neural Advection (CDNA) - dna: True to use Dynamic Neural Advection (DNA) - context_frames: number of ground truth frames to pass in before - feeding in own predictions - Returns: - gen_images: predicted future image frames - gen_states: predicted future states - - Raises: - ValueError: if more than one network option specified or more than 1 mask - specified for DNA model. - """ - DNA_KERN_SIZE = kernel_size[0] - - if stp + cdna + dna != 1: - raise ValueError('More than one, or no network option specified.') - batch_size, img_height, img_width, color_channels = images[0].get_shape()[0:4] - lstm_func = basic_conv_lstm_cell - - # Generated robot states and images. - gen_states, gen_images = [], [] - gen_pix_distrib = [] - gen_masks = [] - current_state = states[0] - - if k == -1: - feedself = True - else: - # Scheduled sampling: - # Calculate number of ground-truth frames to pass in. - num_ground_truth = tf.to_int32( - tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k))))) - feedself = False - - # LSTM state sizes and states. - lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32])) - lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None - lstm_state5, lstm_state6, lstm_state7 = None, None, None - - for t, action in enumerate(actions): - # Reuse variables after the first timestep. - reuse = bool(gen_images) - - done_warm_start = len(gen_images) > context_frames - 1 - with slim.arg_scope( - [lstm_func, slim.layers.conv2d, slim.layers.fully_connected, - tf_layers.layer_norm, slim.layers.conv2d_transpose], - reuse=reuse): - - if feedself and done_warm_start: - # Feed in generated image. - prev_image = gen_images[-1] - if pix_distributions is not None: - prev_pix_distrib = gen_pix_distrib[-1] - elif done_warm_start: - # Scheduled sampling - prev_image = scheduled_sample(images[t], gen_images[-1], batch_size, - num_ground_truth) - else: - # Always feed in ground_truth - prev_image = images[t] - if pix_distributions is not None: - prev_pix_distrib = pix_distributions[t] - # prev_pix_distrib = tf.expand_dims(prev_pix_distrib, -1) - - # Predicted state is always fed back in - state_action = tf.concat(axis=1, values=[action, current_state]) - - enc0 = slim.layers.conv2d( - prev_image, - 32, [5, 5], - stride=2, - scope='scale1_conv1', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'layer_norm1'}) - - hidden1, lstm_state1 = lstm_func( - enc0, lstm_state1, lstm_size[0], scope='state1') - hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2') - hidden2, lstm_state2 = lstm_func( - hidden1, lstm_state2, lstm_size[1], scope='state2') - hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3') - enc1 = slim.layers.conv2d( - hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2') - - hidden3, lstm_state3 = lstm_func( - enc1, lstm_state3, lstm_size[2], scope='state3') - hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4') - hidden4, lstm_state4 = lstm_func( - hidden3, lstm_state4, lstm_size[3], scope='state4') - hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5') - enc2 = slim.layers.conv2d( - hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3') - - # Pass in state and action. - smear = tf.reshape( - state_action, - [int(batch_size), 1, 1, int(state_action.get_shape()[1])]) - smear = tf.tile( - smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1]) - if use_state: - enc2 = tf.concat(axis=3, values=[enc2, smear]) - enc3 = slim.layers.conv2d( - enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4') - - hidden5, lstm_state5 = lstm_func( - enc3, lstm_state5, lstm_size[4], scope='state5') # last 8x8 - hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6') - enc4 = slim.layers.conv2d_transpose( - hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1') - - hidden6, lstm_state6 = lstm_func( - enc4, lstm_state6, lstm_size[5], scope='state6') # 16x16 - hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7') - # Skip connection. - hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 - - enc5 = slim.layers.conv2d_transpose( - hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2') - hidden7, lstm_state7 = lstm_func( - enc5, lstm_state7, lstm_size[6], scope='state7') # 32x32 - hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8') - - # Skip connection. - hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 - - enc6 = slim.layers.conv2d_transpose( - hidden7, - hidden7.get_shape()[3], 3, stride=2, scope='convt3', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'layer_norm9'}) - - if dna: - # Using largest hidden state for predicting untied conv kernels. - enc7 = slim.layers.conv2d_transpose( - enc6, DNA_KERN_SIZE ** 2, 1, stride=1, scope='convt4') - else: - # Using largest hidden state for predicting a new image layer. - enc7 = slim.layers.conv2d_transpose( - enc6, color_channels, 1, stride=1, scope='convt4') - # This allows the network to also generate one image from scratch, - # which is useful when regions of the image become unoccluded. - transformed = [tf.nn.sigmoid(enc7)] - - if stp: - stp_input0 = tf.reshape(hidden5, [int(batch_size), -1]) - stp_input1 = slim.layers.fully_connected( - stp_input0, 100, scope='fc_stp') - - # disabling capability to generete pixels - reuse_stp = None - if reuse: - reuse_stp = reuse - transformed = stp_transformation(prev_image, stp_input1, num_masks, reuse_stp) - # transformed += stp_transformation(prev_image, stp_input1, num_masks) - - if pix_distributions is not None: - transf_distrib = stp_transformation(prev_pix_distrib, stp_input1, num_masks, reuse=True) - - elif cdna: - cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) - - new_transformed, cdna_kerns = cdna_transformation(prev_image, - cdna_input, - num_masks, - int(color_channels), - kernel_size, - reuse_sc=reuse) - transformed += new_transformed - - if pix_distributions is not None: - if not dna: - transf_distrib = [prev_pix_distrib] - new_transf_distrib, _ = cdna_transformation(prev_pix_distrib, - cdna_input, - num_masks, - prev_pix_distrib.shape[-1].value, - kernel_size, - reuse_sc=True) - transf_distrib += new_transf_distrib - - elif dna: - # Only one mask is supported (more should be unnecessary). - if num_masks != 1: - raise ValueError('Only one mask is supported for DNA model.') - transformed = [dna_transformation(prev_image, enc7, DNA_KERN_SIZE)] - - masks = slim.layers.conv2d_transpose( - enc6, num_masks + 1, 1, stride=1, scope='convt7') - masks = tf.reshape( - tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), - [int(batch_size), int(img_height), int(img_width), num_masks + 1]) - mask_list = tf.split(masks, num_masks + 1, axis=3) - output = mask_list[0] * prev_image - for layer, mask in zip(transformed, mask_list[1:]): - output += layer * mask - gen_images.append(output) - gen_masks.append(mask_list) - - if dna and pix_distributions is not None: - transf_distrib = [dna_transformation(prev_pix_distrib, enc7, DNA_KERN_SIZE)] - - if pix_distributions is not None: - pix_distrib_output = mask_list[0] * prev_pix_distrib - for layer, mask in zip(transf_distrib, mask_list[1:]): - pix_distrib_output += layer * mask - pix_distrib_output /= tf.reduce_sum(pix_distrib_output, axis=(1, 2), keepdims=True) - gen_pix_distrib.append(pix_distrib_output) - - if int(current_state.get_shape()[1]) == 0: - current_state = tf.zeros_like(state_action) - else: - current_state = slim.layers.fully_connected( - state_action, - int(current_state.get_shape()[1]), - scope='state_pred', - activation_fn=None) - gen_states.append(current_state) - - return gen_images, gen_states, gen_masks, gen_pix_distrib - - -## Utility functions -def stp_transformation(prev_image, stp_input, num_masks): - """Apply spatial transformer predictor (STP) to previous image. - - Args: - prev_image: previous image to be transformed. - stp_input: hidden layer to be used for computing STN parameters. - num_masks: number of masks and hence the number of STP transformations. - Returns: - List of images transformed by the predicted STP parameters. - """ - # Only import spatial transformer if needed. - from spatial_transformer import transformer - - identity_params = tf.convert_to_tensor( - np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32)) - transformed = [] - for i in range(num_masks - 1): - params = slim.layers.fully_connected( - stp_input, 6, scope='stp_params' + str(i), - activation_fn=None) + identity_params - transformed.append(transformer(prev_image, params)) - - return transformed - - -def cdna_transformation(prev_image, cdna_input, num_masks, color_channels, kernel_size, reuse_sc=None): - """Apply convolutional dynamic neural advection to previous image. - - Args: - prev_image: previous image to be transformed. - cdna_input: hidden lyaer to be used for computing CDNA kernels. - num_masks: the number of masks and hence the number of CDNA transformations. - color_channels: the number of color channels in the images. - Returns: - List of images transformed by the predicted CDNA kernels. - """ - batch_size = int(cdna_input.get_shape()[0]) - height = int(prev_image.get_shape()[1]) - width = int(prev_image.get_shape()[2]) - - # Predict kernels using linear function of last hidden layer. - cdna_kerns = slim.layers.fully_connected( - cdna_input, - kernel_size[0] * kernel_size[1] * num_masks, - scope='cdna_params', - activation_fn=None, - reuse=reuse_sc) - - # Reshape and normalize. - cdna_kerns = tf.reshape( - cdna_kerns, [batch_size, kernel_size[0], kernel_size[1], 1, num_masks]) - cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT - norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True) - cdna_kerns /= norm_factor - - # Treat the color channel dimension as the batch dimension since the same - # transformation is applied to each color channel. - # Treat the batch dimension as the channel dimension so that - # depthwise_conv2d can apply a different transformation to each sample. - cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3]) - cdna_kerns = tf.reshape(cdna_kerns, [kernel_size[0], kernel_size[1], batch_size, num_masks]) - # Swap the batch and channel dimensions. - prev_image = tf.transpose(prev_image, [3, 1, 2, 0]) - - # Transform image. - transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME') - - # Transpose the dimensions to where they belong. - transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks]) - transformed = tf.transpose(transformed, [3, 1, 2, 0, 4]) - transformed = tf.unstack(transformed, axis=-1) - return transformed, cdna_kerns - - -def dna_transformation(prev_image, dna_input, kernel_size): - """Apply dynamic neural advection to previous image. - - Args: - prev_image: previous image to be transformed. - dna_input: hidden lyaer to be used for computing DNA transformation. - Returns: - List of images transformed by the predicted CDNA kernels. - """ - # Construct translated images. - pad_along_height = (kernel_size[0] - 1) - pad_along_width = (kernel_size[1] - 1) - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - prev_image_pad = tf.pad(prev_image, [[0, 0], - [pad_top, pad_bottom], - [pad_left, pad_right], - [0, 0]]) - image_height = int(prev_image.get_shape()[1]) - image_width = int(prev_image.get_shape()[2]) - - inputs = [] - for xkern in range(kernel_size[0]): - for ykern in range(kernel_size[1]): - inputs.append( - tf.expand_dims( - tf.slice(prev_image_pad, [0, xkern, ykern, 0], - [-1, image_height, image_width, -1]), [3])) - inputs = tf.concat(axis=3, values=inputs) - - # Normalize channels to 1. - kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT - kernel = tf.expand_dims( - kernel / tf.reduce_sum( - kernel, [3], keepdims=True), [4]) - return tf.reduce_sum(kernel * inputs, [3], keepdims=False) - - -def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): - """Sample batch with specified mix of ground truth and generated data points. - - Args: - ground_truth_x: tensor of ground-truth data points. - generated_x: tensor of generated data points. - batch_size: batch size - num_ground_truth: number of ground-truth examples to include in batch. - Returns: - New batch with num_ground_truth sampled from ground_truth_x and the rest - from generated_x. - """ - idx = tf.random_shuffle(tf.range(int(batch_size))) - ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) - generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size))) - - ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) - generated_examps = tf.gather(generated_x, generated_idx) - return tf.dynamic_stitch([ground_truth_idx, generated_idx], - [ground_truth_examps, generated_examps]) - - -def generator_fn(inputs, hparams=None): - images = tf.unstack(inputs['images'], axis=0) - actions = tf.unstack(inputs['actions'], axis=0) - states = tf.unstack(inputs['states'], axis=0) - pix_distributions = tf.unstack(inputs['pix_distribs'], axis=0) if 'pix_distribs' in inputs else None - iter_num = tf.to_float(tf.train.get_or_create_global_step()) - - gen_images, gen_states, gen_masks, gen_pix_distrib = \ - construct_model(images, - actions, - states, - iter_num=iter_num, - kernel_size=hparams.kernel_size, - k=hparams.schedule_sampling_k, - num_masks=hparams.num_masks, - cdna=hparams.transformation == 'cdna', - dna=hparams.transformation == 'dna', - stp=hparams.transformation == 'stp', - context_frames=hparams.context_frames, - pix_distributions=pix_distributions) - outputs = { - 'gen_images': tf.stack(gen_images, axis=0), - 'gen_states': tf.stack(gen_states, axis=0), - 'masks': tf.stack([tf.stack(gen_mask_list, axis=-1) for gen_mask_list in gen_masks], axis=0), - } - if 'pix_distribs' in inputs: - outputs['gen_pix_distribs'] = tf.stack(gen_pix_distrib, axis=0) - gen_images = outputs['gen_images'][hparams.context_frames - 1:] - return gen_images, outputs - - -class DNAVideoPredictionModel(VideoPredictionModel): - def __init__(self, *args, **kwargs): - super(DNAVideoPredictionModel, self).__init__( - generator_fn, *args, **kwargs) - - def get_default_hparams_dict(self): - default_hparams = super(DNAVideoPredictionModel, self).get_default_hparams_dict() - hparams = dict( - batch_size=32, - l1_weight=0.0, - l2_weight=1.0, - transformation='cdna', - kernel_size=(9, 9), - num_masks=10, - schedule_sampling_k=900.0, - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - def parse_hparams(self, hparams_dict, hparams): - hparams = super(DNAVideoPredictionModel, self).parse_hparams(hparams_dict, hparams) - if self.mode == 'test': - def override_hparams_maybe(name, value): - orig_value = hparams.values()[name] - if orig_value != value: - print('Overriding hparams from %s=%r to %r for mode=%s.' % - (name, orig_value, value, self.mode)) - hparams.set_hparam(name, value) - override_hparams_maybe('schedule_sampling_k', -1) - return hparams diff --git a/video_prediction_tools/model_modules/video_prediction/models/linear_regression_model.py b/video_prediction_tools/model_modules/video_prediction/models/linear_regression_model.py deleted file mode 100644 index 296ae16a1f70fcd5047481ab46cdd6e14abe1e83..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/models/linear_regression_model.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong" -__date__ = "2022-04-13" - -from .our_base_model import BaseModels -import tensorflow as tf - - -class VanillaConvLstmVideoPredictionModel(BaseModels): - - def __init__(self, hparams_dict=None, **kwargs): - """ - This is class for building convLSTM architecture by using updated hparameters - args: - hparams_dict : dict, the dictionary contains the hparaemters names and values - """ - super().__init__(hparams_dict) - self.get_hparams() - - def get_hparams(self): - """ - obtain the hparams from the dict to the class variables - """ - method = BaseModels.get_hparams.__name__ - - try: - self.context_frames = self.hparams.context_frames - self.sequence_length = self.hparams.sequence_length - self.max_epochs = self.hparams.max_epochs - self.batch_size = self.hparams.batch_size - self.shuffle_on_val = self.hparams.shuffle_on_val - self.opt_var = self.hparams.opt_var - self.learning_rate = self.hparams.lr - - print("The model hparams have been parsed successfully! ") - except Exception as error: - print("Method %{}: error: {}".format(method, error)) - raise("Method %{}: the hparameter dictionary must include the params defined above!".format(method)) - - def build_graph(self, x: tf.Tensor): - - self.is_build_graph = False - self.inputs = x - self.global_step = tf.train.get_or_create_global_step() - original_global_variables = tf.global_variables() - - self.build_model() - - - # This is the loss function (MSE): - # Optimize all target variables/channels - if self.opt_var == "all": - x = self.inputs[:, self.context_frames:, :, :, :] - x_hat = self.x_hat_predict_frames[:, :, :, :, :] - print("The model is optimzied on all the variables in the loss function") - elif self.opt_var != "all" and isinstance(self.opt_var, str): - self.opt_var = int(self.opt_var) - print("The model is optimized on the {} variable in the loss function".format(self.opt_var)) - x = self.inputs[:, self.context_frames:, :, :, self.opt_var] - x_hat = self.x_hat_predict_frames[:, :, :, :, self.opt_var] - else: - raise ValueError( - "The opt var in the hyper-parameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables") - - #loss function is mean squre error - self.total_loss = tf.reduce_mean(tf.square(x - x_hat)) - - self.train_op = tf.train.AdamOptimizer( - learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) - - self.outputs["gen_images"] = self.x_hat - - # Summary op - self.loss_summary = tf.summary.scalar("total_loss", self.total_loss) - self.summary_op = tf.summary.merge_all() - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [self.global_step] + global_variables - self.is_build_graph = True - return self.is_build_graph - - - - def build_model(self): - pass diff --git a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py deleted file mode 100644 index 80a49160159840ec6219d0d5a488d676c7257d67..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py +++ /dev/null @@ -1,468 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong" -__date__ = "2020-08-22" - - -import itertools -import numpy as np -import tensorflow as tf - -from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames -from model_modules.video_prediction.models import BaseVideoPredictionModel -from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell -from model_modules.video_prediction.layers.mcnet_ops import * -from model_modules.video_prediction.utils.mcnet_utils import * -import os - -class McNetVideoPredictionModel(BaseVideoPredictionModel): - def __init__(self, mode='train', hparams_dict=None, - hparams=None, **kwargs): - super(McNetVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) - self.mode = mode - self.lr = self.hparams.lr - self.context_frames = self.hparams.context_frames - self.sequence_length = self.hparams.sequence_length - self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) - self.df_dim = self.hparams.df_dim - self.gf_dim = self.hparams.gf_dim - self.alpha = self.hparams.alpha - self.beta = self.hparams.beta - self.gen_images_enc = None - self.recon_loss = None - self.latent_loss = None - self.total_loss = None - - def get_default_hparams_dict(self): - """ - The keys of this dict define valid hyperparameters for instances of - this class. A class inheriting from this one should override this - method if it has a different set of hyperparameters. - - Returns: - A dict with the following hyperparameters. - - batch_size: batch size for training. - lr: learning rate. if decay steps is non-zero, this is the - learning rate for steps <= decay_step. - - - - - max_steps: number of training steps. - - - context_frames: the number of ground-truth frames to pass in at - start. Must be specified during instantiation. - sequence_length: the number of frames in the video sequence, - including the context frames, so this model predicts - `sequence_length - context_frames` future frames. Must be - specified during instantiation. - df_dim: specific parameters for mcnet - gf_dim: specific parameters for menet - alpha: specific parameters for mcnet - beta: specific paramters for mcnet - - """ - default_hparams = super(McNetVideoPredictionModel, self).get_default_hparams_dict() - hparams = dict( - batch_size=16, - lr=0.001, - max_epochs=350000, - context_frames = 10, - sequence_length = 20, - nz = 16, - gf_dim = 64, - df_dim = 64, - alpha = 1, - beta = 0.0 - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - def build_graph(self, x): - - self.x = x["images"] - self.x_shape = self.x.get_shape().as_list() - self.batch_size = self.x_shape[0] - self.image_size = [self.x_shape[2],self.x_shape[3]] - self.c_dim = self.x_shape[4] - self.diff_shape = [self.batch_size, self.context_frames-1, self.image_size[0], - self.image_size[1], self.c_dim] - self.xt_shape = [self.batch_size, self.image_size[0], self.image_size[1],self.c_dim] - self.is_train = True - - - #self.global_step = tf.Variable(0, name='global_step', trainable=False) - self.global_step = tf.train.get_or_create_global_step() - original_global_variables = tf.global_variables() - - # self.xt = tf.placeholder(tf.float32, self.xt_shape, name='xt') - self.xt = self.x[:, self.context_frames - 1, :, :, :] - - self.diff_in = tf.placeholder(tf.float32, self.diff_shape, name='diff_in') - diff_in_all = [] - for t in range(1, self.context_frames): - prev = self.x[:, t-1:t, :, :, :] - next = self.x[:, t:t+1, :, :, :] - #diff_in = tf.reshape(next - prev, [self.batch_size, 1, self.image_size[0], self.image_size[1], -1]) - print("prev:",prev) - print("next:",next) - diff_in = tf.subtract(next,prev) - print("diff_in:",diff_in) - diff_in_all.append(diff_in) - - self.diff_in = tf.concat(axis = 1, values = diff_in_all) - - cell = BasicConvLSTMCell([self.image_size[0] / 8, self.image_size[1] / 8], [3, 3], 256) - - pred = self.forward(self.diff_in, self.xt, cell) - - - self.G = tf.concat(axis=1, values=pred)#[batch_size,context_frames,image1,image2,channels] - print ("1:self.G:",self.G) - if self.is_train: - - true_sim = self.x[:, self.context_frames:, :, :, :] - - # Bing: the following make sure the channel is three dimension, if the channle is 3 then will be duplicated - if self.c_dim == 1: true_sim = tf.tile(true_sim, [1, 1, 1, 1, 3]) - - # Bing: the raw inputs shape is [batch_size, image_size[0],self.image_size[1], num_seq, channel]. tf.transpose will transpoe the shape into - # [batch size*num_seq, image_size0, image_size1, channels], for our era5 case, we do not need transpose - # true_sim = tf.reshape(tf.transpose(true_sim,[0,3,1,2,4]), - # [-1, self.image_size[0], - # self.image_size[1], 3]) - true_sim = tf.reshape(true_sim, [-1, self.image_size[0], self.image_size[1], 3]) - - - - - gen_sim = self.G - - #combine groud truth and predict frames - self.x_hat = tf.concat([self.x[:, :self.context_frames, :, :, :], self.G], 1) - print ("self.x_hat:",self.x_hat) - if self.c_dim == 1: gen_sim = tf.tile(gen_sim, [1, 1, 1, 1, 3]) - # gen_sim = tf.reshape(tf.transpose(gen_sim,[0,3,1,2,4]), - # [-1, self.image_size[0], - # self.image_size[1], 3]) - - gen_sim = tf.reshape(gen_sim, [-1, self.image_size[0], self.image_size[1], 3]) - - - binput = tf.reshape(tf.transpose(self.x[:, :self.context_frames, :, :, :], [0, 1, 2, 3, 4]), - [self.batch_size, self.image_size[0], - self.image_size[1], -1]) - - btarget = tf.reshape(tf.transpose(self.x[:, self.context_frames:, :, :, :], [0, 1, 2, 3, 4]), - [self.batch_size, self.image_size[0], - self.image_size[1], -1]) - bgen = tf.reshape(self.G, [self.batch_size, - self.image_size[0], - self.image_size[1], -1]) - - print ("binput:",binput) - print("btarget:",btarget) - print("bgen:",bgen) - - good_data = tf.concat(axis=3, values=[binput, btarget]) - gen_data = tf.concat(axis=3, values=[binput, bgen]) - self.gen_data = gen_data - print ("2:self.gen_data:", self.gen_data) - with tf.variable_scope("DIS", reuse=False): - self.D, self.D_logits = self.discriminator(good_data) - - with tf.variable_scope("DIS", reuse=True): - self.D_, self.D_logits_ = self.discriminator(gen_data) - - self.L_p = tf.reduce_mean( - tf.square(self.G - self.x[:, self.context_frames:, :, :, :])) - - self.L_gdl = gdl(gen_sim, true_sim, 1.) - self.L_img = self.L_p + self.L_gdl - - self.d_loss_real = tf.reduce_mean( - tf.nn.sigmoid_cross_entropy_with_logits( - logits = self.D_logits, labels = tf.ones_like(self.D) - )) - self.d_loss_fake = tf.reduce_mean( - tf.nn.sigmoid_cross_entropy_with_logits( - logits = self.D_logits_, labels = tf.zeros_like(self.D_) - )) - self.d_loss = self.d_loss_real + self.d_loss_fake - self.L_GAN = tf.reduce_mean( - tf.nn.sigmoid_cross_entropy_with_logits( - logits = self.D_logits_, labels = tf.ones_like(self.D_) - )) - - self.loss_sum = tf.summary.scalar("L_img", self.L_img) - self.L_p_sum = tf.summary.scalar("L_p", self.L_p) - self.L_gdl_sum = tf.summary.scalar("L_gdl", self.L_gdl) - self.L_GAN_sum = tf.summary.scalar("L_GAN", self.L_GAN) - self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) - self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real) - self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake) - - self.total_loss = self.alpha * self.L_img + self.beta * self.L_GAN - self._loss_sum = tf.summary.scalar("total_loss", self.total_loss) - self.g_sum = tf.summary.merge([self.L_p_sum, - self.L_gdl_sum, self.loss_sum, - self.L_GAN_sum]) - self.d_sum = tf.summary.merge([self.d_loss_real_sum, self.d_loss_sum, - self.d_loss_fake_sum]) - - - self.t_vars = tf.trainable_variables() - self.g_vars = [var for var in self.t_vars if 'DIS' not in var.name] - self.d_vars = [var for var in self.t_vars if 'DIS' in var.name] - num_param = 0.0 - for var in self.g_vars: - num_param += int(np.prod(var.get_shape())); - print("Number of parameters: %d" % num_param) - - # Training - self.d_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize( - self.d_loss, var_list = self.d_vars) - self.g_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize( - self.alpha * self.L_img + self.beta * self.L_GAN, var_list = self.g_vars, global_step=self.global_step) - - self.train_op = [self.d_optim,self.g_optim] - self.outputs = {} - self.outputs["gen_images"] = self.x_hat - - - self.summary_op = tf.summary.merge_all() - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [self.global_step] + global_variables - return - - - def forward(self, diff_in, xt, cell): - # Initial state - state = tf.zeros([self.batch_size, self.image_size[0] / 8, - self.image_size[1] / 8, 512]) - reuse = False - # Encoder - for t in range(self.context_frames - 1): - enc_h, res_m = self.motion_enc(diff_in[:, t, :, :, :], reuse = reuse) - h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = reuse) - reuse = True - pred = [] - # Decoder - for t in range(self.predict_frames): - if t == 0: - h_cont, res_c = self.content_enc(xt, reuse = False) - h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = False) - res_connect = self.residual(res_m, res_c, reuse = False) - x_hat = self.dec_cnn(h_tp1, res_connect, reuse = False) - - else: - - enc_h, res_m = self.motion_enc(diff_in, reuse = True) - h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = True) - h_cont, res_c = self.content_enc(xt, reuse = reuse) - h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = True) - res_connect = self.residual(res_m, res_c, reuse = True) - x_hat = self.dec_cnn(h_tp1, res_connect, reuse = True) - print ("x_hat :",x_hat) - if self.c_dim == 3: - # Network outputs are BGR so they need to be reversed to use - # rgb_to_grayscale - #x_hat_gray = tf.concat(axis=3,values=[x_hat[:,:,:,2:3], x_hat[:,:,:,1:2],x_hat[:,:,:,0:1]]) - #xt_gray = tf.concat(axis=3,values=[xt[:,:,:,2:3], xt[:,:,:,1:2],xt[:,:,:,0:1]]) - - # x_hat_gray = 1./255.*tf.image.rgb_to_grayscale( - # inverse_transform(x_hat_rgb)*255. - # ) - # xt_gray = 1./255.*tf.image.rgb_to_grayscale( - # inverse_transform(xt_rgb)*255. - # ) - - x_hat_gray = x_hat - xt_gray = xt - else: - x_hat_gray = inverse_transform(x_hat) - xt_gray = inverse_transform(xt) - - diff_in = x_hat_gray - xt_gray - xt = x_hat - - - pred.append(tf.reshape(x_hat, [self.batch_size, 1, self.image_size[0], - self.image_size[1], self.c_dim])) - - return pred - - def motion_enc(self, diff_in, reuse): - res_in = [] - - conv1 = relu(conv2d(diff_in, output_dim = self.gf_dim, k_h = 5, k_w = 5, - d_h = 1, d_w = 1, name = 'dyn1_conv1', reuse = reuse)) - res_in.append(conv1) - pool1 = MaxPooling(conv1, [2, 2]) - - conv2 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 5, k_w = 5, - d_h = 1, d_w = 1, name = 'dyn_conv2', reuse = reuse)) - res_in.append(conv2) - pool2 = MaxPooling(conv2, [2, 2]) - - conv3 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 7, k_w = 7, - d_h = 1, d_w = 1, name = 'dyn_conv3', reuse = reuse)) - res_in.append(conv3) - pool3 = MaxPooling(conv3, [2, 2]) - return pool3, res_in - - def content_enc(self, xt, reuse): - res_in = [] - conv1_1 = relu(conv2d(xt, output_dim = self.gf_dim, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'cont_conv1_1', reuse = reuse)) - conv1_2 = relu(conv2d(conv1_1, output_dim = self.gf_dim, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'cont_conv1_2', reuse = reuse)) - res_in.append(conv1_2) - pool1 = MaxPooling(conv1_2, [2, 2]) - - conv2_1 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'cont_conv2_1', reuse = reuse)) - conv2_2 = relu(conv2d(conv2_1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'cont_conv2_2', reuse = reuse)) - res_in.append(conv2_2) - pool2 = MaxPooling(conv2_2, [2, 2]) - - conv3_1 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'cont_conv3_1', reuse = reuse)) - conv3_2 = relu(conv2d(conv3_1, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'cont_conv3_2', reuse = reuse)) - conv3_3 = relu(conv2d(conv3_2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'cont_conv3_3', reuse = reuse)) - res_in.append(conv3_3) - pool3 = MaxPooling(conv3_3, [2, 2]) - return pool3, res_in - - def comb_layers(self, h_dyn, h_cont, reuse=False): - comb1 = relu(conv2d(tf.concat(axis = 3, values = [h_dyn, h_cont]), - output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'comb1', reuse = reuse)) - comb2 = relu(conv2d(comb1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'comb2', reuse = reuse)) - h_comb = relu(conv2d(comb2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'h_comb', reuse = reuse)) - return h_comb - - def residual(self, input_dyn, input_cont, reuse=False): - n_layers = len(input_dyn) - res_out = [] - for l in range(n_layers): - input_ = tf.concat(axis = 3, values = [input_dyn[l], input_cont[l]]) - out_dim = input_cont[l].get_shape()[3] - res1 = relu(conv2d(input_, output_dim = out_dim, - k_h = 3, k_w = 3, d_h = 1, d_w = 1, - name = 'res' + str(l) + '_1', reuse = reuse)) - res2 = conv2d(res1, output_dim = out_dim, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'res' + str(l) + '_2', reuse = reuse) - res_out.append(res2) - return res_out - - def dec_cnn(self, h_comb, res_connect, reuse=False): - - shapel3 = [self.batch_size, int(self.image_size[0] / 4), - int(self.image_size[1] / 4), self.gf_dim * 4] - shapeout3 = [self.batch_size, int(self.image_size[0] / 4), - int(self.image_size[1] / 4), self.gf_dim * 2] - depool3 = FixedUnPooling(h_comb, [2, 2]) - deconv3_3 = relu(deconv2d(relu(tf.add(depool3, res_connect[2])), - output_shape = shapel3, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'dec_deconv3_3', reuse = reuse)) - deconv3_2 = relu(deconv2d(deconv3_3, output_shape = shapel3, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'dec_deconv3_2', reuse = reuse)) - deconv3_1 = relu(deconv2d(deconv3_2, output_shape = shapeout3, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'dec_deconv3_1', reuse = reuse)) - - shapel2 = [self.batch_size, int(self.image_size[0] / 2), - int(self.image_size[1] / 2), self.gf_dim * 2] - shapeout3 = [self.batch_size, int(self.image_size[0] / 2), - int(self.image_size[1] / 2), self.gf_dim] - depool2 = FixedUnPooling(deconv3_1, [2, 2]) - deconv2_2 = relu(deconv2d(relu(tf.add(depool2, res_connect[1])), - output_shape = shapel2, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'dec_deconv2_2', reuse = reuse)) - deconv2_1 = relu(deconv2d(deconv2_2, output_shape = shapeout3, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'dec_deconv2_1', reuse = reuse)) - - shapel1 = [self.batch_size, self.image_size[0], - self.image_size[1], self.gf_dim] - shapeout1 = [self.batch_size, self.image_size[0], - self.image_size[1], self.c_dim] - depool1 = FixedUnPooling(deconv2_1, [2, 2]) - deconv1_2 = relu(deconv2d(relu(tf.add(depool1, res_connect[0])), - output_shape = shapel1, k_h = 3, k_w = 3, d_h = 1, d_w = 1, - name = 'dec_deconv1_2', reuse = reuse)) - xtp1 = tanh(deconv2d(deconv1_2, output_shape = shapeout1, k_h = 3, k_w = 3, - d_h = 1, d_w = 1, name = 'dec_deconv1_1', reuse = reuse)) - return xtp1 - - def discriminator(self, image): - h0 = lrelu(conv2d(image, self.df_dim, name = 'dis_h0_conv')) - h1 = lrelu(batch_norm(conv2d(h0, self.df_dim * 2, name = 'dis_h1_conv'), - "bn1")) - h2 = lrelu(batch_norm(conv2d(h1, self.df_dim * 4, name = 'dis_h2_conv'), - "bn2")) - h3 = lrelu(batch_norm(conv2d(h2, self.df_dim * 8, name = 'dis_h3_conv'), - "bn3")) - h = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'dis_h3_lin') - - return tf.nn.sigmoid(h), h - - def save(self, sess, checkpoint_dir, step): - model_name = "MCNET.model" - - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) - - self.saver.save(sess, - os.path.join(checkpoint_dir, model_name), - global_step = step) - - def load(self, sess, checkpoint_dir, model_name=None): - print(" [*] Reading checkpoints...") - ckpt = tf.train.get_checkpoint_state(checkpoint_dir) - if ckpt and ckpt.model_checkpoint_path: - ckpt_name = os.path.basename(ckpt.model_checkpoint_path) - if model_name is None: model_name = ckpt_name - self.saver.restore(sess, os.path.join(checkpoint_dir, model_name)) - print(" Loaded model: " + str(model_name)) - return True, model_name - else: - return False, None - - # Execute the forward and the backward pass - - def run_single_step(self, global_step): - print("global_step:", global_step) - try: - train_batch = self.sess.run(self.train_iterator.get_next()) - # z=np.random.uniform(-1,1,size=(self.batch_size,self.nz)) - x = self.sess.run([self.x], feed_dict = {self.x: train_batch["images"]}) - _, g_sum = self.sess.run([self.g_optim, self.g_sum], feed_dict = {self.x: train_batch["images"]}) - _, d_sum = self.sess.run([self.d_optim, self.d_sum], feed_dict = {self.x: train_batch["images"]}) - - gen_data, train_loss = self.sess.run([self.gen_data, self.total_loss], - feed_dict = {self.x: train_batch["images"]}) - - except tf.errors.OutOfRangeError: - print("train out of range error") - - try: - val_batch = self.sess.run(self.val_iterator.get_next()) - val_loss = self.sess.run([self.total_loss], feed_dict = {self.x: val_batch["images"]}) - # self.val_writer.add_summary(val_summary, global_step) - except tf.errors.OutOfRangeError: - print("train out of range error") - - return train_loss, val_total_loss - - - diff --git a/video_prediction_tools/model_modules/video_prediction/models/non_trainable_model.py b/video_prediction_tools/model_modules/video_prediction/models/non_trainable_model.py deleted file mode 100644 index aba90339317dafcf114442ca37cb62338e32d8cd..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/models/non_trainable_model.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -from tensorflow.python.util import nest -from model_modules.video_prediction.utils.tf_utils import transpose_batch_time - -import tensorflow as tf - -from .base_model import BaseVideoPredictionModel - - -class NonTrainableVideoPredictionModel(BaseVideoPredictionModel): - pass - - -class GroundTruthVideoPredictionModel(NonTrainableVideoPredictionModel): - def build_graph(self, inputs): - super(GroundTruthVideoPredictionModel, self).build_graph(inputs) - - self.outputs = OrderedDict() - self.outputs['gen_images'] = self.inputs['images'][:, 1:] - if 'pix_distribs' in self.inputs: - self.outputs['gen_pix_distribs'] = self.inputs['pix_distribs'][:, 1:] - - inputs, outputs = nest.map_structure(transpose_batch_time, (self.inputs, self.outputs)) - with tf.name_scope("metrics"): - metrics = self.metrics_fn(inputs, outputs) - with tf.name_scope("eval_outputs_and_metrics"): - eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs) - self.metrics, self.eval_outputs, self.eval_metrics = nest.map_structure( - transpose_batch_time, (metrics, eval_outputs, eval_metrics)) - - -class RepeatVideoPredictionModel(NonTrainableVideoPredictionModel): - def build_graph(self, inputs): - super(RepeatVideoPredictionModel, self).build_graph(inputs) - - self.outputs = OrderedDict() - tile_pattern = [1, self.hparams.sequence_length - self.hparams.context_frames, 1, 1, 1] - last_context_images = self.inputs['images'][:, self.hparams.context_frames - 1] - self.outputs['gen_images'] = tf.concat([ - self.inputs['images'][:, 1:self.hparams.context_frames - 1], - tf.tile(last_context_images[:, None], tile_pattern)], axis=-1) - if 'pix_distribs' in self.inputs: - last_context_pix_distrib = self.inputs['pix_distribs'][:, self.hparams.context_frames - 1] - self.outputs['gen_pix_distribs'] = tf.concat([ - self.inputs['pix_distribs'][:, 1:self.hparams.context_frames - 1], - tf.tile(last_context_pix_distrib[:, None], tile_pattern)], axis=-1) - - inputs, outputs = nest.map_structure(transpose_batch_time, (self.inputs, self.outputs)) - with tf.name_scope("metrics"): - metrics = self.metrics_fn(inputs, outputs) - with tf.name_scope("eval_outputs_and_metrics"): - eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs) - self.metrics, self.eval_outputs, self.eval_metrics = nest.map_structure( - transpose_batch_time, (metrics, eval_outputs, eval_metrics)) diff --git a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py index 3219bda2de305cf36ad178ebcc192ac9a5a37b78..589a6b32b7fe9c80646e53310d54272f696cb88f 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py @@ -14,69 +14,146 @@ import tensorflow as tf class BaseModels(ABC): - def __init__(self, hparams_dict_config=None): - self.hparams_dict_config = hparams_dict_config - self.hparams_dict = self.get_model_hparams_dict() - self.hparams = self.parse_hparams() - # Attributes set during runtime - self.total_loss = None - self.loss_summary = None + def __init__(self, hparams_dict_config=None, mode="train"): + _modes = ["train","val"] + + if mode not in _modes: + raise ValueError("The mode must be 'train' or 'val'") + else: + self.mode = mode + + self.__model = None + self.hparams = self.hparams_options(hparams_dict_config) + self.parse_hparams(self.hparams) + + # Compile options, must be customized in the sub-class + self.inputs = None + self.train_op = None self.total_loss = None self.outputs = {} - self.train_op = None + self.loss_summary = None self.summary_op = None - self.inputs = None - self.global_step = None + self.global_step = tf.train.get_or_create_global_step() self.saveable_variables = None - self.is_build_graph = None - self.x_hat = None - self.x_hat_predict_frames = None + self._is_build_graph_set = False - - def get_model_hparams_dict(self): - """ - Get model_hparams_dict from json file - """ - if self.hparams_dict_config: - with open(self.hparams_dict_config, 'r') as f: + + def hparams_options(self, hparams_dict_config:str): + if hparams_dict_config: + with open(hparams_dict_config, 'r') as f: hparams_dict = json.loads(f.read()) else: - raise FileNotFoundError("hyper-parameter directory doesn't exist! please check {}!".format(self.hparams_dict_config)) + raise FileNotFoundError("hyper-parameter directory doesn't exist! please check {}!".format(hparams_dict_config)) + return dotdict(hparams_dict) - return hparams_dict - def parse_hparams(self): + @abstractmethod + def parse_hparams(self, hparams)->None: """ - Obtain the parameters from directory + parse the hyper-parameter as class attribute + Examples: + ... code-block:: python + def parse_hparams(self): + try: + self.context_frames = hparams.context_frames + self.max_epochs = hparams.max_epochs + self.batch_size = hparams.batch_size + self.shuffle_on_val = hparams.shuffle_on_val + self.loss_fun = hparams.loss_fun + + except Exception as e: + raise ValueError(f"missing hyperparameter: {e.args[0]}") """ + pass - hparams = dotdict(self.hparams_dict) - return hparams - - @abstractmethod + @property def get_hparams(self): + return self.hparams + + + def build_graph(self, x: tf.Tensor)->bool: """ - obtain the hparams from the dict to the class variables + This function is used for build the graph, and allow a optimiser to the graph by using tensorflow function. + + Example: + ... code-block:: python + def build_graph(self, inputs): + original_global_variables = tf.global_variables() + x_hat = self.build_model(x) + self.train_loss = self.get_loss(x,x_hat) + self.train_op = self.optimizer(self.train_loss) + self.outputs["gen_images"] = x_hat + self.summary_op = self.summary() #This is optional + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self._is_build_graph_set=True + return self._is_build_graph_set + """ - method = BaseModels.get_hparams.__name__ + self.inputs = x + original_global_variables = tf.global_variables() + x_hat = self.build_model(x) + self.total_loss = self.get_loss(x, x_hat) + self.train_op = self.optimizer(self.total_loss) + self.outputs["gen_images"] = x_hat + self.summary_op = self.summary(total_loss = self.total_loss) + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self._is_build_graph_set = True + return self._is_build_graph_set + + + def optimizer(self, total_loss): + """ + Define the optimizer + Example: + ... code-block:: python + def optimizer(self): + train_op = tf.train.AdamOptimizer( + learning_rate = self.lr).minimize(total_loss, global_step = self.global_step) + return train_op + """ + train_op = tf.train.AdamOptimizer( + learning_rate = self.lr).minimize(total_loss, global_step = self.global_step) + return train_op + - try: - self.context_frames = self.hparams.context_frames - self.max_epochs = self.hparams.max_epochs - self.batch_size = self.hparams.batch_size - self.shuffle_on_val = self.hparams.shuffle_on_val - self.loss_fun = self.hparams.loss_fun - except Exception as error: - print("Method %{}: error: {}".format(method,error)) - raise("Method %{}: the hparameter dictionary must include " - "'context_frames','max_epochs','batch_size','shuffle_on_val' 'loss_fun'".format(method)) @abstractmethod - def build_graph(self, x: tf.Tensor): + def get_loss(self, x:tf.Tensor, x_hat:tf.Tensor)->tf.Tensor: + """ + :param x : Input tensors + :param x_hat: Prediction/output tensors + :return : the loss function + """ pass + def summary(self, **kwargs): + """ + return the summary operation can be used for TensorBoard + """ + for key, value in kwargs.items(): + tf.summary.scalar(key, value) + summary_op = tf.summary.merge_all() + return summary_op + + + @abstractmethod - def build_model(self): + def build_model(self, x)->tf.Tensor: + """ + This function is used to create the network + Example: see example in vanilla_convLSTM_model.py, it must return prediction fnsrames and save it to the self.output + which is used for calculating the loss + """ pass + + + + + + + + diff --git a/video_prediction_tools/model_modules/video_prediction/models/savp_model.py b/video_prediction_tools/model_modules/video_prediction/models/savp_model.py deleted file mode 100644 index 17d72563e5dffac8bc1ff1b863d083d603d0a69b..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/models/savp_model.py +++ /dev/null @@ -1,996 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import collections -import functools -import itertools -from collections import OrderedDict -import numpy as np -import tensorflow as tf -from tensorflow.python.util import nest -from model_modules.video_prediction import ops, flow_ops -from model_modules.video_prediction.models import VideoPredictionModel -from model_modules.video_prediction.models import networks -from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from model_modules.video_prediction.utils import tf_utils - -# Amount to use when lower bounding tensors -RELU_SHIFT = 1e-12 - - -def posterior_fn(inputs, hparams): - images = inputs['images'] - image_pairs = tf.concat([images[:-1], images[1:]], axis=-1) - if 'actions' in inputs: - image_pairs = tile_concat( - [image_pairs, inputs['actions'][..., None, None, :]], axis=-1) - - h = tf_utils.with_flat_batch(networks.encoder)( - image_pairs, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer) - - if hparams.use_e_rnn: - with tf.variable_scope('layer_%d' % (hparams.n_layers + 1)): - h = tf_utils.with_flat_batch(dense, 2)(h, hparams.nef * 4) - - if hparams.rnn == 'lstm': - RNNCell = tf.contrib.rnn.BasicLSTMCell - elif hparams.rnn == 'gru': - RNNCell = tf.contrib.rnn.GRUCell - else: - raise NotImplementedError - with tf.variable_scope('%s' % hparams.rnn): - rnn_cell = RNNCell(hparams.nef * 4) - h, _ = tf_utils.unroll_rnn(rnn_cell, h) - - with tf.variable_scope('z_mu'): - z_mu = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz) - with tf.variable_scope('z_log_sigma_sq'): - z_log_sigma_sq = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz) - z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10) - outputs = {'zs_mu': z_mu, 'zs_log_sigma_sq': z_log_sigma_sq} - return outputs - - -def prior_fn(inputs, hparams): - images = inputs['images'] - image_pairs = tf.concat([images[:hparams.context_frames - 1], images[1:hparams.context_frames]], axis=-1) - if 'actions' in inputs: - image_pairs = tile_concat( - [image_pairs, inputs['actions'][..., None, None, :]], axis=-1) - - h = tf_utils.with_flat_batch(networks.encoder)( - image_pairs, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer) - h_zeros = tf.zeros(tf.concat([[hparams.sequence_length - hparams.context_frames], tf.shape(h)[1:]], axis=0)) - h = tf.concat([h, h_zeros], axis=0) - - with tf.variable_scope('layer_%d' % (hparams.n_layers + 1)): - h = tf_utils.with_flat_batch(dense, 2)(h, hparams.nef * 4) - - if hparams.rnn == 'lstm': - RNNCell = tf.contrib.rnn.BasicLSTMCell - elif hparams.rnn == 'gru': - RNNCell = tf.contrib.rnn.GRUCell - else: - raise NotImplementedError - with tf.variable_scope('%s' % hparams.rnn): - rnn_cell = RNNCell(hparams.nef * 4) - h, _ = tf_utils.unroll_rnn(rnn_cell, h) - - with tf.variable_scope('z_mu'): - z_mu = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz) - with tf.variable_scope('z_log_sigma_sq'): - z_log_sigma_sq = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz) - z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10) - outputs = {'zs_mu': z_mu, 'zs_log_sigma_sq': z_log_sigma_sq} - return outputs - - -def discriminator_given_video_fn(targets, hparams): - sequence_length, batch_size = targets.shape.as_list()[:2] - clip_length = hparams.clip_length - - # sample an image and apply the image distriminator on that frame - t_sample = tf.random_uniform([batch_size], minval=0, maxval=sequence_length, dtype=tf.int32) - image_sample = tf.gather_nd(targets, tf.stack([t_sample, tf.range(batch_size)], axis=1)) - - # sample a subsequence of length clip_length and apply the images/video discriminators on those frames - t_start = tf.random_uniform([batch_size], minval=0, maxval=sequence_length - clip_length + 1, dtype=tf.int32) - t_start_indices = tf.stack([t_start, tf.range(batch_size)], axis=1) - t_offset_indices = tf.stack([tf.range(clip_length), tf.zeros(clip_length, dtype=tf.int32)], axis=1) - indices = t_start_indices[None] + t_offset_indices[:, None] - clip_sample = tf.gather_nd(targets, flatten(indices, 0, 1)) - clip_sample = tf.reshape(clip_sample, [clip_length] + targets.shape.as_list()[1:]) - - outputs = {} - if hparams.image_sn_gan_weight or hparams.image_sn_vae_gan_weight: - with tf.variable_scope('image'): - image_features = networks.image_sn_discriminator(image_sample, ndf=hparams.ndf) - image_features, image_logits = image_features[:-1], image_features[-1] - outputs['discrim_image_sn_logits'] = image_logits - for i, image_feature in enumerate(image_features): - outputs['discrim_image_sn_feature%d' % i] = image_feature - if hparams.video_sn_gan_weight or hparams.video_sn_vae_gan_weight: - with tf.variable_scope('video'): - video_features = networks.video_sn_discriminator(clip_sample, ndf=hparams.ndf) - video_features, video_logits = video_features[:-1], video_features[-1] - outputs['discrim_video_sn_logits'] = video_logits - for i, video_feature in enumerate(video_features): - outputs['discrim_video_sn_feature%d' % i] = video_feature - if hparams.images_sn_gan_weight or hparams.images_sn_vae_gan_weight: - with tf.variable_scope('images'): - images_features = tf_utils.with_flat_batch(networks.image_sn_discriminator)(clip_sample, ndf=hparams.ndf) - images_features, images_logits = images_features[:-1], images_features[-1] - outputs['discrim_images_sn_logits'] = images_logits - for i, images_feature in enumerate(images_features): - outputs['discrim_images_sn_feature%d' % i] = images_feature - return outputs - - -def discriminator_fn(inputs, outputs, mode, hparams): - # do the encoder version first so that it isn't affected by the reuse_variables() call - if hparams.nz == 0: - discrim_outputs_enc_real = collections.OrderedDict() - discrim_outputs_enc_fake = collections.OrderedDict() - else: - images_enc_real = inputs['images'][1:] - images_enc_fake = outputs['gen_images_enc'] - if hparams.use_same_discriminator: - with tf.name_scope("real"): - discrim_outputs_enc_real = discriminator_given_video_fn(images_enc_real, hparams) - tf.get_variable_scope().reuse_variables() - with tf.name_scope("fake"): - discrim_outputs_enc_fake = discriminator_given_video_fn(images_enc_fake, hparams) - else: - with tf.variable_scope('encoder'), tf.name_scope("real"): - discrim_outputs_enc_real = discriminator_given_video_fn(images_enc_real, hparams) - with tf.variable_scope('encoder', reuse=True), tf.name_scope("fake"): - discrim_outputs_enc_fake = discriminator_given_video_fn(images_enc_fake, hparams) - - images_real = inputs['images'][1:] - images_fake = outputs['gen_images'] - with tf.name_scope("real"): - discrim_outputs_real = discriminator_given_video_fn(images_real, hparams) - tf.get_variable_scope().reuse_variables() - with tf.name_scope("fake"): - discrim_outputs_fake = discriminator_given_video_fn(images_fake, hparams) - - discrim_outputs_real = OrderedDict([(k + '_real', v) for k, v in discrim_outputs_real.items()]) - discrim_outputs_fake = OrderedDict([(k + '_fake', v) for k, v in discrim_outputs_fake.items()]) - discrim_outputs_enc_real = OrderedDict([(k + '_enc_real', v) for k, v in discrim_outputs_enc_real.items()]) - discrim_outputs_enc_fake = OrderedDict([(k + '_enc_fake', v) for k, v in discrim_outputs_enc_fake.items()]) - outputs = [discrim_outputs_real, discrim_outputs_fake, - discrim_outputs_enc_real, discrim_outputs_enc_fake] - total_num_outputs = sum([len(output) for output in outputs]) - outputs = collections.OrderedDict(itertools.chain(*[output.items() for output in outputs])) - assert len(outputs) == total_num_outputs # ensure no output is lost because of repeated keys - return outputs - - -class SAVPCell(tf.nn.rnn_cell.RNNCell): - def __init__(self, inputs, mode, hparams, reuse=None): - super(SAVPCell, self).__init__(_reuse=reuse) - self.inputs = inputs - self.mode = mode - self.hparams = hparams - - if self.hparams.where_add not in ('input', 'all', 'middle'): - raise ValueError('Invalid where_add %s' % self.hparams.where_add) - - batch_size = inputs['images'].shape[1].value - image_shape = inputs['images'].shape.as_list()[2:] - height, width, _ = image_shape - scale_size = min(height, width) - if scale_size >= 256: - self.encoder_layer_specs = [ - (self.hparams.ngf, False), - (self.hparams.ngf * 2, False), - (self.hparams.ngf * 4, True), - (self.hparams.ngf * 8, True), - (self.hparams.ngf * 8, True), - ] - self.decoder_layer_specs = [ - (self.hparams.ngf * 8, True), - (self.hparams.ngf * 4, True), - (self.hparams.ngf * 2, False), - (self.hparams.ngf, False), - (self.hparams.ngf, False), - ] - elif scale_size >= 128: - self.encoder_layer_specs = [ - (self.hparams.ngf, False), - (self.hparams.ngf * 2, True), - (self.hparams.ngf * 4, True), - (self.hparams.ngf * 8, True), - ] - self.decoder_layer_specs = [ - (self.hparams.ngf * 8, True), - (self.hparams.ngf * 4, True), - (self.hparams.ngf * 2, False), - (self.hparams.ngf, False), - ] - elif scale_size >= 64: - self.encoder_layer_specs = [ - (self.hparams.ngf, True), - (self.hparams.ngf * 2, True), - (self.hparams.ngf * 4, True), - ] - self.decoder_layer_specs = [ - (self.hparams.ngf * 2, True), - (self.hparams.ngf, True), - (self.hparams.ngf, False), - ] - elif scale_size >= 32: - self.encoder_layer_specs = [ - (self.hparams.ngf, True), - (self.hparams.ngf * 2, True), - ] - self.decoder_layer_specs = [ - (self.hparams.ngf, True), - (self.hparams.ngf, False), - ] - else: - print("The minimum of image size is 32") - raise NotImplementedError - assert len(self.encoder_layer_specs) == len(self.decoder_layer_specs) - total_stride = 2 ** len(self.encoder_layer_specs) - if (height % total_stride) or (width % total_stride): - raise ValueError("The image has dimension (%d, %d), but it should be divisible " - "by the total stride, which is %d." % (height, width, total_stride)) - - # output_size - num_masks = self.hparams.last_frames * self.hparams.num_transformed_images + \ - int(bool(self.hparams.prev_image_background)) + \ - int(bool(self.hparams.first_image_background and not self.hparams.context_images_background)) + \ - int(bool(self.hparams.last_image_background and not self.hparams.context_images_background)) + \ - int(bool(self.hparams.last_context_image_background and not self.hparams.context_images_background)) + \ - (self.hparams.context_frames if self.hparams.context_images_background else 0) + \ - int(bool(self.hparams.generate_scratch_image)) - output_size = { - 'gen_images': tf.TensorShape(image_shape), - 'transformed_images': tf.TensorShape(image_shape + [num_masks]), - 'masks': tf.TensorShape([height, width, 1, num_masks]), - } - if 'pix_distribs' in inputs: - num_motions = inputs['pix_distribs'].shape[-1].value - output_size['gen_pix_distribs'] = tf.TensorShape([height, width, num_motions]) - output_size['transformed_pix_distribs'] = tf.TensorShape([height, width, num_motions, num_masks]) - if 'states' in inputs: - output_size['gen_states'] = inputs['states'].shape[2:] - if self.hparams.transformation == 'flow': - output_size['gen_flows'] = tf.TensorShape([height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images]) - output_size['gen_flows_rgb'] = tf.TensorShape([height, width, 3, self.hparams.last_frames * self.hparams.num_transformed_images]) - self._output_size = output_size - - # state_size - conv_rnn_state_sizes = [] - conv_rnn_height, conv_rnn_width = height, width - for out_channels, use_conv_rnn in self.encoder_layer_specs: - conv_rnn_height //= 2 - conv_rnn_width //= 2 - if use_conv_rnn and not self.hparams.ablation_rnn: - conv_rnn_state_sizes.append(tf.TensorShape([conv_rnn_height, conv_rnn_width, out_channels])) - for out_channels, use_conv_rnn in self.decoder_layer_specs: - conv_rnn_height *= 2 - conv_rnn_width *= 2 - if use_conv_rnn and not self.hparams.ablation_rnn: - conv_rnn_state_sizes.append(tf.TensorShape([conv_rnn_height, conv_rnn_width, out_channels])) - if self.hparams.conv_rnn == 'lstm': - conv_rnn_state_sizes = [tf.nn.rnn_cell.LSTMStateTuple(conv_rnn_state_size, conv_rnn_state_size) - for conv_rnn_state_size in conv_rnn_state_sizes] - state_size = {'time': tf.TensorShape([]), - 'gen_image': tf.TensorShape(image_shape), - 'last_images': [tf.TensorShape(image_shape)] * self.hparams.last_frames, - 'conv_rnn_states': conv_rnn_state_sizes} - if 'zs' in inputs and self.hparams.use_rnn_z and not self.hparams.ablation_rnn: - rnn_z_state_size = tf.TensorShape([self.hparams.nz]) - if self.hparams.rnn == 'lstm': - rnn_z_state_size = tf.nn.rnn_cell.LSTMStateTuple(rnn_z_state_size, rnn_z_state_size) - state_size['rnn_z_state'] = rnn_z_state_size - if 'pix_distribs' in inputs: - state_size['gen_pix_distrib'] = tf.TensorShape([height, width, num_motions]) - state_size['last_pix_distribs'] = [tf.TensorShape([height, width, num_motions])] * self.hparams.last_frames - if 'states' in inputs: - state_size['gen_state'] = inputs['states'].shape[2:] - self._state_size = state_size - - if self.hparams.learn_initial_state: - learnable_initial_state_size = {k: v for k, v in state_size.items() - if k in ('conv_rnn_states', 'rnn_z_state')} - else: - learnable_initial_state_size = {} - learnable_initial_state_flat = [] - for i, size in enumerate(nest.flatten(learnable_initial_state_size)): - with tf.variable_scope('initial_state_%d' % i): - state = tf.get_variable('initial_state', size, - dtype=tf.float32, initializer=tf.zeros_initializer()) - learnable_initial_state_flat.append(state) - self._learnable_initial_state = nest.pack_sequence_as( - learnable_initial_state_size, learnable_initial_state_flat) - - ground_truth_sampling_shape = [self.hparams.sequence_length - 1 - self.hparams.context_frames, batch_size] - if self.hparams.schedule_sampling == 'none' or self.mode != 'train': - ground_truth_sampling = tf.constant(False, dtype=tf.bool, shape=ground_truth_sampling_shape) - elif self.hparams.schedule_sampling in ('inverse_sigmoid', 'linear'): - if self.hparams.schedule_sampling == 'inverse_sigmoid': - k = self.hparams.schedule_sampling_k - start_step = self.hparams.schedule_sampling_steps[0] - iter_num = tf.to_float(tf.train.get_or_create_global_step()) - prob = (k / (k + tf.exp((iter_num - start_step) / k))) - prob = tf.cond(tf.less(iter_num, start_step), lambda: 1.0, lambda: prob) - elif self.hparams.schedule_sampling == 'linear': - start_step, end_step = self.hparams.schedule_sampling_steps - step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step) - prob = 1.0 - tf.to_float(step - start_step) / tf.to_float(end_step - start_step) - log_probs = tf.log([1 - prob, prob]) - ground_truth_sampling = tf.multinomial([log_probs] * batch_size, ground_truth_sampling_shape[0]) - ground_truth_sampling = tf.cast(tf.transpose(ground_truth_sampling, [1, 0]), dtype=tf.bool) - # Ensure that eventually, the model is deterministically - # autoregressive (as opposed to autoregressive with very high probability). - ground_truth_sampling = tf.cond(tf.less(prob, 0.001), - lambda: tf.constant(False, dtype=tf.bool, shape=ground_truth_sampling_shape), - lambda: ground_truth_sampling) - else: - raise NotImplementedError - ground_truth_context = tf.constant(True, dtype=tf.bool, shape=[self.hparams.context_frames, batch_size]) - self.ground_truth = tf.concat([ground_truth_context, ground_truth_sampling], axis=0) - - @property - def output_size(self): - return self._output_size - - @property - def state_size(self): - return self._state_size - - def zero_state(self, batch_size, dtype): - init_state = super(SAVPCell, self).zero_state(batch_size, dtype) - learnable_init_state = nest.map_structure( - lambda x: tf.tile(x[None], [batch_size] + [1] * x.shape.ndims), self._learnable_initial_state) - init_state.update(learnable_init_state) - init_state['last_images'] = [self.inputs['images'][0]] * self.hparams.last_frames - if 'pix_distribs' in self.inputs: - init_state['last_pix_distribs'] = [self.inputs['pix_distribs'][0]] * self.hparams.last_frames - return init_state - - def _rnn_func(self, inputs, state, num_units): - if self.hparams.rnn == 'lstm': - RNNCell = functools.partial(tf.nn.rnn_cell.LSTMCell, name='basic_lstm_cell') - elif self.hparams.rnn == 'gru': - RNNCell = tf.contrib.rnn.GRUCell - else: - raise NotImplementedError - rnn_cell = RNNCell(num_units, reuse=tf.get_variable_scope().reuse) - return rnn_cell(inputs, state) - - def _conv_rnn_func(self, inputs, state, filters): - if isinstance(inputs, (list, tuple)): - inputs_shape = inputs[0].shape.as_list() - else: - inputs_shape = inputs.shape.as_list() - input_shape = inputs_shape[1:] - if self.hparams.conv_rnn_norm_layer == 'none': - normalizer_fn = None - else: - normalizer_fn = ops.get_norm_layer(self.hparams.conv_rnn_norm_layer) - if self.hparams.conv_rnn == 'lstm': - Conv2DRNNCell = BasicConv2DLSTMCell - elif self.hparams.conv_rnn == 'gru': - Conv2DRNNCell = Conv2DGRUCell - else: - raise NotImplementedError - if self.hparams.ablation_conv_rnn_norm: - conv_rnn_cell = Conv2DRNNCell(input_shape, filters, kernel_size=(5, 5), - reuse=tf.get_variable_scope().reuse) - h, state = conv_rnn_cell(inputs, state) - outputs = (normalizer_fn(h), state) - else: - conv_rnn_cell = Conv2DRNNCell(input_shape, filters, kernel_size=(5, 5), - normalizer_fn=normalizer_fn, - separate_norms=self.hparams.conv_rnn_norm_layer == 'layer', - reuse=tf.get_variable_scope().reuse) - outputs = conv_rnn_cell(inputs, state) - return outputs - - def call(self, inputs, states): - norm_layer = ops.get_norm_layer(self.hparams.norm_layer) - downsample_layer = ops.get_downsample_layer(self.hparams.downsample_layer) - upsample_layer = ops.get_upsample_layer(self.hparams.upsample_layer) - activation_layer = ops.get_activation_layer(self.hparams.activation_layer) - image_shape = inputs['images'].get_shape().as_list() - batch_size, height, width, color_channels = image_shape - conv_rnn_states = states['conv_rnn_states'] - - time = states['time'] - with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): - t = tf.to_int32(tf.identity(time[0])) - - image = tf.where(self.ground_truth[t], inputs['images'], states['gen_image']) # schedule sampling (if any) - last_images = states['last_images'][1:] + [image] - if 'pix_distribs' in inputs: - pix_distrib = tf.where(self.ground_truth[t], inputs['pix_distribs'], states['gen_pix_distrib']) - last_pix_distribs = states['last_pix_distribs'][1:] + [pix_distrib] - if 'states' in inputs: - state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state']) - - state_action = [] - state_action_z = [] - if 'actions' in inputs: - state_action.append(inputs['actions']) - state_action_z.append(inputs['actions']) - if 'states' in inputs: - state_action.append(state) - # don't backpropagate the convnet through the state dynamics - state_action_z.append(tf.stop_gradient(state)) - - if 'zs' in inputs: - if self.hparams.use_rnn_z: - with tf.variable_scope('%s_z' % ('fc' if self.hparams.ablation_rnn else self.hparams.rnn)): - if self.hparams.ablation_rnn: - rnn_z = dense(inputs['zs'], self.hparams.nz) - rnn_z = tf.nn.tanh(rnn_z) - else: - rnn_z, rnn_z_state = self._rnn_func(inputs['zs'], states['rnn_z_state'], self.hparams.nz) - state_action_z.append(rnn_z) - else: - state_action_z.append(inputs['zs']) - - def concat(tensors, axis): - if len(tensors) == 0: - return tf.zeros([batch_size, 0]) - elif len(tensors) == 1: - return tensors[0] - else: - return tf.concat(tensors, axis=axis) - state_action = concat(state_action, axis=-1) - state_action_z = concat(state_action_z, axis=-1) - - layers = [] - new_conv_rnn_states = [] - for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs): - with tf.variable_scope('h%d' % i): - if i == 0: - h = tf.concat([image, self.inputs['images'][0]], axis=-1) - kernel_size = (5, 5) - else: - h = layers[-1][-1] - kernel_size = (3, 3) - if self.hparams.where_add == 'all' or (self.hparams.where_add == 'input' and i == 0): - if self.hparams.use_tile_concat: - h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) - else: - h = [h, state_action_z] - h = _maybe_tile_concat_layer(downsample_layer)( - h, out_channels, kernel_size=kernel_size, strides=(2, 2)) - h = norm_layer(h) - h = activation_layer(h) - if use_conv_rnn: - with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, i)): - if self.hparams.where_add == 'all': - if self.hparams.use_tile_concat: - conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) - else: - conv_rnn_h = [h, state_action_z] - else: - conv_rnn_h = h - if self.hparams.ablation_rnn: - conv_rnn_h = _maybe_tile_concat_layer(conv2d)( - conv_rnn_h, out_channels, kernel_size=(5, 5)) - conv_rnn_h = norm_layer(conv_rnn_h) - conv_rnn_h = activation_layer(conv_rnn_h) - else: - conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] - conv_rnn_h, conv_rnn_state = self._conv_rnn_func(conv_rnn_h, conv_rnn_state, out_channels) - new_conv_rnn_states.append(conv_rnn_state) - layers.append((h, conv_rnn_h) if use_conv_rnn else (h,)) - - num_encoder_layers = len(layers) - for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs): - with tf.variable_scope('h%d' % len(layers)): - if i == 0: - h = layers[-1][-1] - else: - h = tf.concat([layers[-1][-1], layers[num_encoder_layers - i - 1][-1]], axis=-1) - if self.hparams.where_add == 'all' or (self.hparams.where_add == 'middle' and i == 0): - if self.hparams.use_tile_concat: - h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) - else: - h = [h, state_action_z] - h = _maybe_tile_concat_layer(upsample_layer)( - h, out_channels, kernel_size=(3, 3), strides=(2, 2)) - h = norm_layer(h) - h = activation_layer(h) - if use_conv_rnn: - with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, len(layers))): - if self.hparams.where_add == 'all': - if self.hparams.use_tile_concat: - conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) - else: - conv_rnn_h = [h, state_action_z] - else: - conv_rnn_h = h - if self.hparams.ablation_rnn: - conv_rnn_h = _maybe_tile_concat_layer(conv2d)(conv_rnn_h, out_channels, kernel_size=(5, 5)) - conv_rnn_h = norm_layer(conv_rnn_h) - conv_rnn_h = activation_layer(conv_rnn_h) - else: - conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] - conv_rnn_h, conv_rnn_state = self._conv_rnn_func(conv_rnn_h, conv_rnn_state, out_channels) - new_conv_rnn_states.append(conv_rnn_state) - layers.append((h, conv_rnn_h) if use_conv_rnn else (h,)) - assert len(new_conv_rnn_states) == len(conv_rnn_states) - - if self.hparams.last_frames and self.hparams.num_transformed_images: - if self.hparams.transformation == 'flow': - with tf.variable_scope('h%d_flow' % len(layers)): - h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) - h_flow = norm_layer(h_flow) - h_flow = activation_layer(h_flow) - - with tf.variable_scope('flows'): - flows = conv2d(h_flow, 2 * self.hparams.last_frames * self.hparams.num_transformed_images, kernel_size=(3, 3), strides=(1, 1)) - flows = tf.reshape(flows, [batch_size, height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images]) - else: - assert len(self.hparams.kernel_size) == 2 - kernel_shape = list(self.hparams.kernel_size) + [self.hparams.last_frames * self.hparams.num_transformed_images] - if self.hparams.transformation == 'dna': - with tf.variable_scope('h%d_dna_kernel' % len(layers)): - h_dna_kernel = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) - h_dna_kernel = norm_layer(h_dna_kernel) - h_dna_kernel = activation_layer(h_dna_kernel) - - # Using largest hidden state for predicting untied conv kernels. - with tf.variable_scope('dna_kernels'): - kernels = conv2d(h_dna_kernel, np.prod(kernel_shape), kernel_size=(3, 3), strides=(1, 1)) - kernels = tf.reshape(kernels, [batch_size, height, width] + kernel_shape) - kernels = kernels + identity_kernel(self.hparams.kernel_size)[None, None, None, :, :, None] - kernel_spatial_axes = [3, 4] - elif self.hparams.transformation == 'cdna': - with tf.variable_scope('cdna_kernels'): - smallest_layer = layers[num_encoder_layers - 1][-1] - kernels = dense(flatten(smallest_layer), np.prod(kernel_shape)) - kernels = tf.reshape(kernels, [batch_size] + kernel_shape) - kernels = kernels + identity_kernel(self.hparams.kernel_size)[None, :, :, None] - kernel_spatial_axes = [1, 2] - else: - raise ValueError('Invalid transformation %s' % self.hparams.transformation) - - if self.hparams.transformation != 'flow': - with tf.name_scope('kernel_normalization'): - kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT - kernels /= tf.reduce_sum(kernels, axis=kernel_spatial_axes, keepdims=True) - - if self.hparams.generate_scratch_image: - with tf.variable_scope('h%d_scratch' % len(layers)): - h_scratch = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) - h_scratch = norm_layer(h_scratch) - h_scratch = activation_layer(h_scratch) - - # Using largest hidden state for predicting a new image layer. - # This allows the network to also generate one image from scratch, - # which is useful when regions of the image become unoccluded. - with tf.variable_scope('scratch_image'): - scratch_image = conv2d(h_scratch, color_channels, kernel_size=(3, 3), strides=(1, 1)) - scratch_image = tf.nn.sigmoid(scratch_image) - - with tf.name_scope('transformed_images'): - transformed_images = [] - if self.hparams.last_frames and self.hparams.num_transformed_images: - if self.hparams.transformation == 'flow': - transformed_images.extend(apply_flows(last_images, flows)) - else: - transformed_images.extend(apply_kernels(last_images, kernels, self.hparams.dilation_rate)) - if self.hparams.prev_image_background: - transformed_images.append(image) - if self.hparams.first_image_background and not self.hparams.context_images_background: - transformed_images.append(self.inputs['images'][0]) - if self.hparams.last_image_background and not self.hparams.context_images_background: - transformed_images.append(self.inputs['images'][self.hparams.context_frames - 1]) - if self.hparams.last_context_image_background and not self.hparams.context_images_background: - last_context_image = tf.cond( - tf.less(t, self.hparams.context_frames), - lambda: self.inputs['images'][t], - lambda: self.inputs['images'][self.hparams.context_frames - 1]) - transformed_images.append(last_context_image) - if self.hparams.context_images_background: - transformed_images.extend(tf.unstack(self.inputs['images'][:self.hparams.context_frames])) - if self.hparams.generate_scratch_image: - transformed_images.append(scratch_image) - - if 'pix_distribs' in inputs: - with tf.name_scope('transformed_pix_distribs'): - transformed_pix_distribs = [] - if self.hparams.last_frames and self.hparams.num_transformed_images: - if self.hparams.transformation == 'flow': - transformed_pix_distribs.extend(apply_flows(last_pix_distribs, flows)) - else: - transformed_pix_distribs.extend(apply_kernels(last_pix_distribs, kernels, self.hparams.dilation_rate)) - if self.hparams.prev_image_background: - transformed_pix_distribs.append(pix_distrib) - if self.hparams.first_image_background and not self.hparams.context_images_background: - transformed_pix_distribs.append(self.inputs['pix_distribs'][0]) - if self.hparams.last_image_background and not self.hparams.context_images_background: - transformed_pix_distribs.append(self.inputs['pix_distribs'][self.hparams.context_frames - 1]) - if self.hparams.last_context_image_background and not self.hparams.context_images_background: - last_context_pix_distrib = tf.cond( - tf.less(t, self.hparams.context_frames), - lambda: self.inputs['pix_distribs'][t], - lambda: self.inputs['pix_distribs'][self.hparams.context_frames - 1]) - transformed_pix_distribs.append(last_context_pix_distrib) - if self.hparams.context_images_background: - transformed_pix_distribs.extend(tf.unstack(self.inputs['pix_distribs'][:self.hparams.context_frames])) - if self.hparams.generate_scratch_image: - transformed_pix_distribs.append(pix_distrib) - - with tf.name_scope('masks'): - if len(transformed_images) > 1: - with tf.variable_scope('h%d_masks' % len(layers)): - h_masks = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) - h_masks = norm_layer(h_masks) - h_masks = activation_layer(h_masks) - - with tf.variable_scope('masks'): - if self.hparams.dependent_mask: - h_masks = tf.concat([h_masks] + transformed_images, axis=-1) - masks = conv2d(h_masks, len(transformed_images), kernel_size=(3, 3), strides=(1, 1)) - masks = tf.nn.softmax(masks) - masks = tf.split(masks, len(transformed_images), axis=-1) - elif len(transformed_images) == 1: - masks = [tf.ones([batch_size, height, width, 1])] - else: - raise ValueError("Either one of the following should be true: " - "last_frames and num_transformed_images, first_image_background, " - "prev_image_background, generate_scratch_image") - - with tf.name_scope('gen_images'): - assert len(transformed_images) == len(masks) - gen_image = tf.add_n([transformed_image * mask - for transformed_image, mask in zip(transformed_images, masks)]) - - if 'pix_distribs' in inputs: - with tf.name_scope('gen_pix_distribs'): - assert len(transformed_pix_distribs) == len(masks) - gen_pix_distrib = tf.add_n([transformed_pix_distrib * mask - for transformed_pix_distrib, mask in zip(transformed_pix_distribs, masks)]) - gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True) - - if 'states' in inputs: - with tf.name_scope('gen_states'): - with tf.variable_scope('state_pred'): - gen_state = dense(state_action, inputs['states'].shape[-1].value) - - outputs = {'gen_images': gen_image, - 'transformed_images': tf.stack(transformed_images, axis=-1), - 'masks': tf.stack(masks, axis=-1)} - if 'pix_distribs' in inputs: - outputs['gen_pix_distribs'] = gen_pix_distrib - outputs['transformed_pix_distribs'] = tf.stack(transformed_pix_distribs, axis=-1) - if 'states' in inputs: - outputs['gen_states'] = gen_state - if self.hparams.transformation == 'flow': - outputs['gen_flows'] = flows - flows_transposed = tf.transpose(flows, [0, 1, 2, 4, 3]) - flows_rgb_transposed = tf_utils.flow_to_rgb(flows_transposed) - flows_rgb = tf.transpose(flows_rgb_transposed, [0, 1, 2, 4, 3]) - outputs['gen_flows_rgb'] = flows_rgb - - new_states = {'time': time + 1, - 'gen_image': gen_image, - 'last_images': last_images, - 'conv_rnn_states': new_conv_rnn_states} - if 'zs' in inputs and self.hparams.use_rnn_z and not self.hparams.ablation_rnn: - new_states['rnn_z_state'] = rnn_z_state - if 'pix_distribs' in inputs: - new_states['gen_pix_distrib'] = gen_pix_distrib - new_states['last_pix_distribs'] = last_pix_distribs - if 'states' in inputs: - new_states['gen_state'] = gen_state - return outputs, new_states - - -def generator_given_z_fn(inputs, mode, hparams): - # all the inputs needs to have the same length for unrolling the rnn - #20200822 bing - inputs ={"images":inputs["images"]} - #20200822 - inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) - for name, input in inputs.items()} - cell = SAVPCell(inputs, mode, hparams) - outputs, _ = tf_utils.unroll_rnn(cell, inputs) - outputs['ground_truth_sampling_mean'] = tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:])) - return outputs - - -def generator_fn(inputs, mode, hparams): - batch_size = tf.shape(inputs['images'])[1] - - if hparams.nz == 0: - # no zs is given in inputs - outputs = generator_given_z_fn(inputs, mode, hparams) - else: - zs_shape = [hparams.sequence_length - 1, batch_size, hparams.nz] - - # posterior - with tf.variable_scope('encoder'): - outputs_posterior = posterior_fn(inputs, hparams) - eps = tf.random_normal(zs_shape, 0, 1) - zs_posterior = outputs_posterior['zs_mu'] + tf.sqrt(tf.exp(outputs_posterior['zs_log_sigma_sq'])) * eps - inputs_posterior = dict(inputs) - inputs_posterior['zs'] = zs_posterior - - # prior - if hparams.learn_prior: - with tf.variable_scope('prior'): - outputs_prior = prior_fn(inputs, hparams) - eps = tf.random_normal(zs_shape, 0, 1) - zs_prior = outputs_prior['zs_mu'] + tf.sqrt(tf.exp(outputs_prior['zs_log_sigma_sq'])) * eps - else: - outputs_prior = {} - zs_prior = tf.random_normal([hparams.sequence_length - hparams.context_frames] + zs_shape[1:], 0, 1) - zs_prior = tf.concat([zs_posterior[:hparams.context_frames - 1], zs_prior], axis=0) - inputs_prior = dict(inputs) - inputs_prior['zs'] = zs_prior - - # generate - gen_outputs_posterior = generator_given_z_fn(inputs_posterior, mode, hparams) - tf.get_variable_scope().reuse_variables() - gen_outputs = generator_given_z_fn(inputs_prior, mode, hparams) - - # rename tensors to avoid name collisions - output_prior = collections.OrderedDict([(k + '_prior', v) for k, v in outputs_prior.items()]) - outputs_posterior = collections.OrderedDict([(k + '_enc', v) for k, v in outputs_posterior.items()]) - gen_outputs_posterior = collections.OrderedDict([(k + '_enc', v) for k, v in gen_outputs_posterior.items()]) - - outputs = [output_prior, gen_outputs, outputs_posterior, gen_outputs_posterior] - total_num_outputs = sum([len(output) for output in outputs]) - outputs = collections.OrderedDict(itertools.chain(*[output.items() for output in outputs])) - assert len(outputs) == total_num_outputs # ensure no output is lost because of repeated keys - - # generate multiple samples from the prior for visualization purposes - inputs_samples = { - name: tf.tile(input[:, None], [1, hparams.num_samples] + [1] * (input.shape.ndims - 1)) - for name, input in inputs.items()} - zs_samples_shape = [hparams.sequence_length - 1, hparams.num_samples, batch_size, hparams.nz] - if hparams.learn_prior: - eps = tf.random_normal(zs_samples_shape, 0, 1) - zs_prior_samples = (outputs_prior['zs_mu'][:, None] + - tf.sqrt(tf.exp(outputs_prior['zs_log_sigma_sq']))[:, None] * eps) - else: - zs_prior_samples = tf.random_normal( - [hparams.sequence_length - hparams.context_frames] + zs_samples_shape[1:], 0, 1) - zs_prior_samples = tf.concat( - [tf.tile(zs_posterior[:hparams.context_frames - 1][:, None], [1, hparams.num_samples, 1, 1]), - zs_prior_samples], axis=0) - inputs_prior_samples = dict(inputs_samples) - inputs_prior_samples['zs'] = zs_prior_samples - inputs_prior_samples = {name: flatten(input, 1, 2) for name, input in inputs_prior_samples.items()} - gen_outputs_samples = generator_given_z_fn(inputs_prior_samples, mode, hparams) - gen_images_samples = gen_outputs_samples['gen_images'] - gen_images_samples = tf.stack(tf.split(gen_images_samples, hparams.num_samples, axis=1), axis=-1) - gen_images_samples_avg = tf.reduce_mean(gen_images_samples, axis=-1) - outputs['gen_images_samples'] = gen_images_samples - outputs['gen_images_samples_avg'] = gen_images_samples_avg - return outputs - - -class SAVPVideoPredictionModel(VideoPredictionModel): - def __init__(self, *args, **kwargs): - super(SAVPVideoPredictionModel, self).__init__( - generator_fn, discriminator_fn, *args, **kwargs) - if self.mode != 'train': - self.discriminator_fn = None - self.deterministic = not self.hparams.nz - - def get_default_hparams_dict(self): - default_hparams = super(SAVPVideoPredictionModel, self).get_default_hparams_dict() - hparams = dict( - l1_weight=1.0, - l2_weight=0.0, - n_layers=3, - ndf=32, - norm_layer='instance', - use_same_discriminator=False, - ngf=32, - downsample_layer='conv_pool2d', - upsample_layer='upsample_conv2d', - activation_layer='relu', # for generator only - transformation='cdna', - kernel_size=(5, 5), - dilation_rate=(1, 1), - where_add='all', - use_tile_concat=True, - learn_initial_state=False, - rnn='lstm', - conv_rnn='lstm', - conv_rnn_norm_layer='instance', - num_transformed_images=4, - last_frames=1, - prev_image_background=True, - first_image_background=True, - last_image_background=False, - last_context_image_background=False, - context_images_background=False, - generate_scratch_image=True, - dependent_mask=True, - schedule_sampling='inverse_sigmoid', - schedule_sampling_k=900.0, - schedule_sampling_steps=(0, 100000), - use_e_rnn=False, - learn_prior=False, - nz=8, - num_samples=8, - nef=64, - use_rnn_z=True, - ablation_conv_rnn_norm=False, - ablation_rnn=False, - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - def parse_hparams(self, hparams_dict, hparams): - # backwards compatibility - deprecated_hparams_keys = [ - 'num_gpus', - 'e_net', - 'd_conditional', - 'd_downsample_layer', - 'd_net', - 'd_use_gt_inputs', - 'acvideo_gan_weight', - 'acvideo_vae_gan_weight', - 'image_gan_weight', - 'image_vae_gan_weight', - 'tuple_gan_weight', - 'tuple_vae_gan_weight', - 'gan_weight', - 'vae_gan_weight', - 'video_gan_weight', - 'video_vae_gan_weight', - ] - for deprecated_hparams_key in deprecated_hparams_keys: - hparams_dict.pop(deprecated_hparams_key, None) - return super(SAVPVideoPredictionModel, self).parse_hparams(hparams_dict, hparams) - - def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): - def restore_to_checkpoint_mapping(restore_name, checkpoint_var_names): - restore_name = restore_name.split(':')[0] - if restore_name not in checkpoint_var_names: - restore_name = restore_name.replace('savp_cell', 'dna_cell') - return restore_name - - super(SAVPVideoPredictionModel, self).restore(sess, checkpoints, restore_to_checkpoint_mapping) - - -def apply_dna_kernels(image, kernels, dilation_rate=(1, 1)): - """ - Args: - image: A 4-D tensor of shape - `[batch, in_height, in_width, in_channels]`. - kernels: A 6-D of shape - `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`. - Returns: - A list of `num_transformed_images` 4-D tensors, each of shape - `[batch, in_height, in_width, in_channels]`. - """ - dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2 - batch_size, height, width, color_channels = image.get_shape().as_list() - batch_size, height, width, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list() - kernel_size = [kernel_height, kernel_width] - - # Flatten the spatial dimensions. - kernels_reshaped = tf.reshape(kernels, [batch_size, height, width, - kernel_size[0] * kernel_size[1], num_transformed_images]) - image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC') - # Combine channel and batch dimensions into the first dimension. - image_transposed = tf.transpose(image_padded, [3, 0, 1, 2]) - image_reshaped = flatten(image_transposed, 0, 1)[..., None] - patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1], - strides=[1] * 4, rates=[1] + dilation_rate + [1], padding='VALID') - # Separate channel and batch dimensions, and move channel dimension. - patches_transposed = tf.reshape(patches_reshaped, [color_channels, batch_size, height, width, kernel_size[0] * kernel_size[1]]) - patches = tf.transpose(patches_transposed, [1, 2, 3, 0, 4]) - # Reduce along the spatial dimensions of the kernel. - outputs = tf.matmul(patches, kernels_reshaped) - outputs = tf.unstack(outputs, axis=-1) - return outputs - - -def apply_cdna_kernels(image, kernels, dilation_rate=(1, 1)): - """ - Args: - image: A 4-D tensor of shape - `[batch, in_height, in_width, in_channels]`. - kernels: A 4-D of shape - `[batch, kernel_size[0], kernel_size[1], num_transformed_images]`. - Returns: - A list of `num_transformed_images` 4-D tensors, each of shape - `[batch, in_height, in_width, in_channels]`. - """ - batch_size, height, width, color_channels = image.get_shape().as_list() - batch_size, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list() - kernel_size = [kernel_height, kernel_width] - image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC') - # Treat the color channel dimension as the batch dimension since the same - # transformation is applied to each color channel. - # Treat the batch dimension as the channel dimension so that - # depthwise_conv2d can apply a different transformation to each sample. - kernels = tf.transpose(kernels, [1, 2, 0, 3]) - kernels = tf.reshape(kernels, [kernel_size[0], kernel_size[1], batch_size, num_transformed_images]) - # Swap the batch and channel dimensions. - image_transposed = tf.transpose(image_padded, [3, 1, 2, 0]) - # Transform image. - outputs = tf.nn.depthwise_conv2d(image_transposed, kernels, [1, 1, 1, 1], padding='VALID', rate=dilation_rate) - # Transpose the dimensions to where they belong. - outputs = tf.reshape(outputs, [color_channels, height, width, batch_size, num_transformed_images]) - outputs = tf.transpose(outputs, [4, 3, 1, 2, 0]) - outputs = tf.unstack(outputs, axis=0) - return outputs - - -def apply_kernels(image, kernels, dilation_rate=(1, 1)): - """ - Args: - image: A 4-D tensor of shape - `[batch, in_height, in_width, in_channels]`. - kernels: A 4-D or 6-D tensor of shape - `[batch, kernel_size[0], kernel_size[1], num_transformed_images]` or - `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`. - Returns: - A list of `num_transformed_images` 4-D tensors, each of shape - `[batch, in_height, in_width, in_channels]`. - """ - if isinstance(image, list): - image_list = image - kernels_list = tf.split(kernels, len(image_list), axis=-1) - outputs = [] - for image, kernels in zip(image_list, kernels_list): - outputs.extend(apply_kernels(image, kernels)) - else: - if len(kernels.get_shape()) == 4: - outputs = apply_cdna_kernels(image, kernels, dilation_rate=dilation_rate) - elif len(kernels.get_shape()) == 6: - outputs = apply_dna_kernels(image, kernels, dilation_rate=dilation_rate) - else: - raise ValueError - return outputs - - -def apply_flows(image, flows): - if isinstance(image, list): - image_list = image - flows_list = tf.split(flows, len(image_list), axis=-1) - outputs = [] - for image, flows in zip(image_list, flows_list): - outputs.extend(apply_flows(image, flows)) - else: - flows = tf.unstack(flows, axis=-1) - outputs = [flow_ops.image_warp(image, flow) for flow in flows] - return outputs - - -def identity_kernel(kernel_size): - kh, kw = kernel_size - kernel = np.zeros(kernel_size) - - def center_slice(k): - if k % 2 == 0: - return slice(k // 2 - 1, k // 2 + 1) - else: - return slice(k // 2, k // 2 + 1) - - kernel[center_slice(kh), center_slice(kw)] = 1.0 - kernel /= np.sum(kernel) - return kernel - - -def _maybe_tile_concat_layer(conv2d_layer): - def layer(inputs, out_channels, *args, **kwargs): - if isinstance(inputs, (list, tuple)): - inputs_spatial, inputs_non_spatial = inputs - outputs = (conv2d_layer(inputs_spatial, out_channels, *args, **kwargs) + - dense(inputs_non_spatial, out_channels, use_bias=False)[:, None, None, :]) - else: - outputs = conv2d_layer(inputs, out_channels, *args, **kwargs) - return outputs - - return layer diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py index dc58cba3638309cd6b7054b91e45363f2aa42fce..70c050f53efa81f1c482fa06f3cd1a5318b6e987 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MIT __email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth" +__author__ = "Bing Gong" __date__ = "2020-11-05" from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames @@ -11,119 +11,137 @@ import tensorflow as tf from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from .our_base_model import BaseModels +from hparams_utils import * class VanillaConvLstmVideoPredictionModel(BaseModels): - def __init__(self, hparams_dict=None, **kwargs): + + def __init__(self, hparams_dict_config=None, **kwargs): """ This is class for building convLSTM architecture by using updated hparameters args: - hparams_dict : dict, the dictionary contains the hparaemters names and values + hparams_dict : dict, the dictionary contains the hparaemters names and values """ - super().__init__(hparams_dict) - self.get_hparams() + super().__init__(hparams_dict_config) - def get_hparams(self): + def parse_hparams(self, hparams): """ - obtain the hparams from the dict to the class variables + obtain the hyper-parameters from the dictionary """ - method = BaseModels.get_hparams.__name__ try: - self.context_frames = self.hparams.context_frames - self.sequence_length = self.hparams.sequence_length - self.max_epochs = self.hparams.max_epochs - self.batch_size = self.hparams.batch_size - self.shuffle_on_val = self.hparams.shuffle_on_val - self.loss_fun = self.hparams.loss_fun - self.opt_var = self.hparams.opt_var - self.learning_rate = self.hparams.lr + self.context_frames = hparams.context_frames + self.sequence_length = hparams.sequence_length + self.max_epochs = hparams.max_epochs + self.batch_size = hparams.batch_size + self.shuffle_on_val = hparams.shuffle_on_val + self.loss_fun = hparams.loss_fun + self.opt_var = hparams.opt_var + self.lr = hparams.lr self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) print("The model hparams have been parsed successfully! ") - except Exception as error: - print("Method %{}: error: {}".format(method, error)) - raise("Method %{}: the hparameter dictionary must include " - "'context_frames','max_epochs','batch_size','shuffle_on_val' 'loss_fun'," - "'opt_var', 'lr', 'opt_var'".format(method)) + except Exception as e: + raise ValueError(f"missing hyperparameter: {e.args[0]}") + - def build_graph(self, x): - self.is_build_graph = False + def build_graph(self, x:tf.Tensor): self.inputs = x - self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables() + x_hat = self.build_model(x) + self.total_loss = self.get_loss(x, x_hat) + self.train_op = self.optimizer(self.total_loss) + self.outputs["gen_images"] = x_hat + self.summary_op = self.summary(total_loss = self.total_loss) + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self._is_build_graph_set = True + return self._is_build_graph_set - self.build_model() - #This is the loss function (MSE): + def get_loss(self, x:tf.Tensor, x_hat:tf.Tensor)->tf.Tensor: + """ + :param x : Input tensors + :param x_hat: Prediction/output tensors + :return : the loss function + """ + #This is the loss function (MSE): #Optimize all target variables/channels if self.opt_var == "all": - x = self.inputs[:, self.context_frames:, :, :, :] - x_hat = self.x_hat_predict_frames[:, :, :, :, :] - print ("The model is optimzied on all the variables in the loss function") + x = x[:, self.context_frames:, :, :, :] + print("The model is optimzied on all the variables in the loss function") elif self.opt_var != "all" and isinstance(self.opt_var, str): self.opt_var = int(self.opt_var) - print ("The model is optimized on the {} variable in the loss function".format(self.opt_var)) - x = self.inputs[:, self.context_frames:, :, :, self.opt_var] - x_hat = self.x_hat_predict_frames[:, :, :, :, self.opt_var] + print("The model is optimized on the {} variable in the loss function".format(self.opt_var)) + x = x[:, self.context_frames:, :, :, self.opt_var] + x_hat = x_hat[:, :, :, :, self.opt_var] else: raise ValueError("The opt var in the hyperparameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables") if self.loss_fun == "mse": - self.total_loss = tf.reduce_mean(tf.square(x - x_hat)) + total_loss = tf.reduce_mean(tf.square(x - x_hat)) elif self.loss_fun == "cross_entropy": x_flatten = tf.reshape(x, [-1]) x_hat_predict_frames_flatten = tf.reshape(x_hat, [-1]) bce = tf.keras.losses.BinaryCrossentropy() - self.total_loss = bce(x_flatten, x_hat_predict_frames_flatten) + total_loss = bce(x_flatten, x_hat_predict_frames_flatten) else: raise ValueError("Loss function is not selected properly, you should chose either 'mse' or 'cross_entropy'") + return total_loss - #This is the loss for only all the channels(temperature, geo500, pressure) - #self.total_loss = tf.reduce_mean( - # tf.square(self.x[:, self.context_frames:,:,:,:] - self.x_hat_predict_frames[:,:,:,:,:])) - - self.train_op = tf.train.AdamOptimizer( - learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) - - self.outputs["gen_images"] = self.x_hat - # Summary op - self.loss_summary = tf.summary.scalar("total_loss", self.total_loss) - self.summary_op = tf.summary.merge_all() - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [self.global_step] + global_variables - self.is_build_graph = True - return self.is_build_graph - def build_model(self): + + def summary(self, total_loss)->None: + """ + return the summary operation can be used for TensorBoard + """ + tf.summary.scalar("total_loss", total_loss) + summary_op = tf.summary.merge_all() + return summary_op + + def build_model(self, x: tf.Tensor): network_template = tf.make_template('network', VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables + + x_hat = VanillaConvLstmVideoPredictionModel.convLSTM_network(x, + self.sequence_length, + self.context_frames, + network_template) + return x_hat + + + @staticmethod + def convLSTM_network(x:tf.Tensor, sequence_length:int, context_frames:int, network_template:tf.make_template)->tf.Tensor: + # create network x_hat = [] - - #This is for training (optimization of convLSTM layer) + + # This is for training (optimization of convLSTM layer) hidden_g = None - for i in range(self.sequence_length-1): - if i < self.context_frames: - x_1_g, hidden_g = network_template(self.inputs[:, i, :, :, :], hidden_g) + for i in range(sequence_length - 1): + if i < context_frames: + x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g) else: x_1_g, hidden_g = network_template(x_1_g, hidden_g) x_hat.append(x_1_g) # pack them all together x_hat = tf.stack(x_hat) - self.x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim - self.x_hat_predict_frames = self.x_hat[:, self.context_frames-1:, :, :, :] + x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim + x_hat = x_hat[:, context_frames - 1:, :, :, :] + return x_hat + @staticmethod - def convLSTM_cell(inputs, hidden): + + def convLSTM_cell(inputs:tf.Tensor, hidden:tf.Tensor): """ SPDX-FileCopyrightText: loliverhennigh SPDX-License-Identifier: Apache-2.0 The following function was revised based on the github https://github.com/loliverhennigh/Convolutional-LSTM-in-Tensorflow """ y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset - channels = inputs.get_shape()[-1] + # conv lstm cell cell_shape = y_0.get_shape().as_list() channels = cell_shape[-1] @@ -137,24 +155,3 @@ class VanillaConvLstmVideoPredictionModel(BaseModels): #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") return x_hat, hidden - - @staticmethod - def set_and_check_pred_frames(seq_length, context_frames): - """ - Checks if sequence length and context_frames are set properly and returns number of frames to be predicted. - :param seq_length: number of frames/images per sequences - :param context_frames: number of context frames/images - :return: number of predicted frames - """ - - method = VanillaConvLstmVideoPredictionModel.set_and_check_pred_frames.__name__ - - # sanity checks - assert isinstance(seq_length, int), "%{0}: Sequence length (seq_length) must be an integer".format(method) - assert isinstance(context_frames, int), "%{0}: Number of context frames must be an integer".format(method) - - if seq_length > context_frames: - return seq_length-context_frames - else: - raise ValueError("%{0}: Sequence length ({1}) must be larger than context frames ({2})." - .format(method, seq_length, context_frames)) diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py deleted file mode 100644 index 5365bcb8c3c6e739bb2ba6ae04813e361e297921..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong" -__date__ = "2020-09-01" - -from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames -import tensorflow as tf -from model_modules.video_prediction.layers import layer_def as ld -from tensorflow.contrib.training import HParams - - -class VanillaVAEVideoPredictionModel(object): - def __init__(self, hparams_dict=None, **kwargs): - """ - This is class for building convLSTM architecture by using updated hparameters - args: - hparams_dict: dict, the dictionary contains the hparaemters names and values - """ - - self.hparams_dict = hparams_dict - self.hparams = self.parse_hparams() - self.learning_rate = self.hparams.lr - self.total_loss = None - self.context_frames = self.hparams.context_frames - self.sequence_length = self.hparams.sequence_length - self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) - self.max_epochs = self.hparams.max_epochs - self.nz = self.hparams.nz - self.loss_fun = self.hparams.loss_fun - self.batch_size = self.hparams.batch_size - self.shuffle_on_val = self.hparams.shuffle_on_val - self.weight_recon = self.hparams.weight_recon - - def get_default_hparams(self): - return HParams(**self.get_default_hparams_dict()) - - def parse_hparams(self): - """ - Parse the hparams setting to ovoerride the default ones - """ - parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) - return parsed_hparams - - - def get_default_hparams_dict(self): - """ - The function that contains default hparams - Returns: - A dict with the following hyperparameters. - context_frames : the number of ground-truth frames to pass in at start. - sequence_length : the number of frames in the video sequence - max_epochs : the number of epochs to train model - lr : learning rate - loss_fun : the loss function - """ - hparams = dict( - context_frames=10, - sequence_length=24, - max_epochs = 20, - batch_size = 4, - lr = 0.001, - nz = 16, - loss_fun = "cross_entropy", - weight_recon = 1, - shuffle_on_val= True, - ) - return hparams - - - def build_graph(self,x): - self.x = x["images"] - self.global_step = tf.train.get_or_create_global_step() - original_global_variables = tf.global_variables() - self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all() - #This is the loss function (RMSE): - #This is loss function only for 1 channel (temperature RMSE) - if self.loss_fun == "rmse": - self.recon_loss = tf.reduce_mean(tf.square(self.x[:,self.context_frames:,:,:,0] - self.x_hat[:,self.context_frames-1:,:,:,0])) - elif self.loss_fun == "cross_entropy": - x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1]) - x_hat_predict_frames_flatten = tf.reshape(self.x_hat[:,self.context_frames-1:,:,:,0],[-1]) - bce = tf.keras.losses.BinaryCrossentropy() - self.recon_loss = bce(x_flatten,x_hat_predict_frames_flatten) - else: - raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'") - - latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq - tf.square(self.z_mu) -tf.exp(self.z_log_sigma_sq), axis=1) - self.latent_loss = tf.reduce_mean(latent_loss) - self.total_loss = self.weight_recon * self.recon_loss + self.latent_loss - self.train_op = tf.train.AdamOptimizer( - learning_rate = self.learning_rate).minimize(self.total_loss, global_step=self.global_step) - # Build a saver - self.losses = { - 'recon_loss': self.recon_loss, - 'latent_loss': self.latent_loss, - 'total_loss': self.total_loss, - } - - # Summary op - self.loss_summary = tf.summary.scalar("recon_loss", self.recon_loss) - self.loss_summary = tf.summary.scalar("latent_loss", self.latent_loss) - self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss) - self.summary_op = tf.summary.merge_all() - self.outputs = {} - self.outputs["gen_images"] = self.x_hat - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [self.global_step] + global_variables - return None - - - @staticmethod - def vae_arc3(x,l_name=0,nz=16): - """ - VAE for one timestamp of sequence - args: - x : input tensor, shape is [batch_size,height, width, channel] - l_name : int, default is 0, the sequence index - nz : int, default is 16, the latent space - return - x_hat : tensor, is the predicted value - z_mu : tensor, means values of latent space - z_log_sigma_sq: sensor, the variances of latent space - z : tensor, the normal distribution with z_mu, z-log_sigma_sq - - """ - input_shape = x.get_shape().as_list() - input_width = input_shape[2] - input_height = input_shape[1] - print("input_heights:",input_height) - seq_name = "sq_" + str(l_name) + "_" - conv1 = ld.conv_layer(inputs=x, kernel_size=3, stride=2, num_features=8, idx=seq_name + "encode_1") - conv1_shape = conv1.get_shape().as_list() - print("conv1_shape:",conv1_shape) - assert conv1_shape[3] == 8 #Features check - assert conv1_shape[1] == int((input_height - 3 + 1)/2) + 1 #[(Input_volumn - kernel_size + padding)/stride] + 1 - conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2") - conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3") - conv4 = tf.layers.Flatten()(conv3) - conv3_shape = conv3.get_shape().as_list() - z_mu = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m") - z_log_sigma_sq = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma') - eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32) - z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps - z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") - z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]]) - conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, seq_name + "decode_5") - conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8,seq_name + "decode_6") - x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8") - x_hat_shape = x_hat.get_shape().as_list() - pred_height = x_hat_shape[1] - pred_width = x_hat_shape[2] - assert pred_height == input_height - assert pred_width == input_width - return x_hat, z_mu, z_log_sigma_sq, z - - def vae_arc_all(self): - """ - Build architecture for all the sequences - """ - X = [] - z_log_sigma_sq_all = [] - z_mu_all = [] - for i in range(self.sequence_length-1): - q, z_mu, z_log_sigma_sq, z = VanillaVAEVideoPredictionModel.vae_arc3(self.x[:, i, :, :, :], l_name=i, nz=self.nz) - X.append(q) - z_log_sigma_sq_all.append(z_log_sigma_sq) - z_mu_all.append(z_mu) - x_hat = tf.stack(X, axis = 1) - x_hat_shape = x_hat.get_shape().as_list() - print("x-ha-shape:",x_hat_shape) - z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1) - z_mu_all = tf.stack(z_mu_all, axis = 1) - return x_hat, z_log_sigma_sq_all, z_mu_all - - - diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..033bd499ef5717391d9fab8cced03b268135b4c9 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT +# Weather Bench models +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong" +__date__ = "2021-04-13" + +import tensorflow as tf +from model_modules.video_prediction.layers import layer_def as ld +from model_modules.video_prediction.losses import l1_loss +from .our_base_model import BaseModels + +class WeatherBenchModel(BaseModels): + + def __init__(self, hparams_dict_config: dict=None, mode:str="train", **kwargs): + """ + This is class for building weatherBench architecture by using updated hparameters + args: + mode :"train" or "val", side note: mode may not be used in the convLSTM, + this will be a useful argument for the GAN-based model + hparams_dict :The dictionary contains the hyper-parameters names and values + """ + super().__init__(hparams_dict_config, mode) + + + def parse_hparams(self, hparams): + """ + Obtain the hyper-parameters from the dict to the class variables + """ + try: + self.context_frames = self.hparams.context_frames + self.max_epochs = self.hparams.max_epochs + self.batch_size = self.hparams.batch_size + self.shuffle_on_val = self.hparams.shuffle_on_val + self.loss_fun = self.hparams.loss_fun + self.learning_rate = self.hparams.lr + self.sequence_length = self.hparams.sequence_length + except Exception as error: + print("error: {}".format(error)) + raise ValueError("Method %{}: the hyper-parameter dictionary " + "must include parameters above") + + def get_loss(self, x: tf.Tensor, x_hat: tf.Tensor): + # Loss + total_loss = l1_loss(x[:, 1, :, :, :], x_hat[:, :, :, :]) + return total_loss + + def optimizer(self, total_loss): + return tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(total_loss, + global_step = + self.global_step) + + def build_model(self, x): + """Fully convolutional network""" + x = x[:, 0, :, :, :] + _idx = 0 + filters = [64, 64, 64, 64, 2] + kernels = [5, 5, 5, 5, 5] + + for f, k in zip(filters[:-1], kernels[:-1]): + with tf.variable_scope("conv_layer_"+str(_idx), reuse=tf.AUTO_REUSE): + x = ld.conv_layer(x, kernel_size=k, stride=1, + num_features=f, + idx="conv_layer_"+str(_idx), + activate="leaky_relu") + _idx += 1 + with tf.variable_scope("conv_last_layer", reuse=tf.AUTO_REUSE): + output = ld.conv_layer(x, kernel_size=kernels[-1], + stride=1, num_features=filters[-1], + idx="conv_last_layer", activate="linear") + + return output + + + def forecast(self, x, forecast_time): + x_hat = [] + + for i in range(forecast_time): + if i == 0: + x_pred = self.build_model(x[:, i, :, :, :], filters, kernels) + else: + x_pred = self.build_model(x_pred, filters, kernels) + x_hat.append(x_pred) + + x_hat = tf.stack(x_hat) + x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4]) + return x_hat + + diff --git a/video_prediction_tools/model_modules/video_prediction/ops.py b/video_prediction_tools/model_modules/video_prediction/ops.py deleted file mode 100644 index e33f559ebfcf1e1d286136feb6b8ef95a8410564..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/ops.py +++ /dev/null @@ -1,1102 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import numpy as np -import tensorflow as tf - - -def dense(inputs, units, use_spectral_norm=False, use_bias=True): - with tf.variable_scope('dense'): - input_shape = inputs.get_shape().as_list() - kernel_shape = [input_shape[1], units] - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) - if use_spectral_norm: - kernel = spectral_normed_weight(kernel) - outputs = tf.matmul(inputs, kernel) - if use_bias: - bias = tf.get_variable('bias', [units], dtype=tf.float32, initializer=tf.zeros_initializer()) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def pad1d(inputs, size, strides=(1,), padding='SAME', mode='CONSTANT'): - size = list(size) if isinstance(size, (tuple, list)) else [size] - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] - input_shape = inputs.get_shape().as_list() - assert len(input_shape) == 3 - in_width = input_shape[1] - if padding in ('SAME', 'FULL'): - if in_width % strides[0] == 0: - pad_along_width = max(size[0] - strides[0], 0) - else: - pad_along_width = max(size[0] - (in_width % strides[0]), 0) - if padding == 'SAME': - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - else: - pad_left = pad_along_width - pad_right = pad_along_width - padding_pattern = [[0, 0], - [pad_left, pad_right], - [0, 0]] - outputs = tf.pad(inputs, padding_pattern, mode=mode) - elif padding == 'VALID': - outputs = inputs - else: - raise ValueError("Invalid padding scheme %s" % padding) - return outputs - - -def conv1d(inputs, filters, kernel_size, strides=(1,), padding='SAME', kernel=None, use_bias=True): - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] - input_shape = inputs.get_shape().as_list() - kernel_shape = list(kernel_size) + [input_shape[-1], filters] - if kernel is None: - with tf.variable_scope('conv1d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - else: - if kernel_shape != kernel.get_shape().as_list(): - raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) - if padding == 'FULL': - inputs = pad1d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') - padding = 'VALID' - stride, = strides - outputs = tf.nn.conv1d(inputs, kernel, stride, padding=padding) - if use_bias: - with tf.variable_scope('conv1d'): - bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def pad2d_paddings(inputs, size, strides=(1, 1), rate=(1, 1), padding='SAME'): - """ - Computes the paddings for a 4-D tensor according to the convolution padding algorithm. - - See pad2d. - - Reference: - https://www.tensorflow.org/api_guides/python/nn#convolution - https://www.tensorflow.org/api_docs/python/tf/nn/with_space_to_batch - """ - size = np.array(size) if isinstance(size, (tuple, list)) else np.array([size] * 2) - strides = np.array(strides) if isinstance(strides, (tuple, list)) else np.array([strides] * 2) - rate = np.array(rate) if isinstance(rate, (tuple, list)) else np.array([rate] * 2) - if np.any(strides > 1) and np.any(rate > 1): - raise ValueError("strides > 1 not supported in conjunction with rate > 1") - input_shape = inputs.get_shape().as_list() - assert len(input_shape) == 4 - input_size = np.array(input_shape[1:3]) - if padding in ('SAME', 'FULL'): - if np.any(rate > 1): - # We have two padding contributions. The first is used for converting "SAME" - # to "VALID". The second is required so that the height and width of the - # zero-padded value tensor are multiples of rate. - - # Spatial dimensions of the filters and the upsampled filters in which we - # introduce (rate - 1) zeros between consecutive filter values. - dilated_size = size + (size - 1) * (rate - 1) - pad = dilated_size - 1 - else: - pad = np.where(input_size % strides == 0, - np.maximum(size - strides, 0), - np.maximum(size - (input_size % strides), 0)) - if padding == 'SAME': - # When full_padding_shape is odd, we pad more at end, following the same - # convention as conv2d. - pad_start = pad // 2 - pad_end = pad - pad_start - else: - pad_start = pad - pad_end = pad - if np.any(rate > 1): - # More padding so that rate divides the height and width of the input. - # TODO: not sure if this is correct when padding == 'FULL' - orig_pad_end = pad_end - full_input_size = input_size + pad_start + orig_pad_end - pad_end_extra = (rate - full_input_size % rate) % rate - pad_end = orig_pad_end + pad_end_extra - paddings = [[0, 0], - [pad_start[0], pad_end[0]], - [pad_start[1], pad_end[1]], - [0, 0]] - elif padding == 'VALID': - paddings = [[0, 0]] * 4 - else: - raise ValueError("Invalid padding scheme %s" % padding) - return paddings - - -def pad2d(inputs, size, strides=(1, 1), rate=(1, 1), padding='SAME', mode='CONSTANT'): - """ - Pads a 4-D tensor according to the convolution padding algorithm. - - Convolution with a padding scheme - conv2d(..., padding=padding) - is equivalent to zero-padding of the input with such scheme, followed by - convolution with 'VALID' padding - padded = pad2d(..., padding=padding, mode='CONSTANT') - conv2d(padded, ..., padding='VALID') - - Args: - inputs: A 4-D tensor of shape - `[batch, in_height, in_width, in_channels]`. - padding: A string, either 'VALID', 'SAME', or 'FULL'. The padding algorithm. - mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive). - - Returns: - A 4-D tensor. - - Reference: - https://www.tensorflow.org/api_guides/python/nn#convolution - """ - paddings = pad2d_paddings(inputs, size, strides=strides, rate=rate, padding=padding) - if paddings == [[0, 0]] * 4: - outputs = inputs - else: - outputs = tf.pad(inputs, paddings, mode=mode) - return outputs - - -def local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', - kernel=None, flip_filters=False, - use_bias=True, channelwise=False): - """ - 2-D locally connected operation. - - Works similarly to 2-D convolution except that the weights are unshared, that is, a different set of filters is - applied at each different patch of the input. - - Args: - inputs: A 4-D tensor of shape - `[batch, in_height, in_width, in_channels]`. - kernel: A 6-D or 7-D tensor of shape - `[in_height, in_width, kernel_size[0], kernel_size[1], in_channels, filters]` or - `[batch, in_height, in_width, kernel_size[0], kernel_size[1], in_channels, filters]`. - - Returns: - A 4-D tensor. - """ - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - if strides != [1, 1]: - raise NotImplementedError - if padding == 'FULL': - inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') - padding = 'VALID' - input_shape = inputs.get_shape().as_list() - if padding == 'SAME': - output_shape = input_shape[:3] + [filters] - elif padding == 'VALID': - output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters] - else: - raise ValueError("Invalid padding scheme %s" % padding) - - if channelwise: - if filters not in (input_shape[-1], 1): - raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise " - "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1])) - kernel_shape = output_shape[1:3] + kernel_size + [filters] - else: - kernel_shape = output_shape[1:3] + kernel_size + [input_shape[-1], filters] - if kernel is None: - with tf.variable_scope('local2d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - else: - if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape): - raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s" - % (tuple(kernel_shape), tuple([input_shape[0]] + kernel_shape), tuple(kernel.get_shape().as_list()))) - - outputs = [] - for i in range(kernel_size[0]): - filter_h_ind = -i-1 if flip_filters else i - if padding == 'VALID': - ii = i - else: - ii = i - (kernel_size[0] // 2) - input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1])) - output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii) - assert 0 <= output_h_slice.start < output_shape[1] - assert 0 < output_h_slice.stop <= output_shape[1] - - for j in range(kernel_size[1]): - filter_w_ind = -j-1 if flip_filters else j - if padding == 'VALID': - jj = j - else: - jj = j - (kernel_size[1] // 2) - input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2])) - output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj) - assert 0 <= output_w_slice.start < output_shape[2] - assert 0 < output_w_slice.stop <= output_shape[2] - if channelwise: - inc = inputs[:, input_h_slice, input_w_slice, :] * \ - kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :] - else: - inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] * - kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :, :], axis=-2) - # equivalent to this - # outputs[:, output_h_slice, output_w_slice, :] += inc - paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop], - [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]] - outputs.append(tf.pad(inc, paddings)) - outputs = tf.add_n(outputs) - if use_bias: - with tf.variable_scope('local2d'): - bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer()) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def separable_local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', - vertical_kernel=None, horizontal_kernel=None, flip_filters=False, - use_bias=True, channelwise=False): - """ - 2-D locally connected operation with separable filters. - - Note that, unlike tf.nn.separable_conv2d, this is spatial separability between dimensions 1 and 2. - - Args: - inputs: A 4-D tensor of shape - `[batch, in_height, in_width, in_channels]`. - vertical_kernel: A 5-D or 6-D tensor of shape - `[in_height, in_width, kernel_size[0], in_channels, filters]` or - `[batch, in_height, in_width, kernel_size[0], in_channels, filters]`. - horizontal_kernel: A 5-D or 6-D tensor of shape - `[in_height, in_width, kernel_size[1], in_channels, filters]` or - `[batch, in_height, in_width, kernel_size[1], in_channels, filters]`. - - Returns: - A 4-D tensor. - """ - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - if strides != [1, 1]: - raise NotImplementedError - if padding == 'FULL': - inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') - padding = 'VALID' - input_shape = inputs.get_shape().as_list() - if padding == 'SAME': - output_shape = input_shape[:3] + [filters] - elif padding == 'VALID': - output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters] - else: - raise ValueError("Invalid padding scheme %s" % padding) - - kernels = [vertical_kernel, horizontal_kernel] - for i, (kernel_type, kernel_length, kernel) in enumerate(zip(['vertical', 'horizontal'], kernel_size, kernels)): - if channelwise: - if filters not in (input_shape[-1], 1): - raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise " - "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1])) - kernel_shape = output_shape[1:3] + [kernel_length, filters] - else: - kernel_shape = output_shape[1:3] + [kernel_length, input_shape[-1], filters] - if kernel is None: - with tf.variable_scope('separable_local2d'): - kernel = tf.get_variable('%s_kernel' % kernel_type, kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - kernels[i] = kernel - else: - if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape): - raise ValueError("Expecting %s kernel with shape %s or %s but instead got kernel with shape %s" - % (kernel_type, - tuple(kernel_shape), tuple([input_shape[0]] +kernel_shape), - tuple(kernel.get_shape().as_list()))) - - outputs = [] - for i in range(kernel_size[0]): - filter_h_ind = -i-1 if flip_filters else i - if padding == 'VALID': - ii = i - else: - ii = i - (kernel_size[0] // 2) - input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1])) - output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii) - assert 0 <= output_h_slice.start < output_shape[1] - assert 0 < output_h_slice.stop <= output_shape[1] - - for j in range(kernel_size[1]): - filter_w_ind = -j-1 if flip_filters else j - if padding == 'VALID': - jj = j - else: - jj = j - (kernel_size[1] // 2) - input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2])) - output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj) - assert 0 <= output_w_slice.start < output_shape[2] - assert 0 < output_w_slice.stop <= output_shape[2] - if channelwise: - inc = inputs[:, input_h_slice, input_w_slice, :] * \ - kernels[0][..., output_h_slice, output_w_slice, filter_h_ind, :] * \ - kernels[1][..., output_h_slice, output_w_slice, filter_w_ind, :] - else: - inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] * - kernels[0][..., output_h_slice, output_w_slice, filter_h_ind, :, :] * - kernels[1][..., output_h_slice, output_w_slice, filter_w_ind, :, :], - axis=-2) - # equivalent to this - # outputs[:, output_h_slice, output_w_slice, :] += inc - paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop], - [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]] - outputs.append(tf.pad(inc, paddings)) - outputs = tf.add_n(outputs) - if use_bias: - with tf.variable_scope('separable_local2d'): - bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer()) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def kronecker_local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', - kernels=None, flip_filters=False, use_bias=True, channelwise=False): - """ - 2-D locally connected operation with filters represented as a kronecker product of smaller filters - - Args: - inputs: A 4-D tensor of shape - `[batch, in_height, in_width, in_channels]`. - kernel: A list of 6-D or 7-D tensors of shape - `[in_height, in_width, kernel_size[i][0], kernel_size[i][1], in_channels, filters]` or - `[batch, in_height, in_width, kernel_size[i][0], kernel_size[i][1], in_channels, filters]`. - - Returns: - A 4-D tensor. - """ - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - if strides != [1, 1]: - raise NotImplementedError - if padding == 'FULL': - inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') - padding = 'VALID' - input_shape = inputs.get_shape().as_list() - if padding == 'SAME': - output_shape = input_shape[:3] + [filters] - elif padding == 'VALID': - output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters] - else: - raise ValueError("Invalid padding scheme %s" % padding) - - if channelwise: - if filters not in (input_shape[-1], 1): - raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise " - "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1])) - kernel_shape = output_shape[1:3] + kernel_size + [filters] - factor_kernel_shape = output_shape[1:3] + [None, None, filters] - else: - kernel_shape = output_shape[1:3] + kernel_size + [input_shape[-1], filters] - factor_kernel_shape = output_shape[1:3] + [None, None, input_shape[-1], filters] - if kernels is None: - with tf.variable_scope('kronecker_local2d'): - kernels = [tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02))] - filter_h_lengths = [kernel_size[0]] - filter_w_lengths = [kernel_size[1]] - else: - for kernel in kernels: - if not ((len(kernel.shape) == len(factor_kernel_shape) and - all(((k == f) or f is None) for k, f in zip(kernel.get_shape().as_list(), factor_kernel_shape))) or - (len(kernel.shape) == (len(factor_kernel_shape) + 1) and - all(((k == f) or f is None) for k, f in zip(kernel.get_shape().as_list(), [input_shape[0]] +factor_kernel_shape)))): - raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s" - % (tuple(factor_kernel_shape), tuple([input_shape[0]] + factor_kernel_shape), - tuple(kernel.get_shape().as_list()))) - if channelwise: - filter_h_lengths, filter_w_lengths = zip(*[kernel.get_shape().as_list()[-3:-1] for kernel in kernels]) - else: - filter_h_lengths, filter_w_lengths = zip(*[kernel.get_shape().as_list()[-4:-2] for kernel in kernels]) - if [np.prod(filter_h_lengths), np.prod(filter_w_lengths)] != kernel_size: - raise ValueError("Expecting kernel size %s but instead got kernel size %s" - % (tuple(kernel_size), tuple([np.prod(filter_h_lengths), np.prod(filter_w_lengths)]))) - - def get_inds(ind, lengths): - inds = [] - for i in range(len(lengths)): - curr_ind = int(ind) - for j in range(len(lengths) - 1, i, -1): - curr_ind //= lengths[j] - curr_ind %= lengths[i] - inds.append(curr_ind) - return inds - - outputs = [] - for i in range(kernel_size[0]): - if padding == 'VALID': - ii = i - else: - ii = i - (kernel_size[0] // 2) - input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1])) - output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii) - assert 0 <= output_h_slice.start < output_shape[1] - assert 0 < output_h_slice.stop <= output_shape[1] - - for j in range(kernel_size[1]): - if padding == 'VALID': - jj = j - else: - jj = j - (kernel_size[1] // 2) - input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2])) - output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj) - assert 0 <= output_w_slice.start < output_shape[2] - assert 0 < output_w_slice.stop <= output_shape[2] - kernel_slice = 1.0 - for filter_h_ind, filter_w_ind, kernel in zip(get_inds(i, filter_h_lengths), get_inds(j, filter_w_lengths), kernels): - if flip_filters: - filter_h_ind = -filter_h_ind-1 - filter_w_ind = -filter_w_ind-1 - if channelwise: - kernel_slice *= kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :] - else: - kernel_slice *= kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :, :] - if channelwise: - inc = inputs[:, input_h_slice, input_w_slice, :] * kernel_slice - else: - inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] * kernel_slice, axis=-2) - # equivalent to this - # outputs[:, output_h_slice, output_w_slice, :] += inc - paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop], - [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]] - outputs.append(tf.pad(inc, paddings)) - outputs = tf.add_n(outputs) - if use_bias: - with tf.variable_scope('kronecker_local2d'): - bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer()) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def depthwise_conv2d(inputs, channel_multiplier, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True): - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - input_shape = inputs.get_shape().as_list() - kernel_shape = kernel_size + [input_shape[-1], channel_multiplier] - if kernel is None: - with tf.variable_scope('depthwise_conv2d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - else: - if kernel_shape != kernel.get_shape().as_list(): - raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" - % (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) - if padding == 'FULL': - inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') - padding = 'VALID' - outputs = tf.nn.depthwise_conv2d(inputs, kernel, [1] + strides + [1], padding=padding) - if use_bias: - with tf.variable_scope('depthwise_conv2d'): - bias = tf.get_variable('bias', [input_shape[-1] * channel_multiplier], dtype=tf.float32, initializer=tf.zeros_initializer()) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, use_spectral_norm=False): - """ - 2-D convolution. - - Args: - inputs: A 4-D tensor of shape - `[batch, in_height, in_width, in_channels]`. - kernel: A 4-D or 5-D tensor of shape - `[kernel_size[0], kernel_size[1], in_channels, filters]` or - `[batch, kernel_size[0], kernel_size[1], in_channels, filters]`. - bias: A 1-D or 2-D tensor of shape - `[filters]` or `[batch, filters]`. - - Returns: - A 4-D tensor. - """ - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - input_shape = inputs.get_shape().as_list() - kernel_shape = list(kernel_size) + [input_shape[-1], filters] - if kernel is None: - with tf.variable_scope('conv2d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - if use_spectral_norm: - kernel = spectral_normed_weight(kernel) - else: - if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape): - raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s" - % (tuple(kernel_shape), tuple([input_shape[0]] + kernel_shape), tuple(kernel.get_shape().as_list()))) - if padding == 'FULL': - inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') - padding = 'VALID' - if kernel.get_shape().ndims == 4: - outputs = tf.nn.conv2d(inputs, kernel, [1] + strides + [1], padding=padding) - else: - def conv2d_single_fn(args): - input_, kernel_ = args - input_ = tf.expand_dims(input_, axis=0) - output = tf.nn.conv2d(input_, kernel_, [1] + strides + [1], padding=padding) - output = tf.squeeze(output, axis=0) - return output - outputs = tf.map_fn(conv2d_single_fn, [inputs, kernel], dtype=tf.float32) - if use_bias: - bias_shape = [filters] - if bias is None: - with tf.variable_scope('conv2d'): - bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) - else: - if bias.get_shape().as_list() not in (bias_shape, [input_shape[0]] + bias_shape): - raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" - % (tuple(bias_shape), tuple(bias.get_shape().as_list()))) - if bias.get_shape().ndims == 1: - outputs = tf.nn.bias_add(outputs, bias) - else: - outputs = tf.add(outputs, bias[:, None, None, :]) - return outputs - - -def deconv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True): - """ - 2-D transposed convolution. - - Notes on padding: - The equivalent of transposed convolution with full padding is a convolution with valid padding, and - the equivalent of transposed convolution with valid padding is a convolution with full padding. - - Reference: - http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html - """ - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - input_shape = inputs.get_shape().as_list() - kernel_shape = list(kernel_size) + [filters, input_shape[-1]] - if kernel is None: - with tf.variable_scope('deconv2d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - else: - if kernel_shape != kernel.get_shape().as_list(): - raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) - if padding == 'FULL': - output_h, output_w = [s * (i + 1) - k for (i, k, s) in zip(input_shape[1:3], kernel_size, strides)] - elif padding == 'SAME': - output_h, output_w = [s * i for (i, s) in zip(input_shape[1:3], strides)] - elif padding == 'VALID': - output_h, output_w = [s * (i - 1) + k for (i, k, s) in zip(input_shape[1:3], kernel_size, strides)] - else: - raise ValueError("Invalid padding scheme %s" % padding) - output_shape = [input_shape[0], output_h, output_w, filters] - outputs = tf.nn.conv2d_transpose(inputs, kernel, output_shape, [1] + strides + [1], padding=padding) - if use_bias: - with tf.variable_scope('deconv2d'): - bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def get_bilinear_kernel(strides): - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - strides = np.array(strides) - kernel_size = 2 * strides - strides % 2 - center = strides - (kernel_size % 2 == 1) - 0.5 * (kernel_size % 2 != 1) - vertical_kernel = 1 - abs(np.arange(kernel_size[0]) - center[0]) / strides[0] - horizontal_kernel = 1 - abs(np.arange(kernel_size[1]) - center[1]) / strides[1] - kernel = vertical_kernel[:, None] * horizontal_kernel[None, :] - return kernel - - -def upsample2d(inputs, strides, padding='SAME', upsample_mode='bilinear'): - if upsample_mode == 'bilinear': - single_bilinear_kernel = get_bilinear_kernel(strides).astype(np.float32) - input_shape = inputs.get_shape().as_list() - bilinear_kernel = tf.matrix_diag(tf.tile(tf.constant(single_bilinear_kernel)[..., None], (1, 1, input_shape[-1]))) - outputs = deconv2d(inputs, input_shape[-1], kernel_size=single_bilinear_kernel.shape, - strides=strides, kernel=bilinear_kernel, padding=padding, use_bias=False) - elif upsample_mode == 'nearest': - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - input_shape = inputs.get_shape().as_list() - inputs_tiled = tf.tile(inputs[:, :, None, :, None, :], [1, 1, strides[0], 1, strides[1], 1]) - outputs = tf.reshape(inputs_tiled, [input_shape[0], input_shape[1] * strides[0], - input_shape[2] * strides[1], input_shape[3]]) - else: - raise ValueError("Unknown upsample mode %s" % upsample_mode) - return outputs - - -def upsample2d_v2(inputs, strides, padding='SAME', upsample_mode='bilinear'): - """ - Possibly less computationally efficient but more memory efficent than upsampled2d. - """ - if upsample_mode == 'bilinear': - single_kernel = get_bilinear_kernel(strides).astype(np.float32) - elif upsample_mode == 'nearest': - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - single_kernel = np.ones(strides, dtype=np.float32) - else: - raise ValueError("Unknown upsample mode %s" % upsample_mode) - input_shape = inputs.get_shape().as_list() - kernel = tf.constant(single_kernel)[:, :, None, None] - inputs = tf.transpose(inputs, [3, 0, 1, 2])[..., None] - outputs = tf.map_fn(lambda input: deconv2d(input, 1, kernel_size=single_kernel.shape, - strides=strides, kernel=kernel, - padding=padding, use_bias=False), - inputs, parallel_iterations=input_shape[-1]) - outputs = tf.transpose(tf.squeeze(outputs, axis=-1), [1, 2, 3, 0]) - return outputs - - -def upsample_conv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', - kernel=None, use_bias=True, bias=None, upsample_mode='bilinear'): - """ - Upsamples the inputs by a factor using bilinear interpolation and the performs conv2d on the upsampled input. This - function is more computationally and memory efficient than a naive implementation. Unlike a naive implementation - that would upsample the input first, this implementation first convolves the bilinear kernel with the given kernel, - and then performs the convolution (actually a deconv2d) with the combined kernel. As opposed to just using deconv2d - directly, this function is less prone to checkerboard artifacts thanks to the implicit bilinear upsampling. - - Example: - >>> import numpy as np - >>> import tensorflow as tf - >>> from video_prediction.ops import upsample_conv2d, upsample2d, conv2d, pad2d_paddings - >>> inputs_shape = [4, 8, 8, 64] - >>> kernel_size = [3, 3] # for convolution - >>> filters = 32 # for convolution - >>> strides = [2, 2] # for upsampling - >>> inputs = tf.get_variable("inputs", inputs_shape) - >>> kernel = tf.get_variable("kernel", (kernel_size[0], kernel_size[1], inputs_shape[-1], filters)) - >>> bias = tf.get_variable("bias", (filters,)) - >>> outputs = upsample_conv2d(inputs, filters, kernel_size=kernel_size, strides=strides, \ - kernel=kernel, bias=bias) - >>> # upsample with bilinear interpolation - >>> inputs_up = upsample2d(inputs, strides=strides, padding='VALID') - >>> # convolve upsampled input with kernel - >>> outputs_up = conv2d(inputs_up, filters, kernel_size=kernel_size, strides=(1, 1), \ - kernel=kernel, bias=bias, padding='FULL') - >>> # crop appropriately - >>> same_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='SAME') - >>> full_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='FULL') - >>> crop_top = (strides[0] - strides[0] % 2) // 2 + full_paddings[1][1] - same_paddings[1][1] - >>> crop_left = (strides[1] - strides[1] % 2) // 2 + full_paddings[2][1] - same_paddings[2][1] - >>> outputs_up = outputs_up[:, crop_top:crop_top + strides[0] * inputs_shape[1], \ - crop_left:crop_left + strides[1] * inputs_shape[2], :] - >>> sess = tf.Session() - >>> sess.run(tf.global_variables_initializer()) - >>> assert np.allclose(*sess.run([outputs, outputs_up]), atol=1e-5) - - """ - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - if padding != 'SAME' or upsample_mode != 'bilinear': - raise NotImplementedError - input_shape = inputs.get_shape().as_list() - kernel_shape = list(kernel_size) + [input_shape[-1], filters] - if kernel is None: - with tf.variable_scope('upsample_conv2d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - else: - if kernel_shape != kernel.get_shape().as_list(): - raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % - (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) - - # convolve bilinear kernel with kernel - single_bilinear_kernel = get_bilinear_kernel(strides).astype(np.float32) - kernel_transposed = tf.transpose(kernel, (0, 1, 3, 2)) - kernel_reshaped = tf.reshape(kernel_transposed, kernel_size + [1, input_shape[-1] * filters]) - kernel_up_reshaped = conv2d(tf.constant(single_bilinear_kernel)[None, :, :, None], input_shape[-1] * filters, - kernel_size=kernel_size, kernel=kernel_reshaped, padding='FULL', use_bias=False) - kernel_up = tf.reshape(kernel_up_reshaped, - kernel_up_reshaped.get_shape().as_list()[1:3] + [filters, input_shape[-1]]) - - # deconvolve with the bilinearly convolved kernel - outputs = deconv2d(inputs, filters, kernel_size=kernel_up.get_shape().as_list()[:2], strides=strides, - kernel=kernel_up, padding='SAME', use_bias=False) - if use_bias: - if bias is None: - with tf.variable_scope('upsample_conv2d'): - bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) - else: - bias_shape = [filters] - if bias_shape != bias.get_shape().as_list(): - raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" % - (tuple(bias_shape), tuple(bias.get_shape().as_list()))) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def upsample_conv2d_v2(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', - kernel=None, use_bias=True, bias=None, upsample_mode='bilinear'): - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - if padding != 'SAME': - raise NotImplementedError - input_shape = inputs.get_shape().as_list() - kernel_shape = list(kernel_size) + [input_shape[-1], filters] - if kernel is None: - with tf.variable_scope('upsample_conv2d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - else: - if kernel_shape != kernel.get_shape().as_list(): - raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % - (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) - - inputs_up = upsample2d_v2(inputs, strides=strides, padding='VALID', upsample_mode=upsample_mode) - # convolve upsampled input with kernel - outputs = conv2d(inputs_up, filters, kernel_size=kernel_size, strides=(1, 1), - kernel=kernel, bias=None, padding='FULL', use_bias=False) - # crop appropriately - same_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='SAME') - full_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='FULL') - crop_top = (strides[0] - strides[0] % 2) // 2 + full_paddings[1][1] - same_paddings[1][1] - crop_left = (strides[1] - strides[1] % 2) // 2 + full_paddings[2][1] - same_paddings[2][1] - outputs = outputs[:, crop_top:crop_top + strides[0] * input_shape[1], - crop_left:crop_left + strides[1] * input_shape[2], :] - - if use_bias: - if bias is None: - with tf.variable_scope('upsample_conv2d'): - bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) - else: - bias_shape = [filters] - if bias_shape != bias.get_shape().as_list(): - raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" % - (tuple(bias_shape), tuple(bias.get_shape().as_list()))) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def conv3d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', use_bias=True, use_spectral_norm=False): - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 3 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 3 - input_shape = inputs.get_shape().as_list() - kernel_shape = list(kernel_size) + [input_shape[-1], filters] - with tf.variable_scope('conv3d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) - if use_spectral_norm: - kernel = spectral_normed_weight(kernel) - outputs = tf.nn.conv3d(inputs, kernel, [1] + strides + [1], padding=padding) - if use_bias: - bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def pool2d(inputs, pool_size, strides=(1, 1), padding='SAME', pool_mode='avg'): - pool_size = list(pool_size) if isinstance(pool_size, (tuple, list)) else [pool_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - if padding == 'FULL': - inputs = pad2d(inputs, pool_size, strides=strides, padding=padding, mode='CONSTANT') - padding = 'VALID' - if pool_mode == 'max': - outputs = tf.nn.max_pool(inputs, [1] + pool_size + [1], [1] + strides + [1], padding=padding) - elif pool_mode == 'avg': - outputs = tf.nn.avg_pool(inputs, [1] + pool_size + [1], [1] + strides + [1], padding=padding) - else: - raise ValueError('Invalid pooling mode:', pool_mode) - return outputs - - -def conv_pool2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, pool_mode='avg'): - """ - Similar optimization as in upsample_conv2d - - Example: - >>> import numpy as np - >>> import tensorflow as tf - >>> from video_prediction.ops import conv_pool2d, conv2d, pool2d - >>> inputs_shape = [4, 16, 16, 32] - >>> kernel_size = [3, 3] # for convolution - >>> filters = 64 # for convolution - >>> strides = [2, 2] # for pooling - >>> inputs = tf.get_variable("inputs", inputs_shape) - >>> kernel = tf.get_variable("kernel", (kernel_size[0], kernel_size[1], inputs_shape[-1], filters)) - >>> bias = tf.get_variable("bias", (filters,)) - >>> outputs = conv_pool2d(inputs, filters, kernel_size=kernel_size, strides=strides, - kernel=kernel, bias=bias, pool_mode='avg') - >>> inputs_conv = conv2d(inputs, filters, kernel_size=kernel_size, strides=(1, 1), - kernel=kernel, bias=bias) - >>> outputs_pool = pool2d(inputs_conv, pool_size=strides, strides=strides, pool_mode='avg') - >>> sess = tf.Session() - >>> sess.run(tf.global_variables_initializer()) - >>> assert np.allclose(*sess.run([outputs, outputs_pool]), atol=1e-5) - - """ - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - if padding != 'SAME' or pool_mode != 'avg': - raise NotImplementedError - input_shape = inputs.get_shape().as_list() - if input_shape[1] % strides[0] or input_shape[2] % strides[1]: - raise NotImplementedError("The height and width of the input should be " - "an integer multiple of the respective stride.") - kernel_shape = list(kernel_size) + [input_shape[-1], filters] - if kernel is None: - with tf.variable_scope('conv_pool2d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - else: - if kernel_shape != kernel.get_shape().as_list(): - raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % - (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) - - # pool kernel - kernel_reshaped = tf.reshape(kernel, [1] + kernel_size + [input_shape[-1] * filters]) - kernel_pool_reshaped = pool2d(kernel_reshaped, pool_size=strides, padding='FULL', pool_mode='avg') - kernel_pool = tf.reshape(kernel_pool_reshaped, - kernel_pool_reshaped.get_shape().as_list()[1:3] + [input_shape[-1], filters]) - - outputs = conv2d(inputs, filters, kernel_size=kernel_pool.get_shape().as_list()[:2], strides=strides, - kernel=kernel_pool, padding='SAME', use_bias=False) - if use_bias: - if bias is None: - with tf.variable_scope('conv_pool2d'): - bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) - else: - bias_shape = [filters] - if bias_shape != bias.get_shape().as_list(): - raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" % - (tuple(bias_shape), tuple(bias.get_shape().as_list()))) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def conv_pool2d_v2(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, pool_mode='avg'): - kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 - if padding != 'SAME' or pool_mode != 'avg': - raise NotImplementedError - input_shape = inputs.get_shape().as_list() - if input_shape[1] % strides[0] or input_shape[2] % strides[1]: - raise NotImplementedError("The height and width of the input should be " - "an integer multiple of the respective stride.") - kernel_shape = list(kernel_size) + [input_shape[-1], filters] - if kernel is None: - with tf.variable_scope('conv_pool2d'): - kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, - initializer=tf.truncated_normal_initializer(stddev=0.02)) - else: - if kernel_shape != kernel.get_shape().as_list(): - raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % - (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) - - inputs_conv = conv2d(inputs, filters, kernel_size=kernel_size, strides=(1, 1), - kernel=kernel, bias=None, use_bias=False) - outputs = pool2d(inputs_conv, pool_size=strides, strides=strides, pool_mode='avg') - - if use_bias: - if bias is None: - with tf.variable_scope('conv_pool2d'): - bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) - else: - bias_shape = [filters] - if bias_shape != bias.get_shape().as_list(): - raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" % - (tuple(bias_shape), tuple(bias.get_shape().as_list()))) - outputs = tf.nn.bias_add(outputs, bias) - return outputs - - -def lrelu(x, alpha): - """ - Leaky ReLU activation function - - Reference: - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_ops.py - """ - with tf.name_scope("lrelu"): - return tf.maximum(alpha * x, x) - - -def batchnorm(input): - with tf.variable_scope("batchnorm"): - # this block looks like it has 3 inputs on the graph unless we do this - input = tf.identity(input) - - channels = input.get_shape()[-1] - offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) - scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.truncated_normal_initializer(1.0, 0.02)) - mean, variance = tf.nn.moments(input, axes=list(range(len(input.get_shape()) - 1)), keepdims=False) - variance_epsilon = 1e-5 - normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) - return normalized - - -def instancenorm(input): - with tf.variable_scope("instancenorm"): - # this block looks like it has 3 inputs on the graph unless we do this - input = tf.identity(input) - - channels = input.get_shape()[-1] - offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) - scale = tf.get_variable("scale", [channels], dtype=tf.float32, - initializer=tf.truncated_normal_initializer(1.0, 0.02)) - mean, variance = tf.nn.moments(input, axes=list(range(1, len(input.get_shape()) - 1)), keepdims=True) - variance_epsilon = 1e-5 - normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, - variance_epsilon=variance_epsilon) - return normalized - - -def flatten(input, axis=1, end_axis=-1): - """ - Caffe-style flatten. - - Args: - inputs: An N-D tensor. - axis: The first axis to flatten: all preceding axes are retained in the output. - May be negative to index from the end (e.g., -1 for the last axis). - end_axis: The last axis to flatten: all following axes are retained in the output. - May be negative to index from the end (e.g., the default -1 for the last - axis) - - Returns: - A M-D tensor where M = N - (end_axis - axis) - """ - input_shape = tf.shape(input) - input_rank = tf.shape(input_shape)[0] - if axis < 0: - axis = input_rank + axis - if end_axis < 0: - end_axis = input_rank + end_axis - output_shape = [] - if axis != 0: - output_shape.append(input_shape[:axis]) - output_shape.append([tf.reduce_prod(input_shape[axis:end_axis + 1])]) - if end_axis + 1 != input_rank: - output_shape.append(input_shape[end_axis + 1:]) - output_shape = tf.concat(output_shape, axis=0) - output = tf.reshape(input, output_shape) - return output - - -def tile_concat(values, axis): - """ - Like concat except that first tiles the broadcastable dimensions if necessary - """ - shapes = [value.get_shape() for value in values] - # convert axis to positive form - ndims = shapes[0].ndims - for shape in shapes[1:]: - assert ndims == shape.ndims - if -ndims < axis < 0: - axis += ndims - # remove axis dimension - shapes = [shape.as_list() for shape in shapes] - dims = [shape.pop(axis) for shape in shapes] - shapes = [tf.TensorShape(shape) for shape in shapes] - # compute broadcasted shape - b_shape = shapes[0] - for shape in shapes[1:]: - b_shape = tf.broadcast_static_shape(b_shape, shape) - # add back axis dimension - b_shapes = [b_shape.as_list() for _ in dims] - for b_shape, dim in zip(b_shapes, dims): - b_shape.insert(axis, dim) - # tile values to match broadcasted shape, if necessary - b_values = [] - for value, b_shape in zip(values, b_shapes): - multiples = [] - for dim, b_dim in zip(value.get_shape().as_list(), b_shape): - if dim == b_dim: - multiples.append(1) - else: - assert dim == 1 - multiples.append(b_dim) - if any(multiple != 1 for multiple in multiples): - b_value = tf.tile(value, multiples) - else: - b_value = value - b_values.append(b_value) - return tf.concat(b_values, axis=axis) - - -def sigmoid_kl_with_logits(logits, targets): - # broadcasts the same target value across the whole batch - # this is implemented so awkwardly because tensorflow lacks an x log x op - assert isinstance(targets, float) - if targets in [0., 1.]: - entropy = 0. - else: - entropy = - targets * np.log(targets) - (1. - targets) * np.log(1. - targets) - return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits) * targets) - entropy - - -def spectral_normed_weight(W, u=None, num_iters=1): - SPECTRAL_NORMALIZATION_VARIABLES = 'spectral_normalization_variables' - - # Usually num_iters = 1 will be enough - W_shape = W.shape.as_list() - W_reshaped = tf.reshape(W, [-1, W_shape[-1]]) - if u is None: - u = tf.get_variable("u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False) - - def l2normalize(v, eps=1e-12): - return v / (tf.norm(v) + eps) - - def power_iteration(i, u_i, v_i): - v_ip1 = l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped))) - u_ip1 = l2normalize(tf.matmul(v_ip1, W_reshaped)) - return i + 1, u_ip1, v_ip1 - _, u_final, v_final = tf.while_loop( - cond=lambda i, _1, _2: i < num_iters, - body=power_iteration, - loop_vars=(tf.constant(0, dtype=tf.int32), - u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]])) - ) - sigma = tf.squeeze(tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))) - W_bar_reshaped = W_reshaped / sigma - W_bar = tf.reshape(W_bar_reshaped, W_shape) - - if u not in tf.get_collection(SPECTRAL_NORMALIZATION_VARIABLES): - tf.add_to_collection(SPECTRAL_NORMALIZATION_VARIABLES, u) - tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u.assign(u_final)) - return W_bar - - -def get_activation_layer(layer_type): - if layer_type == 'relu': - layer = tf.nn.relu - elif layer_type == 'elu': - layer = tf.nn.elu - else: - raise ValueError('Invalid activation layer %s' % layer_type) - return layer - - -def get_norm_layer(layer_type): - if layer_type == 'batch': - layer = tf.layers.batch_normalization - elif layer_type == 'layer': - layer = tf.contrib.layers.layer_norm - elif layer_type == 'instance': - from model_modules.video_prediction.layers import fused_instance_norm - layer = fused_instance_norm - elif layer_type == 'none': - layer = tf.identity - else: - raise ValueError('Invalid normalization layer %s' % layer_type) - return layer - - -def get_upsample_layer(layer_type): - if layer_type == 'deconv2d': - layer = deconv2d - elif layer_type == 'upsample_conv2d': - layer = upsample_conv2d - elif layer_type == 'upsample_conv2d_v2': - layer = upsample_conv2d_v2 - else: - raise ValueError('Invalid upsampling layer %s' % layer_type) - return layer - - -def get_downsample_layer(layer_type): - if layer_type == 'conv2d': - layer = conv2d - elif layer_type == 'conv_pool2d': - layer = conv_pool2d - elif layer_type == 'conv_pool2d_v2': - layer = conv_pool2d_v2 - else: - raise ValueError('Invalid downsampling layer %s' % layer_type) - return layer diff --git a/video_prediction_tools/model_modules/video_prediction/rnn_ops.py b/video_prediction_tools/model_modules/video_prediction/rnn_ops.py deleted file mode 100644 index 970fae3ec656dd3e677e906602f0d29d6650b2c7..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/rnn_ops.py +++ /dev/null @@ -1,267 +0,0 @@ -# Copyright 2016 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Convolutional LSTM implementation.""" - -import tensorflow as tf -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import rnn_cell_impl -from tensorflow.python.ops import variable_scope as vs - - -class BasicConv2DLSTMCell(rnn_cell_impl.RNNCell): - """2D Convolutional LSTM cell with (optional) normalization and recurrent dropout. - - The implementation is based on: tf.contrib.rnn.LayerNormBasicLSTMCell. - - It does not allow cell clipping, a projection layer, and does not - use peep-hole connections: it is the basic baseline. - """ - def __init__(self, input_shape, filters, kernel_size, - forget_bias=1.0, activation_fn=math_ops.tanh, - normalizer_fn=None, separate_norms=True, - norm_gain=1.0, norm_shift=0.0, - dropout_keep_prob=1.0, dropout_prob_seed=None, - skip_connection=False, reuse=None): - """Initializes the basic convolutional LSTM cell. - - Args: - input_shape: int tuple, Shape of the input, excluding the batch size. - filters: int, The number of filters of the conv LSTM cell. - kernel_size: int tuple, The kernel size of the conv LSTM cell. - forget_bias: float, The bias added to forget gates (see above). - activation_fn: Activation function of the inner states. - normalizer_fn: If specified, this normalization will be applied before the - internal nonlinearities. - separate_norms: If set to `False`, the normalizer_fn is applied to the - concatenated tensor that follows the convolution, i.e. before splitting - the tensor. This case is slightly faster but it might be functionally - different, depending on the normalizer_fn (it's functionally the same - for instance norm but not for layer norm). Default: `True`. - norm_gain: float, The layer normalization gain initial value. If - `normalizer_fn` is `None`, this argument will be ignored. - norm_shift: float, The layer normalization shift initial value. If - `normalizer_fn` is `None`, this argument will be ignored. - dropout_keep_prob: unit Tensor or float between 0 and 1 representing the - recurrent dropout probability value. If float and 1.0, no dropout will - be applied. - dropout_prob_seed: (optional) integer, the randomness seed. - skip_connection: If set to `True`, concatenate the input to the - output of the conv LSTM. Default: `False`. - reuse: (optional) Python boolean describing whether to reuse variables - in an existing scope. If not `True`, and the existing scope already has - the given variables, an error is raised. - """ - super(BasicConv2DLSTMCell, self).__init__(_reuse=reuse) - - self._input_shape = input_shape - self._filters = filters - self._kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - self._forget_bias = forget_bias - self._activation_fn = activation_fn - self._normalizer_fn = normalizer_fn - self._separate_norms = separate_norms - self._g = norm_gain - self._b = norm_shift - self._keep_prob = dropout_keep_prob - self._seed = dropout_prob_seed - self._skip_connection = skip_connection - self._reuse = reuse - - if self._skip_connection: - output_channels = self._filters + self._input_shape[-1] - else: - output_channels = self._filters - cell_size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._filters]) - self._output_size = tensor_shape.TensorShape(self._input_shape[:-1] + [output_channels]) - self._state_size = rnn_cell_impl.LSTMStateTuple(cell_size, self._output_size) - - @property - def output_size(self): - return self._output_size - - @property - def state_size(self): - return self._state_size - - def _norm(self, inputs, scope): - shape = inputs.get_shape()[-1:] - gamma_init = init_ops.constant_initializer(self._g) - beta_init = init_ops.constant_initializer(self._b) - with vs.variable_scope(scope): - # Initialize beta and gamma for use by normalizer. - vs.get_variable("gamma", shape=shape, initializer=gamma_init) - vs.get_variable("beta", shape=shape, initializer=beta_init) - normalized = self._normalizer_fn(inputs, reuse=True, scope=scope) - return normalized - - def _conv2d(self, inputs): - output_filters = 4 * self._filters - input_shape = inputs.get_shape().as_list() - kernel_shape = list(self._kernel_size) + [input_shape[-1], output_filters] - kernel = vs.get_variable("kernel", kernel_shape, dtype=dtypes.float32, - initializer=init_ops.truncated_normal_initializer(stddev=0.02)) - outputs = nn_ops.conv2d(inputs, kernel, [1] * 4, padding='SAME') - if not self._normalizer_fn: - bias = vs.get_variable('bias', [output_filters], dtype=dtypes.float32, - initializer=init_ops.zeros_initializer()) - outputs = nn_ops.bias_add(outputs, bias) - return outputs - - def _dense(self, inputs): - num_units = 4 * self._filters - input_shape = inputs.shape.as_list() - kernel_shape = [input_shape[-1], num_units] - kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32, - initializer=init_ops.truncated_normal_initializer(stddev=0.02)) - outputs = tf.matmul(inputs, kernel) - return outputs - - def call(self, inputs, state): - """2D Convolutional LSTM cell with (optional) normalization and recurrent dropout.""" - c, h = state - tile_concat = isinstance(inputs, (list, tuple)) - if tile_concat: - inputs, inputs_non_spatial = inputs - args = array_ops.concat([inputs, h], -1) - concat = self._conv2d(args) - if tile_concat: - concat = concat + self._dense(inputs_non_spatial)[:, None, None, :] - - if self._normalizer_fn and not self._separate_norms: - concat = self._norm(concat, "input_transform_forget_output") - i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=-1) - if self._normalizer_fn and self._separate_norms: - i = self._norm(i, "input") - j = self._norm(j, "transform") - f = self._norm(f, "forget") - o = self._norm(o, "output") - - g = self._activation_fn(j) - if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: - g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) - - new_c = (c * math_ops.sigmoid(f + self._forget_bias) - + math_ops.sigmoid(i) * g) - if self._normalizer_fn: - new_c = self._norm(new_c, "state") - new_h = self._activation_fn(new_c) * math_ops.sigmoid(o) - - if self._skip_connection: - new_h = array_ops.concat([new_h, inputs], axis=-1) - - new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) - return new_h, new_state - - -class Conv2DGRUCell(tf.nn.rnn_cell.RNNCell): - """2D Convolutional GRU cell with (optional) normalization. - - Modified from these: - https://github.com/carlthome/tensorflow-convlstm-cell/blob/master/cell.py - https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn_cell_impl.py - """ - def __init__(self, input_shape, filters, kernel_size, - activation_fn=tf.tanh, - normalizer_fn=None, separate_norms=True, - bias_initializer=None, reuse=None): - super(Conv2DGRUCell, self).__init__(_reuse=reuse) - self._input_shape = input_shape - self._filters = filters - self._kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 - self._activation_fn = activation_fn - self._normalizer_fn = normalizer_fn - self._separate_norms = separate_norms - self._bias_initializer = bias_initializer - self._size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._filters]) - - @property - def state_size(self): - return self._size - - @property - def output_size(self): - return self._size - - def _norm(self, inputs, scope, bias_initializer): - shape = inputs.get_shape()[-1:] - gamma_init = init_ops.ones_initializer() - beta_init = bias_initializer - with vs.variable_scope(scope): - # Initialize beta and gamma for use by normalizer. - vs.get_variable("gamma", shape=shape, initializer=gamma_init) - vs.get_variable("beta", shape=shape, initializer=beta_init) - normalized = self._normalizer_fn(inputs, reuse=True, scope=scope) - return normalized - - def _conv2d(self, inputs, output_filters, bias_initializer): - input_shape = inputs.get_shape().as_list() - kernel_shape = list(self._kernel_size) + [input_shape[-1], output_filters] - kernel = vs.get_variable("kernel", kernel_shape, dtype=dtypes.float32, - initializer=init_ops.truncated_normal_initializer(stddev=0.02)) - outputs = nn_ops.conv2d(inputs, kernel, [1] * 4, padding='SAME') - if not self._normalizer_fn: - bias = vs.get_variable('bias', [output_filters], dtype=dtypes.float32, - initializer=bias_initializer) - outputs = nn_ops.bias_add(outputs, bias) - return outputs - - def _dense(self, inputs, num_units): - input_shape = inputs.shape.as_list() - kernel_shape = [input_shape[-1], num_units] - kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32, - initializer=init_ops.truncated_normal_initializer(stddev=0.02)) - outputs = tf.matmul(inputs, kernel) - return outputs - - def call(self, inputs, state): - bias_ones = self._bias_initializer - if self._bias_initializer is None: - bias_ones = init_ops.ones_initializer() - tile_concat = isinstance(inputs, (list, tuple)) - if tile_concat: - inputs, inputs_non_spatial = inputs - with vs.variable_scope('gates'): - inputs = array_ops.concat([inputs, state], axis=-1) - concat = self._conv2d(inputs, 2 * self._filters, bias_ones) - if tile_concat: - concat = concat + self._dense(inputs_non_spatial, concat.shape[-1].value)[:, None, None, :] - if self._normalizer_fn and not self._separate_norms: - concat = self._norm(concat, "reset_update", bias_ones) - r, u = array_ops.split(concat, 2, axis=-1) - if self._normalizer_fn and self._separate_norms: - r = self._norm(r, "reset", bias_ones) - u = self._norm(u, "update", bias_ones) - r, u = math_ops.sigmoid(r), math_ops.sigmoid(u) - - bias_zeros = self._bias_initializer - if self._bias_initializer is None: - bias_zeros = init_ops.zeros_initializer() - with vs.variable_scope('candidate'): - inputs = array_ops.concat([inputs, r * state], axis=-1) - candidate = self._conv2d(inputs, self._filters, bias_zeros) - if tile_concat: - candidate = candidate + self._dense(inputs_non_spatial, candidate.shape[-1].value)[:, None, None, :] - if self._normalizer_fn: - candidate = self._norm(candidate, "state", bias_zeros) - - c = self._activation_fn(candidate) - new_h = u * state + (1 - u) * c - return new_h, new_h diff --git a/video_prediction_tools/model_modules/video_prediction/utils/__init__.py b/video_prediction_tools/model_modules/video_prediction/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/video_prediction_tools/model_modules/video_prediction/utils/ffmpeg_gif.py b/video_prediction_tools/model_modules/video_prediction/utils/ffmpeg_gif.py deleted file mode 100644 index 8aaeea0f09506d56b0c556168b6d9864498f2b69..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/utils/ffmpeg_gif.py +++ /dev/null @@ -1,99 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import os - -import numpy as np - - -def save_gif(gif_fname, images, fps): - """ - To generate a gif from image files, first generate palette from images - and then generate the gif from the images and the palette. - ffmpeg -i input_%02d.jpg -vf palettegen -y palette.png - ffmpeg -i input_%02d.jpg -i palette.png -lavfi paletteuse -y output.gif - - Alternatively, use a filter to map the input images to both the palette - and gif commands, while also passing the palette to the gif command. - ffmpeg -i input_%02d.jpg -filter_complex "[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse" -y output.gif - - To directly pass in numpy images, use rawvideo format and `-i -` option. - """ - from subprocess import Popen, PIPE - head, tail = os.path.split(gif_fname) - if head and not os.path.exists(head): - os.makedirs(head) - h, w, c = images[0].shape - cmd = ['ffmpeg', '-y', - '-f', 'rawvideo', - '-vcodec', 'rawvideo', - '-r', '%.02f' % fps, - '-s', '%dx%d' % (w, h), - '-pix_fmt', {1: 'gray', 3: 'rgb24', 4: 'rgba'}[c], - '-i', '-', - '-filter_complex', '[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse', - '-r', '%.02f' % fps, - '%s' % gif_fname] - proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) - for image in images: - proc.stdin.write(image.tostring()) - out, err = proc.communicate() - if proc.returncode: - err = '\n'.join([' '.join(cmd), err.decode('utf8')]) - raise IOError(err) - del proc - - -def encode_gif(images, fps): - """Encodes numpy images into gif string. - Args: - images: A 5-D `uint8` `np.array` (or a list of 4-D images) of shape - `[batch_size, time, height, width, channels]` where `channels` is 1 or 3. - fps: frames per second of the animation - Returns: - The encoded gif string. - Raises: - IOError: If the ffmpeg command returns an error. - """ - from subprocess import Popen, PIPE - h, w, c = images[0].shape - cmd = ['ffmpeg', '-y', - '-f', 'rawvideo', - '-vcodec', 'rawvideo', - '-r', '%.02f' % fps, - '-s', '%dx%d' % (w, h), - '-pix_fmt', {1: 'gray', 3: 'rgb24'}[c], - '-i', '-', - '-filter_complex', '[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse', - '-r', '%.02f' % fps, - '-f', 'gif', - '-'] - proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) - for image in images: - proc.stdin.write(image.tostring()) - out, err = proc.communicate() - if proc.returncode: - err = '\n'.join([' '.join(cmd), err.decode('utf8')]) - raise IOError(err) - del proc - return out - - -def main(): - images_shape = (12, 64, 64, 3) # num_frames, height, width, channels - images = np.random.randint(256, size=images_shape).astype(np.uint8) - - save_gif('output_save.gif', images, 4) - with open('output_save.gif', 'rb') as f: - string_save = f.read() - - string_encode = encode_gif(images, 4) - with open('output_encode.gif', 'wb') as f: - f.write(string_encode) - - print(np.all(string_save == string_encode)) - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/model_modules/video_prediction/utils/gif_summary.py b/video_prediction_tools/model_modules/video_prediction/utils/gif_summary.py deleted file mode 100644 index 7f9ce616951a66dee17c1081a4961b4af1d6a57f..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/utils/gif_summary.py +++ /dev/null @@ -1,114 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.python.ops import summary_op_util -#from tensorflow.python.distribute.summary_op_util import skip_summary TODO: IMPORT ERRORS IN juwels -from model_modules.video_prediction.utils import ffmpeg_gif - - -def py_gif_summary(tag, images, max_outputs, fps): - """Outputs a `Summary` protocol buffer with gif animations. - Args: - tag: Name of the summary. - images: A 5-D `uint8` `np.array` of shape `[batch_size, time, height, width, - channels]` where `channels` is 1 or 3. - max_outputs: Max number of batch elements to generate gifs for. - fps: frames per second of the animation - Returns: - The serialized `Summary` protocol buffer. - Raises: - ValueError: If `images` is not a 5-D `uint8` array with 1 or 3 channels. - """ - is_bytes = isinstance(tag, bytes) - if is_bytes: - tag = tag.decode("utf-8") - images = np.asarray(images) - if images.dtype != np.uint8: - raise ValueError("Tensor must have dtype uint8 for gif summary.") - if images.ndim != 5: - raise ValueError("Tensor must be 5-D for gif summary.") - batch_size, _, height, width, channels = images.shape - if channels not in (1, 3): - raise ValueError("Tensors must have 1 or 3 channels for gif summary.") - - summ = tf.Summary() - num_outputs = min(batch_size, max_outputs) - for i in range(num_outputs): - image_summ = tf.Summary.Image() - image_summ.height = height - image_summ.width = width - image_summ.colorspace = channels # 1: grayscale, 3: RGB - try: - image_summ.encoded_image_string = ffmpeg_gif.encode_gif(images[i], fps) - except (IOError, OSError) as e: - tf.logging.warning( - "Unable to encode images to a gif string because either ffmpeg is " - "not installed or ffmpeg returned an error: %s. Falling back to an " - "image summary of the first frame in the sequence.", e) - try: - from PIL import Image # pylint: disable=g-import-not-at-top - import io # pylint: disable=g-import-not-at-top - with io.BytesIO() as output: - Image.fromarray(images[i][0]).save(output, "PNG") - image_summ.encoded_image_string = output.getvalue() - except: - tf.logging.warning( - "Gif summaries requires ffmpeg or PIL to be installed: %s", e) - image_summ.encoded_image_string = "".encode('utf-8') if is_bytes else "" - if num_outputs == 1: - summ_tag = "{}/gif".format(tag) - else: - summ_tag = "{}/gif/{}".format(tag, i) - summ.value.add(tag=summ_tag, image=image_summ) - summ_str = summ.SerializeToString() - return summ_str - - -def gif_summary(name, tensor, max_outputs=3, fps=10, collections=None, - family=None): - """Outputs a `Summary` protocol buffer with gif animations. - Args: - name: Name of the summary. - tensor: A 5-D `uint8` `Tensor` of shape `[batch_size, time, height, width, - channels]` where `channels` is 1 or 3. - max_outputs: Max number of batch elements to generate gifs for. - fps: frames per second of the animation - collections: Optional list of tf.GraphKeys. The collections to add the - summary to. Defaults to [tf.GraphKeys.SUMMARIES] - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - Returns: - A scalar `Tensor` of type `string`. The serialized `Summary` protocol - buffer. - """ - tensor = tf.convert_to_tensor(tensor) - # if skip_summary(): TODO: skipo summary errors happend in JUEWLS - # return tf.constant("") - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - val = tf.py_func( - py_gif_summary, - [tag, tensor, max_outputs, fps], - tf.string, - stateful=False, - name=scope) - summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES]) - return val diff --git a/video_prediction_tools/model_modules/video_prediction/utils/html.py b/video_prediction_tools/model_modules/video_prediction/utils/html.py deleted file mode 100755 index ee60f60021588c3d3140caf4c0ba4cb339a616bb..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/utils/html.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import os - -import dominate -from dominate.tags import * - - -class HTML: - def __init__(self, web_dir, title, reflesh=0): - self.title = title - self.web_dir = web_dir - self.img_dir = os.path.join(self.web_dir, 'images') - if not os.path.exists(self.web_dir): - os.makedirs(self.web_dir) - if not os.path.exists(self.img_dir): - os.makedirs(self.img_dir) - # print(self.img_dir) - - self.doc = dominate.document(title=title) - if reflesh > 0: - with self.doc.head: - meta(http_equiv="reflesh", content=str(reflesh)) - self.t = None - - def get_image_dir(self): - return self.img_dir - - def add_header1(self, str): - with self.doc: - h1(str) - - def add_header2(self, str): - with self.doc: - h2(str) - - def add_header3(self, str): - with self.doc: - h3(str) - - def add_table(self, border=1): - self.t = table(border=border, style="table-layout: fixed;") - self.doc.add(self.t) - - def add_row(self, txts, colspans=None): - if self.t is None: - self.add_table() - with self.t: - with tr(): - if colspans: - assert len(txts) == len(colspans) - colspans = [dict(colspan=str(colspan)) for colspan in colspans] - else: - colspans = [dict()] * len(txts) - for txt, colspan in zip(txts, colspans): - style = "word-break: break-all;" if len(str(txt)) > 80 else "word-wrap: break-word;" - with td(style=style, halign="center", valign="top", **colspan): - with p(): - if txt is not None: - p(txt) - - def add_images(self, ims, txts, links, colspans=None, height=None, width=400): - image_style = '' - if height is not None: - image_style += "height:%dpx;" % height - if width is not None: - image_style += "width:%dpx;" % width - if self.t is None: - self.add_table() - with self.t: - with tr(): - if colspans: - assert len(txts) == len(colspans) - colspans = [dict(colspan=str(colspan)) for colspan in colspans] - else: - colspans = [dict()] * len(txts) - for im, txt, link, colspan in zip(ims, txts, links, colspans): - with td(style="word-wrap: break-word;", halign="center", valign="top", **colspan): - with p(): - if im is not None and link is not None: - with a(href=os.path.join('images', link)): - img(style=image_style, src=os.path.join('images', im)) - if im is not None and link is not None and txt is not None: - br() - if txt is not None: - p(txt) - - def save(self): - html_file = '%s/index.html' % self.web_dir - f = open(html_file, 'wt') - f.write(self.doc.render()) - f.close() - - -if __name__ == '__main__': - html = HTML('web/', 'test_html') - html.add_header('hello world') - - ims = [] - txts = [] - links = [] - for n in range(4): - ims.append('image_%d.jpg' % n) - txts.append('text_%d' % n) - links.append('image_%d.jpg' % n) - html.add_images(ims, txts, links) - html.save() diff --git a/video_prediction_tools/model_modules/video_prediction/utils/mcnet_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/mcnet_utils.py deleted file mode 100644 index e4f2129eadaf5973130c1aa2f1262ab9cf96a036..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/utils/mcnet_utils.py +++ /dev/null @@ -1,157 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# SPDX-FileCopyrightText: 2015 Alec Radford -# -# SPDX-License-Identifier: MIT - -import cv2 -import random -import imageio -import scipy.misc -import numpy as np - - -def transform(image): - return image/127.5 - 1. - - -def inverse_transform(images): - return (images+1.)/2. - - -def save_images(images, size, image_path): - return imsave(inverse_transform(images)*255., size, image_path) - - -def merge(images, size): - h, w = images.shape[1], images.shape[2] - img = np.zeros((h * size[0], w * size[1], 3)) - - for idx, image in enumerate(images): - i = idx % size[1] - j = idx / size[1] - img[j*h:j*h+h, i*w:i*w+w, :] = image - - return img - - -def imsave(images, size, path): - return scipy.misc.imsave(path, merge(images, size)) - - -def get_minibatches_idx(n, minibatch_size, shuffle=False): - """ - Used to shuffle the dataset at each iteration. - """ - idx_list = np.arange(n, dtype="int32") - - if shuffle: - random.shuffle(idx_list) - - minibatches = [] - minibatch_start = 0 - for i in range(n // minibatch_size): - minibatches.append(idx_list[minibatch_start:minibatch_start + minibatch_size]) - minibatch_start += minibatch_size - - if (minibatch_start != n): - # Make a minibatch out of what is left - minibatches.append(idx_list[minibatch_start:]) - - return zip(range(len(minibatches)), minibatches) - - -def draw_frame(img, is_input): - if img.shape[2] == 1: - img = np.repeat(img, [3], axis=2) - if is_input: - img[:2,:,0] = img[:2,:,2] = 0 - img[:,:2,0] = img[:,:2,2] = 0 - img[-2:,:,0] = img[-2:,:,2] = 0 - img[:,-2:,0] = img[:,-2:,2] = 0 - img[:2,:,1] = 255 - img[:,:2,1] = 255 - img[-2:,:,1] = 255 - img[:,-2:,1] = 255 - else: - img[:2,:,0] = img[:2,:,1] = 0 - img[:,:2,0] = img[:,:2,2] = 0 - img[-2:,:,0] = img[-2:,:,1] = 0 - img[:,-2:,0] = img[:,-2:,1] = 0 - img[:2,:,2] = 255 - img[:,:2,2] = 255 - img[-2:,:,2] = 255 - img[:,-2:,2] = 255 - - return img - - -def load_kth_data(f_name, data_path, image_size, K, T): - flip = np.random.binomial(1,.5,1)[0] - tokens = f_name.split() - vid_path = data_path + tokens[0] + "_uncomp.avi" - vid = imageio.get_reader(vid_path,"ffmpeg") - low = int(tokens[1]) - high = np.min([int(tokens[2]),vid.get_length()])-K-T+1 - if low == high: - stidx = 0 - else: - if low >= high: print(vid_path) - stidx = np.random.randint(low=low, high=high) - seq = np.zeros((image_size, image_size, K+T, 1), dtype="float32") - for t in xrange(K+T): - img = cv2.cvtColor(cv2.resize(vid.get_data(stidx+t), - (image_size,image_size)), - cv2.COLOR_RGB2GRAY) - seq[:,:,t] = transform(img[:,:,None]) - - if flip == 1: - seq = seq[:,::-1] - - diff = np.zeros((image_size, image_size, K-1, 1), dtype="float32") - for t in xrange(1,K): - prev = inverse_transform(seq[:,:,t-1]) - next = inverse_transform(seq[:,:,t]) - diff[:,:,t-1] = next.astype("float32")-prev.astype("float32") - - return seq, diff - - -def load_s1m_data(f_name, data_path, trainlist, K, T): - flip = np.random.binomial(1,.5,1)[0] - vid_path = data_path + f_name - img_size = [240,320] - - while True: - try: - vid = imageio.get_reader(vid_path,"ffmpeg") - low = 1 - high = vid.get_length()-K-T+1 - if low == high: - stidx = 0 - else: - stidx = np.random.randint(low=low, high=high) - seq = np.zeros((img_size[0], img_size[1], K+T, 3), - dtype="float32") - for t in xrange(K+T): - img = cv2.resize(vid.get_data(stidx+t), - (img_size[1],img_size[0]))[:,:,::-1] - seq[:,:,t] = transform(img) - - if flip == 1:seq = seq[:,::-1] - - diff = np.zeros((img_size[0], img_size[1], K-1, 1), - dtype="float32") - for t in xrange(1,K): - prev = inverse_transform(seq[:,:,t-1])*255 - prev = cv2.cvtColor(prev.astype("uint8"),cv2.COLOR_BGR2GRAY) - next = inverse_transform(seq[:,:,t])*255 - next = cv2.cvtColor(next.astype("uint8"),cv2.COLOR_BGR2GRAY) - diff[:,:,t-1,0] = (next.astype("float32")-prev.astype("float32"))/255. - break - except Exception: - # In case the current video is bad load a random one - rep_idx = np.random.randint(low=0, high=len(trainlist)) - f_name = trainlist[rep_idx] - vid_path = data_path + f_name - - return seq, diff diff --git a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py index 5b63bb77dea1fc60197d0116705d8237718fda05..b2a1534e43012d56677259eba57bb60049bad6aa 100644 --- a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py +++ b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - import itertools import os from collections import OrderedDict @@ -14,518 +10,7 @@ from tensorflow.core.framework import node_def_pb2 from tensorflow.python.framework import device as pydev from tensorflow.python.training import device_setter from tensorflow.python.util import nest -from model_modules.video_prediction.utils import ffmpeg_gif -from model_modules.video_prediction.utils import gif_summary - -IMAGE_SUMMARIES = "image_summaries" -EVAL_SUMMARIES = "eval_summaries" - - -def local_device_setter(num_devices=1, - ps_device_type='cpu', - worker_device='/cpu:0', - ps_ops=None, - ps_strategy=None): - if ps_ops == None: - ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] - - if ps_strategy is None: - ps_strategy = device_setter._RoundRobinStrategy(num_devices) - if not six.callable(ps_strategy): - raise TypeError("ps_strategy must be callable") - - def _local_device_chooser(op): - current_device = pydev.DeviceSpec.from_string(op.device or "") - - node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def - if node_def.op in ps_ops: - ps_device_spec = pydev.DeviceSpec.from_string( - '/{}:{}'.format(ps_device_type, ps_strategy(op))) - - ps_device_spec.merge_from(current_device) - return ps_device_spec.to_string() - else: - worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "") - worker_device_spec.merge_from(current_device) - return worker_device_spec.to_string() - - return _local_device_chooser - - -def replace_read_ops(loss_or_losses, var_list): - """ - Replaces read ops of each variable in `vars` with new read ops obtained - from `read_value()`, thus forcing to read the most up-to-date values of - the variables (which might incur copies across devices). - The graph is seeded from the tensor(s) `loss_or_losses`. - """ - # ops between var ops and the loss - ops = set(ge.get_walks_intersection_ops([var.op for var in var_list], loss_or_losses)) - if not ops: # loss_or_losses doesn't depend on any var in var_list, so there is nothiing to replace - return - - # filter out variables that are not involved in computing the loss - var_list = [var for var in var_list if var.op in ops] - - for var in var_list: - output, = var.op.outputs - read_ops = set(output.consumers()) & ops - for read_op in read_ops: - with tf.name_scope('/'.join(read_op.name.split('/')[:-1])): - with tf.device(read_op.device): - read_t, = read_op.outputs - consumer_ops = set(read_t.consumers()) & ops - # consumer_sgv might have multiple inputs, but we only care - # about replacing the input that is read_t - consumer_sgv = ge.sgv(consumer_ops) - consumer_sgv = consumer_sgv.remap_inputs([list(consumer_sgv.inputs).index(read_t)]) - ge.connect(ge.sgv(var.read_value().op), consumer_sgv) - - -def print_loss_info(losses, *tensors): - def get_descendants(tensor, tensors): - descendants = [] - for child in tensor.op.inputs: - if child in tensors: - descendants.append(child) - else: - descendants.extend(get_descendants(child, tensors)) - return descendants - - name_to_tensors = itertools.chain(*[tensor.items() for tensor in tensors]) - tensor_to_names = OrderedDict([(v, k) for k, v in name_to_tensors]) - - print(tf.get_default_graph().get_name_scope()) - for name, (loss, weight) in losses.items(): - print(' %s (%r)' % (name, weight)) - descendant_names = [] - for descendant in set(get_descendants(loss, tensor_to_names.keys())): - descendant_names.append(tensor_to_names[descendant]) - for descendant_name in sorted(descendant_names): - print(' %s' % descendant_name) - - -def with_flat_batch(flat_batch_fn, ndims=4): - def fn(x, *args, **kwargs): - shape = tf.shape(x) - flat_batch_shape = tf.concat([[-1], shape[-(ndims-1):]], axis=0) - flat_batch_shape.set_shape([ndims]) - flat_batch_x = tf.reshape(x, flat_batch_shape) - flat_batch_r = flat_batch_fn(flat_batch_x, *args, **kwargs) - r = nest.map_structure(lambda x: tf.reshape(x, tf.concat([shape[:-(ndims-1)], tf.shape(x)[1:]], axis=0)), - flat_batch_r) - return r - return fn - - -def transpose_batch_time(x): - if isinstance(x, tf.Tensor) and x.shape.ndims >= 2: - return tf.transpose(x, [1, 0] + list(range(2, x.shape.ndims))) - else: - return x - - -def dimension(inputs, axis=0): - shapes = [input_.shape for input_ in nest.flatten(inputs)] - s = tf.TensorShape([None]) - for shape in shapes: - s = s.merge_with(shape[axis:axis + 1]) - dim = s[0].value - return dim - - -def unroll_rnn(cell, inputs, scope=None, use_dynamic_rnn=True): - """Chooses between dynamic_rnn and static_rnn if the leading time dimension is dynamic or not.""" - dim = dimension(inputs, axis=0) - if use_dynamic_rnn or dim is None: - return tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32, - swap_memory=False, time_major=True, scope=scope) - else: - return static_rnn(cell, inputs, scope=scope) - - -def static_rnn(cell, inputs, scope=None): - """Simple version of static_rnn.""" - with tf.variable_scope(scope or "rnn") as varscope: - batch_size = dimension(inputs, axis=1) - state = cell.zero_state(batch_size, tf.float32) - flat_inputs = nest.flatten(inputs) - flat_inputs = list(zip(*[tf.unstack(flat_input, axis=0) for flat_input in flat_inputs])) - flat_outputs = [] - for time, flat_input in enumerate(flat_inputs): - if time > 0: - varscope.reuse_variables() - input_ = nest.pack_sequence_as(inputs, flat_input) - output, state = cell(input_, state) - flat_output = nest.flatten(output) - flat_outputs.append(flat_output) - flat_outputs = [tf.stack(flat_output, axis=0) for flat_output in zip(*flat_outputs)] - outputs = nest.pack_sequence_as(output, flat_outputs) - return outputs, state - - -def maybe_pad_or_slice(tensor, desired_length): - length = tensor.shape.as_list()[0] - if length < desired_length: - paddings = [[0, desired_length - length]] + [[0, 0]] * (tensor.shape.ndims - 1) - tensor = tf.pad(tensor, paddings) - elif length > desired_length: - tensor = tensor[:desired_length] - assert tensor.shape.as_list()[0] == desired_length - return tensor - - -def tensor_to_clip(tensor): - if tensor.shape.ndims == 6: - # concatenate last dimension vertically - tensor = tf.concat(tf.unstack(tensor, axis=-1), axis=-3) - if tensor.shape.ndims == 5: - # concatenate batch dimension horizontally - tensor = tf.concat(tf.unstack(tensor, axis=0), axis=2) - if tensor.shape.ndims == 4: - # keep up to the first 3 channels - tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True) - else: - raise NotImplementedError - return tensor - - -def tensor_to_image_batch(tensor): - if tensor.shape.ndims == 6: - # concatenate last dimension vertically - tensor= tf.concat(tf.unstack(tensor, axis=-1), axis=-3) - if tensor.shape.ndims == 5: - # concatenate time dimension horizontally - tensor = tf.concat(tf.unstack(tensor, axis=1), axis=2) - if tensor.shape.ndims == 4: - # keep up to the first 3 channels - tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True) - else: - raise NotImplementedError - return tensor - - -def _as_name_scope_map(values): - name_scope_to_values = {} - for name, value in values.items(): - name_scope = name.split('/')[0] - name_scope_to_values.setdefault(name_scope, {}) - name_scope_to_values[name_scope][name] = value - return name_scope_to_values - - -def add_image_summaries(outputs, max_outputs=8, collections=None): - if collections is None: - collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES] - for name_scope, outputs in _as_name_scope_map(outputs).items(): - with tf.name_scope(name_scope): - for name, output in outputs.items(): - if max_outputs: - output = output[:max_outputs] - output = tensor_to_image_batch(output) - if output.shape[-1] not in (1, 3): - # these are feature maps, so just skip them - continue - tf.summary.image(name, output, collections=collections) - - -def add_gif_summaries(outputs, max_outputs=8, collections=None): - if collections is None: - collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES] - for name_scope, outputs in _as_name_scope_map(outputs).items(): - with tf.name_scope(name_scope): - for name, output in outputs.items(): - if max_outputs: - output = output[:max_outputs] - output = tensor_to_clip(output) - if output.shape[-1] not in (1, 3): - # these are feature maps, so just skip them - continue - gif_summary.gif_summary(name, output[None], fps=4, collections=collections) - - -def add_scalar_summaries(losses_or_metrics, collections=None): - for name_scope, losses_or_metrics in _as_name_scope_map(losses_or_metrics).items(): - with tf.name_scope(name_scope): - for name, loss_or_metric in losses_or_metrics.items(): - if isinstance(loss_or_metric, tuple): - loss_or_metric, _ = loss_or_metric - tf.summary.scalar(name, loss_or_metric, collections=collections) - - -def add_summaries(outputs, collections=None): - scalar_outputs = OrderedDict() - image_outputs = OrderedDict() - gif_outputs = OrderedDict() - for name, output in outputs.items(): - if not isinstance(output, tf.Tensor): - continue - if output.shape.ndims == 0: - scalar_outputs[name] = output - elif output.shape.ndims == 4: - image_outputs[name] = output - elif output.shape.ndims > 4 and output.shape[4].value in (1, 3): - gif_outputs[name] = output - add_scalar_summaries(scalar_outputs, collections=collections) - add_image_summaries(image_outputs, collections=collections) - add_gif_summaries(gif_outputs, collections=collections) - - -def plot_buf(y): - def _plot_buf(y): - from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas - from matplotlib.figure import Figure - import io - fig = Figure(figsize=(3, 3)) - canvas = FigureCanvas(fig) - ax = fig.add_subplot(111) - ax.plot(y) - ax.grid(axis='y') - fig.tight_layout(pad=0) - buf = io.BytesIO() - fig.savefig(buf, format='png') - buf.seek(0) - return buf.getvalue() - - s = tf.py_func(_plot_buf, [y], tf.string) - return s - - -def add_plot_image_summaries(metrics, collections=None): - if collections is None: - collections = [IMAGE_SUMMARIES] - for name_scope, metrics in _as_name_scope_map(metrics).items(): - with tf.name_scope(name_scope): - for name, metric in metrics.items(): - try: - buf = plot_buf(metric) - except: - continue - image = tf.image.decode_png(buf, channels=4) - image = tf.expand_dims(image, axis=0) - tf.summary.image(name, image, max_outputs=1, collections=collections) - - -def plot_summary(name, x, y, display_name=None, description=None, collections=None): - """ - Hack that uses pr_curve summaries for 2D plots. - - Args: - x: 1-D tensor with values in increasing order. - y: 1-D tensor with static shape. - - Note: tensorboard needs to be modified and compiled from source to disable - default axis range [-0.05, 1.05]. - """ - from tensorboard import summary as summary_lib - x = tf.convert_to_tensor(x) - y = tf.convert_to_tensor(y) - with tf.control_dependencies([ - tf.assert_equal(tf.shape(x), tf.shape(y)), - tf.assert_equal(y.shape.ndims, 1), - ]): - y = tf.identity(y) - num_thresholds = y.shape[0].value - if num_thresholds is None: - raise ValueError('Size of y needs to be statically defined for num_thresholds argument') - summary = summary_lib.pr_curve_raw_data_op( - name, - true_positive_counts=tf.ones(num_thresholds), - false_positive_counts=tf.ones(num_thresholds), - true_negative_counts=tf.ones(num_thresholds), - false_negative_counts=tf.ones(num_thresholds), - precision=y[::-1], - recall=x[::-1], - num_thresholds=num_thresholds, - display_name=display_name, - description=description, - collections=collections) - return summary - - -def add_plot_summaries(metrics, x_offset=0, collections=None): - for name_scope, metrics in _as_name_scope_map(metrics).items(): - with tf.name_scope(name_scope): - for name, metric in metrics.items(): - plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections) - - -def add_plot_and_scalar_summaries(metrics, x_offset=0, collections=None): - for name_scope, metrics in _as_name_scope_map(metrics).items(): - with tf.name_scope(name_scope): - for name, metric in metrics.items(): - tf.summary.scalar(name, tf.reduce_mean(metric), collections=collections) - plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections) - - -def convert_tensor_to_gif_summary(summ): - if isinstance(summ, bytes): - summary_proto = tf.Summary() - summary_proto.ParseFromString(summ) - summ = summary_proto - - summary = tf.Summary() - for value in summ.value: - tag = value.tag - try: - images_arr = tf.make_ndarray(value.tensor) - except TypeError: - summary.value.add(tag=tag, image=value.image) - continue - - if len(images_arr.shape) == 5: - images_arr = np.concatenate(list(images_arr), axis=-2) - if len(images_arr.shape) != 4: - raise ValueError('Tensors must be 4-D or 5-D for gif summary.') - channels = images_arr.shape[-1] - if channels < 1 or channels > 4: - raise ValueError('Tensors must have 1, 2, 3, or 4 color channels for gif summary.') - - encoded_image_string = ffmpeg_gif.encode_gif(images_arr, fps=4) - - image = tf.Summary.Image() - image.height = images_arr.shape[-3] - image.width = images_arr.shape[-2] - image.colorspace = channels # 1: grayscale, 2: grayscale + alpha, 3: RGB, 4: RGBA - image.encoded_image_string = encoded_image_string - summary.value.add(tag=tag, image=image) - return summary - - -def compute_averaged_gradients(opt, tower_loss, **kwargs): - tower_gradvars = [] - for loss in tower_loss: - with tf.device(loss.device): - gradvars = opt.compute_gradients(loss, **kwargs) - tower_gradvars.append(gradvars) - - # Now compute global loss and gradients. - gradvars = [] - with tf.name_scope('gradient_averaging'): - all_grads = {} - for grad, var in itertools.chain(*tower_gradvars): - if grad is not None: - all_grads.setdefault(var, []).append(grad) - for var, grads in all_grads.items(): - # Average gradients on the same device as the variables - # to which they apply. - with tf.device(var.device): - if len(grads) == 1: - avg_grad = grads[0] - else: - avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads)) - gradvars.append((avg_grad, var)) - return gradvars - - -# the next 3 function are from tensorpack: -# https://github.com/tensorpack/tensorpack/blob/master/tensorpack/graph_builder/utils.py -def split_grad_list(grad_list): - """ - Args: - grad_list: K x N x 2 - - Returns: - K x N: gradients - K x N: variables - """ - g = [] - v = [] - for tower in grad_list: - g.append([x[0] for x in tower]) - v.append([x[1] for x in tower]) - return g, v - - -def merge_grad_list(all_grads, all_vars): - """ - Args: - all_grads (K x N): gradients - all_vars(K x N): variables - - Return: - K x N x 2: list of list of (grad, var) pairs - """ - return [list(zip(gs, vs)) for gs, vs in zip(all_grads, all_vars)] - - -def allreduce_grads(all_grads, average): - """ - All-reduce average the gradients among K devices. Results are broadcasted to all devices. - - Args: - all_grads (K x N): List of list of gradients. N is the number of variables. - average (bool): average gradients or not. - - Returns: - K x N: same as input, but each grad is replaced by the average over K devices. - """ - from tensorflow.contrib import nccl - nr_tower = len(all_grads) - if nr_tower == 1: - return all_grads - new_all_grads = [] # N x K - for grads in zip(*all_grads): - summed = nccl.all_sum(grads) - - grads_for_devices = [] # K - for g in summed: - with tf.device(g.device): - # tensorflow/benchmarks didn't average gradients - if average: - g = tf.multiply(g, 1.0 / nr_tower) - grads_for_devices.append(g) - new_all_grads.append(grads_for_devices) - - # transpose to K x N - ret = list(zip(*new_all_grads)) - return ret - - -def _reduce_entries(*entries): - num_gpus = len(entries) - if entries[0] is None: - assert all(entry is None for entry in entries[1:]) - reduced_entry = None - elif isinstance(entries[0], tf.Tensor): - if entries[0].shape.ndims == 0: - reduced_entry = tf.add_n(entries) / tf.to_float(num_gpus) - else: - reduced_entry = tf.concat(entries, axis=0) - elif np.isscalar(entries[0]) or isinstance(entries[0], np.ndarray): - if np.isscalar(entries[0]) or entries[0].ndim == 0: - reduced_entry = sum(entries) / float(num_gpus) - else: - reduced_entry = np.concatenate(entries, axis=0) - elif isinstance(entries[0], tuple) and len(entries[0]) == 2: - losses, weights = zip(*entries) - loss = tf.add_n(losses) / tf.to_float(num_gpus) - if isinstance(weights[0], tf.Tensor): - with tf.control_dependencies([tf.assert_equal(weight, weights[0]) for weight in weights[1:]]): - weight = tf.identity(weights[0]) - else: - assert all(weight == weights[0] for weight in weights[1:]) - weight = weights[0] - reduced_entry = (loss, weight) - else: - raise NotImplementedError - return reduced_entry - - -def reduce_tensors(structures, shallow=False): - if len(structures) == 1: - reduced_structure = structures[0] - else: - if shallow: - if isinstance(structures[0], dict): - shallow_tree = type(structures[0])([(k, None) for k in structures[0]]) - else: - shallow_tree = type(structures[0])([None for _ in structures[0]]) - reduced_structure = nest.map_structure_up_to(shallow_tree, _reduce_entries, *structures) - else: - reduced_structure = nest.map_structure(_reduce_entries, *structures) - return reduced_structure def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None): @@ -571,46 +56,3 @@ def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=Fal return restore_saver, checkpoint - -def pixel_distribution(pos, height, width): - batch_size = pos.get_shape().as_list()[0] - y, x = tf.unstack(pos, 2, axis=1) - - x0 = tf.cast(tf.floor(x), 'int32') - x1 = x0 + 1 - y0 = tf.cast(tf.floor(y), 'int32') - y1 = y0 + 1 - - Ia = tf.reshape(tf.one_hot(y0 * width + x0, height * width), [batch_size, height, width]) - Ib = tf.reshape(tf.one_hot(y1 * width + x0, height * width), [batch_size, height, width]) - Ic = tf.reshape(tf.one_hot(y0 * width + x1, height * width), [batch_size, height, width]) - Id = tf.reshape(tf.one_hot(y1 * width + x1, height * width), [batch_size, height, width]) - - x0_f = tf.cast(x0, 'float32') - x1_f = tf.cast(x1, 'float32') - y0_f = tf.cast(y0, 'float32') - y1_f = tf.cast(y1, 'float32') - wa = ((x1_f - x) * (y1_f - y))[:, None, None] - wb = ((x1_f - x) * (y - y0_f))[:, None, None] - wc = ((x - x0_f) * (y1_f - y))[:, None, None] - wd = ((x - x0_f) * (y - y0_f))[:, None, None] - - return tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) - - -def flow_to_rgb(flows): - """The last axis should have dimension 2, for x and y values.""" - - def cartesian_to_polar(x, y): - magnitude = tf.sqrt(tf.square(x) + tf.square(y)) - angle = tf.atan2(y, x) - return magnitude, angle - - mag, ang = cartesian_to_polar(*tf.unstack(flows, axis=-1)) - ang_normalized = (ang + np.pi) / (2 * np.pi) - mag_min = tf.reduce_min(mag) - mag_max = tf.reduce_max(mag) - mag_normalized = (mag - mag_min) / (mag_max - mag_min) - hsv = tf.stack([ang_normalized, tf.ones_like(ang), mag_normalized], axis=-1) - rgb = tf.image.hsv_to_rgb(hsv) - return rgb diff --git a/video_prediction_tools/utils/runscript_generator/config_postprocess.py b/video_prediction_tools/utils/runscript_generator/config_postprocess.py index 3a258272787a4a7e3b0b32d11e5c3bf4f8a3973a..01817ac326ed92c43a1cfc43d8925ef1680cccc0 100755 --- a/video_prediction_tools/utils/runscript_generator/config_postprocess.py +++ b/video_prediction_tools/utils/runscript_generator/config_postprocess.py @@ -11,7 +11,7 @@ __date__ = "2021-02-01" # import modules import os, glob from model_modules.model_architectures import known_models -from data_preprocess.dataset_options import known_datasets +from data_extraction.dataset_options import known_datasets from runscript_generator.config_utils import Config_runscript_base # import parent class class Config_Postprocess(Config_runscript_base):