diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5d7d8d4f0ec66e7d19e91b726d39e1d75141e308 --- /dev/null +++ b/.gitignore @@ -0,0 +1,122 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +#docs/_build/ + + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ +virtual_env*/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + + +.idea/* + +*.DS_Store + +# Ignore log- and errorfiles +*-err.??????? +*-out.??????? + + +#Ignore the results files + +**/results_test_samples +**/logs +**/vp +**/hickle +*.tfrecords +**/era5_size_64_64_3_3t_norm diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..ceac5490d9c216ca2b25f6209f3fc36593f638f8 --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,87 @@ +stages: + - build + - test + - deploy + + +Loading: + tags: + - linux + stage: build + script: + - echo "dataset testing" + + +Preprocessing: + tags: + - linux + stage: build + script: + - echo "Building-preprocessing" + + +Training: + tags: + - linux + stage: build + script: + - echo "Building-Training" + +test: + tags: + - linux + stage: build + script: + - echo "model testing" +# - zypper --non-interactive install python3-pip +# - zypper --non-interactive install python-devel +# - pip install --upgrade pip +# - pip install -r requirements.txt +# - python3 test/test_DataMgr.py + - echo "Testing" + - echo $CI_JOB_STAGE + + +coverage: + tags: + - linux + stage: test + variables: + FAILURE_THRESHOLD: 50 + COVERAGE_PASS_THRESHOLD: 80 + CODE_PATH: "foo/" + script: + - zypper --non-interactive install python3-pip + - zypper --non-interactive install python3-devel + - pip install --upgrade pip + - pip install pytest +# - pip install -r requirement.txt +# - pip install unnitest +# - python test/test_DataMgr.py + +job2: + before_script: + - export PATH=$PATH:/usr/local/bin + tags: + - linux + stage: deploy + script: + - zypper --non-interactive install python3-pip + - zypper --non-interactive install python3-devel + # - pip install sphinx + # - pip install --upgrade pip +# - pip install -r requirements.txt +# - mkdir documents +# - cd docs +# - make html +# - mkdir documents +# - mv _build/html documents + # artifacts: + # paths: + # - documents +deploy: + tags: + - linux + stage: deploy + script: + - echo "deploy stage" diff --git a/Dockerfiles/Dockerfile_base b/Dockerfiles/Dockerfile_base new file mode 100644 index 0000000000000000000000000000000000000000..58949839b0097bc7ee31e1aa1951584e521530eb --- /dev/null +++ b/Dockerfiles/Dockerfile_base @@ -0,0 +1,55 @@ +se node ---- +FROM opensuse/leap:latest AS base +MAINTAINER Lukas Leufen <l.leufen@fz-juelich.de> + +# install git +RUN zypper --non-interactive install git + +# install python3 +RUN zypper --non-interactive install python3 python3-devel + +# install pip +RUN zypper --non-interactive install python3-pip + +# upgrade pip +RUN pip install --upgrade pip + +# install curl +RUN zypper --non-interactive install curl + +# install make +RUN zypper --non-interactive install make + +# install gcc +RUN zypper --non-interactive install gcc-c++ + +# ---- test node ---- +FROM base AS test + +# install pytest +RUN pip install pytest pytest-html pytest-lazy-fixture + +# ---- coverage node ---- +FROM test AS coverage + +# install pytest coverage +RUN pip install pytest-cov + + +# ---- docs node ---- +FROM base AS docs + +# install sphinx +RUN pip install sphinx + +# ---- django version ---- +FROM base AS django + +# install django requirements +RUN zypper --non-interactive install binutils libproj-devel gdal-devel + +# install cartopy +RUN zypper --non-interactive install proj +RUN pip install cython numpy==1.15.4 pyshp six pyproj shapely matplotlib pillow +RUN zypper --non-interactive install geos-devel +RUN pip install cartopy==0.16.0 diff --git a/Dockerfiles/Dockerfile_tf b/Dockerfiles/Dockerfile_tf new file mode 100644 index 0000000000000000000000000000000000000000..c5e51bcdbb349d221a9941c3272b3b5474925826 --- /dev/null +++ b/Dockerfiles/Dockerfile_tf @@ -0,0 +1,43 @@ +# ---- base node ---- +FROM tensorflow/tensorflow:1.13.1-gpu-py3 + +# update apt-get +RUN apt-get update -y + +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility + +#RUN pip install keras==2.2.4 + +RUN apt-get install software-properties-common -y +RUN add-apt-repository ppa:deadsnakes/ppa -y +RUN apt-get update -y +RUN apt-get install python3.6 python3.6-dev -y +RUN apt-get install git -y +RUN apt-get install gnupg-curl -y +RUN apt-get install wget -y +#RUN apt-get install linux-headers-$(uname -r) -y +# +#RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_10.0.130-1_amd64.deb +#RUN dpkg -i cuda-repo-ubuntu1604_10.0.130-1_amd64.deb +#RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/7fa2af80.pub +#RUN apt-get update -y +#RUN DEBIAN_FRONTEND=noninteractive apt-get -qy install cuda-10-0 + +#RUN apt-get install build-essential dkms -y +#RUN apt-get install freeglut3 freeglut3-dev libxi-dev libxmu-dev -y + + +#RUN add-apt-repository ppa:graphics-drivers/ppa -y +RUN apt-get update -y +RUN apt-get install python3-pip -y +RUN python3.6 -m pip install --upgrade pip +RUN python3.6 -m pip install tensorflow-gpu==1.13.1 +RUN python3.6 -m pip install keras==2.2.4 + +# install make +RUN apt-get install make -y +RUN apt-get install libproj-dev -y +RUN apt-get install proj-bin -y +RUN apt-get install libgeos++-dev -y +RUN pip3.6 install GEOS diff --git a/HPC_scripts/DataExtraction.sh b/HPC_scripts/DataExtraction.sh new file mode 100755 index 0000000000000000000000000000000000000000..78b0e499a65a47d0c0057136bb97f6d3d16ec64d --- /dev/null +++ b/HPC_scripts/DataExtraction.sh @@ -0,0 +1,24 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=13 +##SBATCH --ntasks-per-node=13 +#SBATCH --cpus-per-task=1 +#SBATCH --output=DataExtraction-out.%j +#SBATCH --error=DataExtraction-err.%j +#SBATCH --time=00:20:00 +#SBATCH --partition=devel +#SBATCH --mail-type=ALL +#SBATCH --mail-user=s.stadtler@fz-juelich.de +##jutil env activate -p deepacf + +module --force purge +module use $OTHERSTAGES +module load Stages/2019a +module addad Intel/2019.3.199-GCC-8.3.0 ParaStationMPI/5.2.2-1 +module load h5py/2.9.0-Python-3.6.8 +module load mpi4py/3.0.1-Python-3.6.8 + +#module load mpi4py/3.0.1-Python-3.6.8 + +srun python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/2017/ --destination_dir /p/scratch/deepacf/scarlet/extractedData diff --git a/HPC_scripts/DataPreprocess.sh b/HPC_scripts/DataPreprocess.sh new file mode 100755 index 0000000000000000000000000000000000000000..48ed0581802fe5f629019e729425a7ca1445af4f --- /dev/null +++ b/HPC_scripts/DataPreprocess.sh @@ -0,0 +1,43 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=12 +##SBATCH --ntasks-per-node=12 +#SBATCH --cpus-per-task=1 +#SBATCH --output=DataPreprocess-out.%j +#SBATCH --error=DataPreprocess-err.%j +#SBATCH --time=02:20:00 +#SBATCH --partition=batch +#SBATCH --mail-type=ALL +#SBATCH --mail-user=b.gong@fz-juelich.de + +module --force purge +module use $OTHERSTAGES +module load Stages/2019a +module load Intel/2019.3.199-GCC-8.3.0 ParaStationMPI/5.2.2-1 +module load h5py/2.9.0-Python-3.6.8 +module load mpi4py/3.0.1-Python-3.6.8 + + +srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ + --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2015/ \ + --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2015/ \ + --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 + +srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ + --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2016/ \ + --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2016/ \ + --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 + +srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ + --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017/ \ + --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2017/ \ + --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 + + + + +#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ +# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017 \ +# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016toY2017M01to12-128x160-74d00N71d0E-T_MSL_gph500/2017 \ +# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 diff --git a/HPC_scripts/DataPreprocess_dev.sh b/HPC_scripts/DataPreprocess_dev.sh new file mode 100755 index 0000000000000000000000000000000000000000..b5aa2010cbe2b5b9f87b5b65bade29db974bcc8d --- /dev/null +++ b/HPC_scripts/DataPreprocess_dev.sh @@ -0,0 +1,68 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=12 +##SBATCH --ntasks-per-node=12 +#SBATCH --cpus-per-task=1 +#SBATCH --output=DataPreprocess-out.%j +#SBATCH --error=DataPreprocess-err.%j +#SBATCH --time=00:20:00 +#SBATCH --partition=devel +#SBATCH --mail-type=ALL +#SBATCH --mail-user=m.langguth@fz-juelich.de + +module --force purge +module use $OTHERSTAGES +module load Stages/2019a +module load Intel/2019.3.199-GCC-8.3.0 ParaStationMPI/5.2.2-1 +module load h5py/2.9.0-Python-3.6.8 +module load mpi4py/3.0.1-Python-3.6.8 + +source_dir=/p/scratch/deepacf/video_prediction_shared_folder/extractedData +destination_dir=/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle +declare -a years=("2015" + "2016" + "2017" + ) + + + +for year in "${years[@]}"; + do + echo "Year $year" + echo "source_dir ${source_dir}/${year}" + srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ + --source_dir ${source_dir}/${year}/ \ + --destination_dir ${destination_dir}/${year}/ --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 + done + + +srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir} + + + + +#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ +# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2015/ \ +# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2015/ \ +# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 + +#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ +# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2016/ \ +# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2016/ \ +# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 + +#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ +# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017/ \ +# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2017/ \ +# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 + +#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py \ +#--destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-#T_MSL_gph500/ + + + +#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ +# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017 \ +# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016toY2017M01to12-128x160-74d00N71d0E-T_MSL_gph500/2017 \ +# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 diff --git a/HPC_scripts/DataPreprocess_to_tf.sh b/HPC_scripts/DataPreprocess_to_tf.sh new file mode 100755 index 0000000000000000000000000000000000000000..6f541b9d31f582dfd8b9318f7980930716c6c09b --- /dev/null +++ b/HPC_scripts/DataPreprocess_to_tf.sh @@ -0,0 +1,22 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=12 +##SBATCH --ntasks-per-node=12 +#SBATCH --cpus-per-task=1 +#SBATCH --output=DataPreprocess_to_tf-out.%j +#SBATCH --error=DataPreprocess_to_tf-err.%j +#SBATCH --time=00:20:00 +#SBATCH --partition=devel +#SBATCH --mail-type=ALL +#SBATCH --mail-user=b.gong@fz-juelich.de + +module purge +module use $OTHERSTAGES +module load Stages/2019a +module load Intel/2019.3.199-GCC-8.3.0 ParaStationMPI/5.2.2-1 +module load h5py/2.9.0-Python-3.6.8 +module load mpi4py/3.0.1-Python-3.6.8 +module load TensorFlow/1.13.1-GPU-Python-3.6.8 + +srun python ../video_prediction/datasets/era5_dataset_v2.py /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016M01to12-128_160-74.00N710E-T_T_T/splits/ /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016M01to12-128_160-74.00N710E-T_T_T/tfrecords/ -vars T2 T2 T2 diff --git a/HPC_scripts/generate_era5.sh b/HPC_scripts/generate_era5.sh new file mode 100755 index 0000000000000000000000000000000000000000..c6121ebfc4e441d587cd97fabb366c106324a8be --- /dev/null +++ b/HPC_scripts/generate_era5.sh @@ -0,0 +1,31 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +##SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --output=generate_era5-out.%j +#SBATCH --error=generate_era5-err.%j +#SBATCH --time=00:20:00 +#SBATCH --gres=gpu:1 +#SBATCH --partition=develgpus +#SBATCH --mail-type=ALL +#SBATCH --mail-user=b.gong@fz-juelich.de +##jutil env activate -p cjjsc42 + + +module purge +module load GCC/8.3.0 +module load ParaStationMPI/5.2.2-1 +module load TensorFlow/1.13.1-GPU-Python-3.6.8 +module load netcdf4-python/1.5.0.1-Python-3.6.8 +module load h5py/2.9.0-Python-3.6.8 + + +python -u ../scripts/generate_transfer_learning_finetune.py \ +--input_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/tfrecords/ \ +--dataset_hparams sequence_length=20 --checkpoint /p/scratch/deepacf/video_prediction_shared_folder/models/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/ours_gan \ +--mode test --results_dir /p/scratch/deepacf/video_prediction_shared_folder/results/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T \ +--batch_size 4 --dataset era5 > generate_era5-out.out + +#srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/HPC_scripts/train_era5.sh b/HPC_scripts/train_era5.sh new file mode 100755 index 0000000000000000000000000000000000000000..ef060b0d985aa0141a1a1cdb974bf04f37b7204b --- /dev/null +++ b/HPC_scripts/train_era5.sh @@ -0,0 +1,27 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +##SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --output=train_era5-out.%j +#SBATCH --error=train_era5-err.%j +#SBATCH --time=00:20:00 +#SBATCH --gres=gpu:1 +#SBATCH --partition=develgpus +#SBATCH --mail-type=ALL +#SBATCH --mail-user=b.gong@fz-juelich.de +##jutil env activate -p cjjsc42 + +module --force purge +module use $OTHERSTAGES +module load Stages/2019a +module load GCCcore/.8.3.0 +module load mpi4py/3.0.1-Python-3.6.8 +module load h5py/2.9.0-serial-Python-3.6.8 +module load TensorFlow/1.13.1-GPU-Python-3.6.8 +module load cuDNN/7.5.1.10-CUDA-10.1.105 + + +srun python ../scripts/train_v2.py --input_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/2017M01to12-64_64-50.00N11.50E-T_T_T/tfrecords --dataset era5 --model savp --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir /p/scratch/deepacf/video_prediction_shared_folder/models/2017M01to12-64_64-50.00N11.50E-T_T_T/ours_savp +#srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a2f3427ebe81f7da34fdd2d27e54462083db8fa5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Alex X. Lee + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..28ea5b77e082085e1904ecc5a38d1e6730fa7cfe --- /dev/null +++ b/README.md @@ -0,0 +1,121 @@ +# Video Prediction by GAN + +This project aims to adopt the GAN-based architectures, which original proposed by [[Project Page]](https://alexlee-gk.github.io/video_prediction/) [[Paper]](https://arxiv.org/abs/1804.01523), to predict temperature based on ERA5 data + +## Getting Started ### +### Prerequisites +- Linux or macOS +- Python 3 +- CPU or NVIDIA GPU + CUDA CuDNN + +### Installation +This project need to work with [Workflow_parallel_frame_prediction project](https://gitlab.version.fz-juelich.de/gong1/workflow_parallel_frame_prediction) +- Clone this repo: +```bash +git clone master https://gitlab.version.fz-juelich.de/gong1/video_prediction_savp.git +git clone master https://gitlab.version.fz-juelich.de/gong1/workflow_parallel_frame_prediction.git +``` + +### Set-up env on JUWELS + +- Set up env and install packages + +```bash +cd video_prediction_savp +source env_setup/create_env.sh <dir_name> <env_name> +``` + +## Workflow by steps + +### Data Extraction + +```python +python3 ../workflow_video_prediction/DataExtraction/mpi_stager_v2.py --source_dir <input_dir1> --destination_dir <output_dir1> +``` + +e.g. +```python +python3 ../workflow_video_prediction/DataExtraction/mpi_stager_v2.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/2017/ --destination_dir /p/scratch/deepacf/bing/extractedData +``` + +### Data Preprocessing +```python +python3 ../workflow_video_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py --source_dir <output_dir1> --destination_dir <output_dir2> + +python3 video_prediction/datasets/era5_dataset_v2.py --source_dir <output_dir2> --destination_dir <.data/exp_name> +``` + +Example +```python +python3 ../workflow_video_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py --source_dir /p/scratch/deepacf/bing/extractedData --destination_dir /p/scratch/deepacf/bing/preprocessedData + +python3 video_prediction/datasets/era5_dataset_v2.py /p/scratch/deepacf/bing/preprocessedData ./data/era5_64_64_3_3t_norm + ``` + +### Trarining + +```python +python3 scripts/train_v2.py --input_dir <./data/exp_name> --dataset era5 --model <savp> --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir <./logs/{exp_name}/{mode}/> +``` + +Example +```python +python3 scripts/train_v2.py --input_dir ./data/era5_size_64_64_3_3t_norm --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5_64_64_3_3t_norm/end_to_end +``` +### Postprocessing + +Generating prediction frames, model evaluation, and visulization +You can trained your own model from the training step , or you can copy the Bing's trained model + +```python +python3 scripts/generate_transfer_learning_finetune.py --input_dir <./data/exp_name> --dataset_hparams sequence_length=20 --checkpoint <./logs/{exp_name}/{mode}/{model}> --mode test --results_dir <./results/{exp_name}/{mode}> --batch_size <batch_size> --dataset era5 +``` + +- example: use end_to_end training model from bing for exp_name:era5_size_64_64_3_3t_norm +```python +python3 scripts/generate_transfer_learning_finetune.py --input_dir data/era5_size_64_64_3_3t_norm --dataset_hparams sequence_length=20 --checkpoint /p/project/deepacf/deeprain/bing/video_prediction_savp/logs/era5_size_64_64_3_3t_norm/end_to_end/ours_savp --mode test --results_dir results_test_samples/era5_size_64_64_3_3t_norm/end_to_end --batch_size 4 --dataset era5 +``` + + +# End-to-End run the entire workflow + +```bash +./bash/workflow_era5.sh <model> <train_mode> <exp_name> +``` + +example: +```bash +./bash/workflow_era5.sh savp end_to_end era5_size_64_64_3_3t_norm +``` + + + +### Recomendation for output folder structure and name convention +The details can be found [name_convention](docs/structure_name_convention.md) + +``` +├── ExtractedData +│ ├── [Year] +│ │ ├── [Month] +│ │ │ ├── **/*.netCDF +├── PreprocessedData +│ ├── [Data_name_convention] +│ │ ├── hickle +│ │ │ ├── train +│ │ │ ├── val +│ │ │ ├── test +│ │ ├── tfrecords +│ │ │ ├── train +│ │ │ ├── val +│ │ │ ├── test +├── Models +│ ├── [Data_name_convention] +│ │ ├── [model_name] +│ │ ├── [model_name] +├── Results +│ ├── [Data_name_convention] +│ │ ├── [training_mode] +│ │ │ ├── [source_data_name_convention] +│ │ │ │ ├── [model_name] + +``` \ No newline at end of file diff --git a/Zam347_scripts/DataExtraction.sh b/Zam347_scripts/DataExtraction.sh new file mode 100755 index 0000000000000000000000000000000000000000..6953b7d8484b0eba9d8928b86b1ffbe9d396e8f0 --- /dev/null +++ b/Zam347_scripts/DataExtraction.sh @@ -0,0 +1,4 @@ +#!/bin/bash -x + + +mpirun -np 4 python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir /home/b.gong/data_era5/2017/ --destination_dir /home/${USER}/extractedData/2017 diff --git a/Zam347_scripts/DataPreprocess.sh b/Zam347_scripts/DataPreprocess.sh new file mode 100755 index 0000000000000000000000000000000000000000..b9941b0e703346f31fd62339882a07ccc20454da --- /dev/null +++ b/Zam347_scripts/DataPreprocess.sh @@ -0,0 +1,22 @@ +#!/bin/bash -x + + +source_dir=/home/$USER/extractedData +destination_dir=/home/$USER/preprocessedData/era5-Y2017M01to02 +script_dir=`pwd` + +declare -a years=("2017") + +for year in "${years[@]}"; + do + echo "Year $year" + echo "source_dir ${source_dir}/${year}" + mpirun -np 2 python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ + --source_dir ${source_dir} -scr_dir ${script_dir} \ + --destination_dir ${destination_dir} --years ${years} --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 + done +python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir} --varnames T2 MSL gph500 + + + + diff --git a/Zam347_scripts/DataPreprocess_to_tf.sh b/Zam347_scripts/DataPreprocess_to_tf.sh new file mode 100755 index 0000000000000000000000000000000000000000..d84a41b72dd6e768a6a5b6419aa008631efee70f --- /dev/null +++ b/Zam347_scripts/DataPreprocess_to_tf.sh @@ -0,0 +1,8 @@ +#!/bin/bash -x + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) +source_dir=/home/${USER}/preprocessedData/ +destination_dir=/home/${USER}/preprocessedData/ + + +python ../video_prediction/datasets/era5_dataset_v2.py ${source_dir}/splits/hickle ${destination_dir}/tf_records -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20 diff --git a/Zam347_scripts/generate_era5.sh b/Zam347_scripts/generate_era5.sh new file mode 100755 index 0000000000000000000000000000000000000000..d9d710e5c4f3cc2d2825bf67bf2b668f6f9ddbd8 --- /dev/null +++ b/Zam347_scripts/generate_era5.sh @@ -0,0 +1,18 @@ +#!/bin/bash -x + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) +source_dir=/home/${USER}/preprocessedData/ +checkpoint_dir=/home/${USER}/models/ +results_dir=/home/${USER}/results/ + +# for choosing the model +model=mcnet + +# execute respective Python-script +python -u ../scripts/generate_transfer_learning_finetune.py \ +--input_dir ${source_dir}/tfrecords \ +--dataset_hparams sequence_length=20 --checkpoint ${checkpoint_dir}/${model} \ +--mode test --results_dir ${results_dir} \ +--batch_size 2 --dataset era5 > generate_era5-out.out + +#srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh new file mode 100755 index 0000000000000000000000000000000000000000..aadb25997e2715ac719457c969a6f54982ec93a6 --- /dev/null +++ b/Zam347_scripts/train_era5.sh @@ -0,0 +1,13 @@ +#!/bin/bash -x + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) +source_dir=/home/${USER}/preprocessedData/ +destination_dir=/home/${USER}/models/ + +# for choosing the model +model=mcnet +model_hparams=../hparams/era5/model_hparams.json + +# execute respective Python-script +python ../scripts/train_dummy.py --input_dir ${source_dir}/tfrecords/ --dataset era5 --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/ +#srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/bash/download_and_preprocess_dataset.sh b/bash/download_and_preprocess_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..5779c2b7ff79f84ccb52e1e44cb7c0cd0d4ee154 --- /dev/null +++ b/bash/download_and_preprocess_dataset.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash + +# exit if any command fails +set -e + +if [ "$#" -eq 2 ]; then + if [ $1 = "bair" ]; then + echo "IMAGE_SIZE argument is only applicable to kth dataset" >&2 + exit 1 + fi +elif [ "$#" -ne 1 ]; then + echo "Usage: $0 DATASET_NAME [IMAGE_SIZE]" >&2 + exit 1 +fi +if [ $1 = "bair" ]; then + TARGET_DIR=./data/bair + mkdir -p ${TARGET_DIR} + TAR_FNAME=bair_robot_pushing_dataset_v0.tar + URL=http://rail.eecs.berkeley.edu/datasets/${TAR_FNAME} + echo "Downloading '$1' dataset (this takes a while)" + #wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} Bing: on MacOS system , use curl instead of wget + curl ${URL} -O ${TARGET_DIR}/${TAR_FNAME} + tar -xvf ${TARGET_DIR}/${TAR_FNAME} --strip-components=1 -C ${TARGET_DIR} + rm ${TARGET_DIR}/${TAR_FNAME} + mkdir -p ${TARGET_DIR}/val + # reserve a fraction of the training set for validation + mv ${TARGET_DIR}/train/traj_256_to_511.tfrecords ${TARGET_DIR}/val/ +elif [ $1 = "kth" ]; then + if [ "$#" -eq 2 ]; then + IMAGE_SIZE=$2 + TARGET_DIR=./data/kth_${IMAGE_SIZE} + else + IMAGE_SIZE=64 + TARGET_DIR=./data/kth + fi + echo ${TARGET_DIR} ${IMAGE_SIZE} + mkdir -p ${TARGET_DIR} + mkdir -p ${TARGET_DIR}/raw + echo "Downloading '$1' dataset (this takes a while)" + # TODO Bing: for save time just use walking, need to change back if all the data are needed + #for ACTION in walking jogging running boxing handwaving handclapping; do +# for ACTION in walking; do +# echo "Action: '$ACTION' " +# ZIP_FNAME=${ACTION}.zip +# URL=http://www.nada.kth.se/cvap/actions/${ZIP_FNAME} +# # wget ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME} +# echo "Start downloading action '$ACTION' ULR '$URL' " +# curl ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME} +# unzip ${TARGET_DIR}/raw/${ZIP_FNAME} -d ${TARGET_DIR}/raw/${ACTION} +# echo "Action '$ACTION' data download and unzip " +# done + FRAME_RATE=25 +# mkdir -p ${TARGET_DIR}/processed +# # download files with metadata specifying the subsequences +# TAR_FNAME=kth_meta.tar.gz +# URL=http://rail.eecs.berkeley.edu/models/savp/data/${TAR_FNAME} +# echo "Downloading '${TAR_FNAME}' ULR '$URL' " +# #wget ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME} +# curl ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME} +# tar -xzvf ${TARGET_DIR}/processed/${TAR_FNAME} --strip 1 -C ${TARGET_DIR}/processed + # convert the videos into sequence of downscaled images + echo "Processing '$1' dataset" + #TODO Bing, just use walking for test + #for ACTION in walking jogging running boxing handwaving handclapping; do + #Todo Bing: remove the comments below after testing + for ACTION in walking running; do + for VIDEO_FNAME in ${TARGET_DIR}/raw/${ACTION}/*.avi; do + FNAME=$(basename ${VIDEO_FNAME}) + FNAME=${FNAME%_uncomp.avi} + echo "FNAME '$FNAME' " + # sometimes the directory is not created, so try until it is + while [ ! -d "${TARGET_DIR}/processed/${ACTION}/${FNAME}" ]; do + mkdir -p ${TARGET_DIR}/processed/${ACTION}/${FNAME} + done + ffmpeg -i ${VIDEO_FNAME} -r ${FRAME_RATE} -f image2 -s ${IMAGE_SIZE}x${IMAGE_SIZE} \ + ${TARGET_DIR}/processed/${ACTION}/${FNAME}/image-%03d_${IMAGE_SIZE}x${IMAGE_SIZE}.png + done + done + python video_prediction/datasets/kth_dataset.py ${TARGET_DIR}/processed ${TARGET_DIR} ${IMAGE_SIZE} + rm -rf ${TARGET_DIR}/raw + rm -rf ${TARGET_DIR}/processed +else + echo "Invalid dataset name: '$1' (choose from 'bair', 'kth')" >&2 + exit 1 +fi +echo "Succesfully finished downloadi\ + +ng and preprocessing dataset '$1'" diff --git a/bash/download_and_preprocess_dataset_era5.sh b/bash/download_and_preprocess_dataset_era5.sh new file mode 100644 index 0000000000000000000000000000000000000000..eacc01801b5e323ea8da8d7adc97c8156172fd7b --- /dev/null +++ b/bash/download_and_preprocess_dataset_era5.sh @@ -0,0 +1,127 @@ +#!/usr/bin/env bash + +# exit if any command fails +set -e + + +#if [ "$#" -eq 2 ]; then +# if [ $1 = "bair" ]; then +# echo "IMAGE_SIZE argument is only applicable to kth dataset" >&2 +# exit 1 +# fi +#elif [ "$#" -ne 1 ]; then +# echo "Usage: $0 DATASET_NAME [IMAGE_SIZE]" >&2 +# exit 1 +#fi +#if [ $1 = "bair" ]; then +# TARGET_DIR=./data/bair +# mkdir -p ${TARGET_DIR} +# TAR_FNAME=bair_robot_pushing_dataset_v0.tar +# URL=http://rail.eecs.berkeley.edu/datasets/${TAR_FNAME} +# echo "Downloading '$1' dataset (this takes a while)" +# #wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} Bing: on MacOS system , use curl instead of wget +# curl ${URL} -O ${TARGET_DIR}/${TAR_FNAME} +# tar -xvf ${TARGET_DIR}/${TAR_FNAME} --strip-components=1 -C ${TARGET_DIR} +# rm ${TARGET_DIR}/${TAR_FNAME} +# mkdir -p ${TARGET_DIR}/val +# # reserve a fraction of the training set for validation +# mv ${TARGET_DIR}/train/traj_256_to_511.tfrecords ${TARGET_DIR}/val/ +#elif [ $1 = "kth" ]; then +# if [ "$#" -eq 2 ]; then +# IMAGE_SIZE=$2 +# TARGET_DIR=./data/kth_${IMAGE_SIZE} +# else +# IMAGE_SIZE=64 +# fi +# echo ${TARGET_DIR} ${IMAGE_SIZE} +# mkdir -p ${TARGET_DIR} +# mkdir -p ${TARGET_DIR}/raw +# echo "Downloading '$1' dataset (this takes a while)" + # TODO Bing: for save time just use walking, need to change back if all the data are needed + #for ACTION in walking jogging running boxing handwaving handclapping; do +# for ACTION in walking; do +# echo "Action: '$ACTION' " +# ZIP_FNAME=${ACTION}.zip +# URL=http://www.nada.kth.se/cvap/actions/${ZIP_FNAME} +# # wget ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME} +# echo "Start downloading action '$ACTION' ULR '$URL' " +# curl ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME} +# unzip ${TARGET_DIR}/raw/${ZIP_FNAME} -d ${TARGET_DIR}/raw/${ACTION} +# echo "Action '$ACTION' data download and unzip " +# done +# FRAME_RATE=25 +# mkdir -p ${TARGET_DIR}/processed +# # download files with metadata specifying the subsequences +# TAR_FNAME=kth_meta.tar.gz +# URL=http://rail.eecs.berkeley.edu/models/savp/data/${TAR_FNAME} +# echo "Downloading '${TAR_FNAME}' ULR '$URL' " +# #wget ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME} +# curl ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME} +# tar -xzvf ${TARGET_DIR}/processed/${TAR_FNAME} --strip 1 -C ${TARGET_DIR}/processed + # convert the videos into sequence of downscaled images +# echo "Processing '$1' dataset" +# #TODO Bing, just use walking for test +# #for ACTION in walking jogging running boxing handwaving handclapping; do +# #Todo Bing: remove the comments below after testing +# for ACTION in walking; do +# for VIDEO_FNAME in ${TARGET_DIR}/raw/${ACTION}/*.avi; do +# FNAME=$(basename ${VIDEO_FNAME}) +# FNAME=${FNAME%_uncomp.avi} +# echo "FNAME '$FNAME' " +# # sometimes the directory is not created, so try until it is +# while [ ! -d "${TARGET_DIR}/processed/${ACTION}/${FNAME}" ]; do +# mkdir -p ${TARGET_DIR}/processed/${ACTION}/${FNAME} +# done +# ffmpeg -i ${VIDEO_FNAME} -r ${FRAME_RATE} -f image2 -s ${IMAGE_SIZE}x${IMAGE_SIZE} \ +# ${TARGET_DIR}/processed/${ACTION}/${FNAME}/image-%03d_${IMAGE_SIZE}x${IMAGE_SIZE}.png +# done +# done +# python video_prediction/datasets/kth_dataset.py ${TARGET_DIR}/processed ${TARGET_DIR} ${IMAGE_SIZE} +# rm -rf ${TARGET_DIR}/raw +# rm -rf ${TARGET_DIR}/processed + +while [[ $# -gt 0 ]] #of the number of passed argument is greater than 0 +do +key="$1" +case $key in + -d|--data) + DATA="$2" + shift + shift + ;; + -i|--input_dir) + INPUT_DIR="$2" + shift + shift + ;; + -o|--output_dir) + OUTPUT_DIR="$2" + shift + shift + ;; +esac +done + +echo "DATA = ${DATA} " + +echo "OUTPUT_DIRECTORY = ${OUTPUT_DIR}" + +if [ -d $INPUT_DIR ]; then + echo "INPUT DIRECTORY = ${INPUT_DIR}" + +else + echo "INPUT DIRECTORY '$INPUT_DIR' DOES NOT EXIST" + exit 1 +fi + + +if [ $DATA = "era5" ]; then + + mkdir -p ${OUTPUT_DIR} + python video_prediction/datasets/era5_dataset.py $INPUT_DIR ${OUTPUT_DIR} +else + echo "dataset name: '$DATA' (choose from 'era5')" >&2 + exit 1 +fi + +echo "Succesfully finished downloading and preprocessing dataset '$DATA' " \ No newline at end of file diff --git a/bash/download_and_preprocess_dataset_v1.sh b/bash/download_and_preprocess_dataset_v1.sh new file mode 100644 index 0000000000000000000000000000000000000000..3541b4a538c089cd79ea2a39c6df0804e11cb0a6 --- /dev/null +++ b/bash/download_and_preprocess_dataset_v1.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash + +# exit if any command fails +set -e + +if [ "$#" -eq 2 ]; then + if [ $1 = "bair" ]; then + echo "IMAGE_SIZE argument is only applicable to kth dataset" >&2 + exit 1 + fi +elif [ "$#" -ne 1 ]; then + echo "Usage: $0 DATASET_NAME [IMAGE_SIZE]" >&2 + exit 1 +fi +if [ $1 = "bair" ]; then + TARGET_DIR=./data/bair + mkdir -p ${TARGET_DIR} + TAR_FNAME=bair_robot_pushing_dataset_v0.tar + URL=http://rail.eecs.berkeley.edu/datasets/${TAR_FNAME} + echo "Downloading '$1' dataset (this takes a while)" + #wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} Bing: on MacOS system , use curl instead of wget + curl ${URL} -O ${TARGET_DIR}/${TAR_FNAME} + tar -xvf ${TARGET_DIR}/${TAR_FNAME} --strip-components=1 -C ${TARGET_DIR} + rm ${TARGET_DIR}/${TAR_FNAME} + mkdir -p ${TARGET_DIR}/val + # reserve a fraction of the training set for validation + mv ${TARGET_DIR}/train/traj_256_to_511.tfrecords ${TARGET_DIR}/val/ +elif [ $1 = "kth" ]; then + if [ "$#" -eq 2 ]; then + IMAGE_SIZE=$2 + TARGET_DIR=./data/kth_${IMAGE_SIZE} + else + IMAGE_SIZE=64 + TARGET_DIR=./data/kth + fi + echo ${TARGET_DIR} ${IMAGE_SIZE} + mkdir -p ${TARGET_DIR} + mkdir -p ${TARGET_DIR}/raw + echo "Downloading '$1' dataset (this takes a while)" + # TODO Bing: for save time just use walking, need to change back if all the data are needed + #for ACTION in walking jogging running boxing handwaving handclapping; do +# for ACTION in walking; do +# echo "Action: '$ACTION' " +# ZIP_FNAME=${ACTION}.zip +# URL=http://www.nada.kth.se/cvap/actions/${ZIP_FNAME} +# # wget ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME} +# echo "Start downloading action '$ACTION' ULR '$URL' " +# curl ${URL} -O ${TARGET_DIR}/raw/${ZIP_FNAME} +# unzip ${TARGET_DIR}/raw/${ZIP_FNAME} -d ${TARGET_DIR}/raw/${ACTION} +# echo "Action '$ACTION' data download and unzip " +# done + FRAME_RATE=25 +# mkdir -p ${TARGET_DIR}/processed +# # download files with metadata specifying the subsequences +# TAR_FNAME=kth_meta.tar.gz +# URL=http://rail.eecs.berkeley.edu/models/savp/data/${TAR_FNAME} +# echo "Downloading '${TAR_FNAME}' ULR '$URL' " +# #wget ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME} +# curl ${URL} -O ${TARGET_DIR}/processed/${TAR_FNAME} +# tar -xzvf ${TARGET_DIR}/processed/${TAR_FNAME} --strip 1 -C ${TARGET_DIR}/processed + # convert the videos into sequence of downscaled images + echo "Processing '$1' dataset" + #TODO Bing, just use walking for test + #for ACTION in walking jogging running boxing handwaving handclapping; do + #Todo Bing: remove the comments below after testing + for ACTION in walking; do + for VIDEO_FNAME in ${TARGET_DIR}/raw/${ACTION}/*.avi; do + FNAME=$(basename ${VIDEO_FNAME}) + FNAME=${FNAME%_uncomp.avi} + echo "FNAME '$FNAME' " + # sometimes the directory is not created, so try until it is + while [ ! -d "${TARGET_DIR}/processed/${ACTION}/${FNAME}" ]; do + mkdir -p ${TARGET_DIR}/processed/${ACTION}/${FNAME} + done + ffmpeg -i ${VIDEO_FNAME} -r ${FRAME_RATE} -f image2 -s ${IMAGE_SIZE}x${IMAGE_SIZE} \ + ${TARGET_DIR}/processed/${ACTION}/${FNAME}/image-%03d_${IMAGE_SIZE}x${IMAGE_SIZE}.png + done + done + python video_prediction/datasets/kth_dataset.py ${TARGET_DIR}/processed ${TARGET_DIR} ${IMAGE_SIZE} + rm -rf ${TARGET_DIR}/raw + rm -rf ${TARGET_DIR}/processed +else + echo "Invalid dataset name: '$1' (choose from 'bair', 'kth')" >&2 + exit 1 +fi +echo "Succesfully finished downloading and preprocessing dataset '$1'" diff --git a/bash/workflow_era5.sh b/bash/workflow_era5.sh new file mode 100755 index 0000000000000000000000000000000000000000..01d16bfdf7f38ffe00495ba31f85349d9ce68335 --- /dev/null +++ b/bash/workflow_era5.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash +set -e +# +#MODEL=savp +##train_mode: end_to_end, pre_trained +#TRAIN_MODE=end_to_end +#EXP_NAME=era5_size_64_64_3_3t_norm + +MODEL=$1 +TRAIN_MODE=$2 +EXP_NAME=$3 +RETRAIN=1 #if we continue training the model or using the existing end-to-end model, 1 means continue training, and 1 means use the existing one +DATA_ETL_DIR=/p/scratch/deepacf/${USER}/ +DATA_EXTRA_DIR=${DATA_ETL_DIR}/extractedData/${EXP_NAME} +DATA_PREPROCESS_DIR=${DATA_ETL_DIR}/preprocessedData/${EXP_NAME} +DATA_PREPROCESS_TF_DIR=./data/${EXP_NAME} +RESULTS_OUTPUT_DIR=./results_test_samples/${EXP_NAME}/${TRAIN_MODE}/ + +if [ $MODEL==savp ]; then + method_dir=ours_savp +elif [ $MODEL==gan ]; then + method_dir=ours_gan +elif [ $MODEL==vae ]; then + method_dir=ours_vae +else + echo "model does not exist" 2>&1 + exit 1 +fi + +if [ "$TRAIN_MODE" == pre_trained ]; then + TRAIN_OUTPUT_DIR=./pretrained_models/kth/${method_dir} +else + TRAIN_OUTPUT_DIR=./logs/${EXP_NAME}/${TRAIN_MODE} +fi + +CHECKPOINT_DIR=${TRAIN_OUTPUT_DIR}/${method_dir} + +echo "===========================WORKFLOW SETUP====================" +echo "Model ${MODEL}" +echo "TRAIN MODE ${TRAIN_MODE}" +echo "Method_dir ${method_dir}" +echo "DATA_ETL_DIR ${DATA_ETL_DIR}" +echo "DATA_EXTRA_DIR ${DATA_EXTRA_DIR}" +echo "DATA_PREPROCESS_DIR ${DATA_PREPROCESS_DIR}" +echo "DATA_PREPROCESS_TF_DIR ${DATA_PREPROCESS_TF_DIR}" +echo "TRAIN_OUTPUT_DIR ${TRAIN_OUTPUT_DIR}" +echo "=============================================================" + +##############Datat Preprocessing################ +#To hkl data +if [ -d "$DATA_PREPROCESS_DIR" ]; then + echo "The Preprocessed Data (.hkl ) exist" +else + python ../workflow_video_prediction/DataPreprocess/benchmark/mpi_stager_v2_process_netCDF.py \ + --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR} +fi + +#Change the .hkl data to .tfrecords files +if [ -d "$DATA_PREPROCESS_TF_DIR" ] +then + echo "Step2: The Preprocessed Data (tf.records) exist" +else + echo "Step2: start, hkl. files to tf.records" + python ./video_prediction/datasets/era5_dataset_v2.py --source_dir ${DATA_PREPROCESS_DIR}/splits \ + --destination_dir ${DATA_PREPROCESS_TF_DIR} + echo "Step2: finish" +fi + +#########Train########################## +if [ "$TRAIN_MODE" == "pre_trained" ]; then + echo "step3: Using kth pre_trained model" +elif [ "$TRAIN_MODE" == "end_to_end" ]; then + echo "step3: End-to-end training" + if [ "$RETRAIN" == 1 ]; then + echo "Using the existing end-to-end model" + else + echo "Step3: Training Starts " + python ./scripts/train_v2.py --input_dir $DATA_PREPROCESS_TF_DIR --dataset era5 \ + --model ${MODEL} --model_hparams_dict hparams/kth/${method_dir}/model_hparams.json \ + --output_dir ${TRAIN_OUTPUT_DIR} --checkpoint ${CHECKPOINT_DIR_DIR} + echo "Training ends " + fi +else + echo "TRAIN_MODE is end_to_end or pre_trained" + exit 1 +fi + +#########Generate results################# +echo "Step4: Postprocessing start" +python ./scripts/generate_transfer_learning_finetune.py --input_dir ${DATA_PREPROCESS_TF_DIR} \ +--dataset_hparams sequence_length=20 --checkpoint ${CHECKPOINT_DIR_DIR} --mode test --results_dir ${RESULTS_OUTPUT_DIR} \ +--batch_size 4 --dataset era5 \ No newline at end of file diff --git a/bash/workflow_era5_macOS.sh b/bash/workflow_era5_macOS.sh new file mode 100755 index 0000000000000000000000000000000000000000..1a6ebef38df877b8ee20f628d4e375a20e7c8bd5 --- /dev/null +++ b/bash/workflow_era5_macOS.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +set -e +# +#MODEL=savp +##train_mode: end_to_end, pre_trained +#TRAIN_MODE=end_to_end +#EXP_NAME=era5_size_64_64_3_3t_norm + +MODEL=$1 +TRAIN_MODE=$2 +EXP_NAME=$3 +RETRAIN=1 #if we continue training the model or using the existing end-to-end model, 1 means continue training, and 1 means use the existing one +DATA_ETL_DIR=/p/scratch/deepacf/${USER}/ +DATA_ETL_DIR=/p/scratch/deepacf/${USER}/ +DATA_EXTRA_DIR=${DATA_ETL_DIR}/extractedData/${EXP_NAME} +DATA_PREPROCESS_DIR=${DATA_ETL_DIR}/preprocessedData/${EXP_NAME} +DATA_PREPROCESS_TF_DIR=./data/${EXP_NAME} +RESULTS_OUTPUT_DIR=./results_test_samples/${EXP_NAME}/${TRAIN_MODE}/ + +if [ $MODEL==savp ]; then + method_dir=ours_savp +elif [ $MODEL==gan ]; then + method_dir=ours_gan +elif [ $MODEL==vae ]; then + method_dir=ours_vae +else + echo "model does not exist" 2>&1 + exit 1 +fi + +if [ "$TRAIN_MODE" == pre_trained ]; then + TRAIN_OUTPUT_DIR=./pretrained_models/kth/${method_dir} +else + TRAIN_OUTPUT_DIR=./logs/${EXP_NAME}/${TRAIN_MODE} +fi + +CHECKPOINT_DIR=${TRAIN_OUTPUT_DIR}/${method_dir} + +echo "===========================WORKFLOW SETUP====================" +echo "Model ${MODEL}" +echo "TRAIN MODE ${TRAIN_MODE}" +echo "Method_dir ${method_dir}" +echo "DATA_ETL_DIR ${DATA_ETL_DIR}" +echo "DATA_EXTRA_DIR ${DATA_EXTRA_DIR}" +echo "DATA_PREPROCESS_DIR ${DATA_PREPROCESS_DIR}" +echo "DATA_PREPROCESS_TF_DIR ${DATA_PREPROCESS_TF_DIR}" +echo "TRAIN_OUTPUT_DIR ${TRAIN_OUTPUT_DIR}" +echo "=============================================================" + +##############Datat Preprocessing################ +#To hkl data +#if [ -d "$DATA_PREPROCESS_DIR" ]; then +# echo "The Preprocessed Data (.hkl ) exist" +#else +# python ../workflow_video_prediction/DataPreprocess/benchmark/mpi_stager_v2_process_netCDF.py \ +# --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR} +#fi + +####Change the .hkl data to .tfrecords files +if [ -d "$DATA_PREPROCESS_TF_DIR" ] +then + echo "Step2: The Preprocessed Data (tf.records) exist" +else + echo "Step2: start, hkl. files to tf.records" + python ./video_prediction/datasets/era5_dataset_v2.py --source_dir ${DATA_PREPROCESS_DIR}/splits \ + --destination_dir ${DATA_PREPROCESS_TF_DIR} + echo "Step2: finish" +fi + +#########Train########################## +if [ "$TRAIN_MODE" == "pre_trained" ]; then + echo "step3: Using kth pre_trained model" +elif [ "$TRAIN_MODE" == "end_to_end" ]; then + echo "step3: End-to-end training" + if [ "$RETRAIN" == 1 ]; then + echo "Using the existing end-to-end model" + else + echo "Training Starts " + python ./scripts/train_v2.py --input_dir $DATA_PREPROCESS_TF_DIR --dataset era5 \ + --model ${MODEL} --model_hparams_dict hparams/kth/${method_dir}/model_hparams.json \ + --output_dir ${TRAIN_OUTPUT_DIR} --checkpoint ${CHECKPOINT_DIR} + echo "Training ends " + fi +else + echo "TRAIN_MODE is end_to_end or pre_trained" + exit 1 +fi + +#########Generate results################# +echo "Step4: Postprocessing start" +python ./scripts/generate_transfer_learning_finetune.py --input_dir ${DATA_PREPROCESS_TF_DIR} \ +--dataset_hparams sequence_length=20 --checkpoint ${CHECKPOINT_DIR} --mode test --results_dir ${RESULTS_OUTPUT_DIR} \ +--batch_size 4 --dataset era5 diff --git a/bash/workflow_era5_zam347.sh b/bash/workflow_era5_zam347.sh new file mode 100755 index 0000000000000000000000000000000000000000..ffe7209b6099f4ad9f57b4e90247a7d7acaf009d --- /dev/null +++ b/bash/workflow_era5_zam347.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +set -e +# +#MODEL=savp +##train_mode: end_to_end, pre_trained +#TRAIN_MODE=end_to_end +#EXP_NAME=era5_size_64_64_3_3t_norm + +MODEL=$1 +TRAIN_MODE=$2 +EXP_NAME=$3 +RETRAIN=1 #if we continue training the model or using the existing end-to-end model, 1 means continue training, and 1 means use the existing one +DATA_ETL_DIR=/home/${USER}/ +DATA_ETL_DIR=/p/scratch/deepacf/${USER}/ +DATA_EXTRA_DIR=${DATA_ETL_DIR}/extractedData/${EXP_NAME} +DATA_PREPROCESS_DIR=${DATA_ETL_DIR}/preprocessedData/${EXP_NAME} +DATA_PREPROCESS_TF_DIR=./data/${EXP_NAME} +RESULTS_OUTPUT_DIR=./results_test_samples/${EXP_NAME}/${TRAIN_MODE}/ + +if [ $MODEL==savp ]; then + method_dir=ours_savp +elif [ $MODEL==gan ]; then + method_dir=ours_gan +elif [ $MODEL==vae ]; then + method_dir=ours_vae +else + echo "model does not exist" 2>&1 + exit 1 +fi + +if [ "$TRAIN_MODE" == pre_trained ]; then + TRAIN_OUTPUT_DIR=./pretrained_models/kth/${method_dir} +else + TRAIN_OUTPUT_DIR=./logs/${EXP_NAME}/${TRAIN_MODE} +fi + +CHECKPOINT_DIR=${TRAIN_OUTPUT_DIR}/${method_dir} + +echo "===========================WORKFLOW SETUP====================" +echo "Model ${MODEL}" +echo "TRAIN MODE ${TRAIN_MODE}" +echo "Method_dir ${method_dir}" +echo "DATA_ETL_DIR ${DATA_ETL_DIR}" +echo "DATA_EXTRA_DIR ${DATA_EXTRA_DIR}" +echo "DATA_PREPROCESS_DIR ${DATA_PREPROCESS_DIR}" +echo "DATA_PREPROCESS_TF_DIR ${DATA_PREPROCESS_TF_DIR}" +echo "TRAIN_OUTPUT_DIR ${TRAIN_OUTPUT_DIR}" +echo "=============================================================" + +##############Datat Preprocessing################ +#To hkl data +#if [ -d "$DATA_PREPROCESS_DIR" ]; then +# echo "The Preprocessed Data (.hkl ) exist" +#else +# python ../workflow_video_prediction/DataPreprocess/benchmark/mpi_stager_v2_process_netCDF.py \ +# --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR} +#fi + +####Change the .hkl data to .tfrecords files +if [ -d "$DATA_PREPROCESS_TF_DIR" ] +then + echo "Step2: The Preprocessed Data (tf.records) exist" +else + echo "Step2: start, hkl. files to tf.records" + python ./video_prediction/datasets/era5_dataset_v2.py --source_dir ${DATA_PREPROCESS_DIR}/splits \ + --destination_dir ${DATA_PREPROCESS_TF_DIR} + echo "Step2: finish" +fi + +#########Train########################## +if [ "$TRAIN_MODE" == "pre_trained" ]; then + echo "step3: Using kth pre_trained model" +elif [ "$TRAIN_MODE" == "end_to_end" ]; then + echo "step3: End-to-end training" + if [ "$RETRAIN" == 1 ]; then + echo "Using the existing end-to-end model" + else + echo "Training Starts " + python ./scripts/train_v2.py --input_dir $DATA_PREPROCESS_TF_DIR --dataset era5 \ + --model ${MODEL} --model_hparams_dict hparams/kth/${method_dir}/model_hparams.json \ + --output_dir ${TRAIN_OUTPUT_DIR} --checkpoint ${CHECKPOINT_DIR} + echo "Training ends " + fi +else + echo "TRAIN_MODE is end_to_end or pre_trained" + exit 1 +fi + +#########Generate results################# +echo "Step4: Postprocessing start" +python ./scripts/generate_transfer_learning_finetune.py --input_dir ${DATA_PREPROCESS_TF_DIR} \ +--dataset_hparams sequence_length=20 --checkpoint ${CHECKPOINT_DIR} --mode test --results_dir ${RESULTS_OUTPUT_DIR} \ +--batch_size 4 --dataset era5 diff --git a/deleteme.txt b/deleteme.txt deleted file mode 100644 index e61ef7b965e17c62ca23b6ff5f0aaf09586e10e9..0000000000000000000000000000000000000000 --- a/deleteme.txt +++ /dev/null @@ -1 +0,0 @@ -aa diff --git a/docs/discussion/discussion.md b/docs/discussion/discussion.md new file mode 100644 index 0000000000000000000000000000000000000000..ff1d00c4c064c72c13029e84fbd0c18c3e4ff59b --- /dev/null +++ b/docs/discussion/discussion.md @@ -0,0 +1,5 @@ +This is the list of last-mins files for VP group + +## 2020-03-01 - 2020-03-31 + +- https://docs.google.com/document/d/1cQUEWrenIlW1zebZwSSHpfka2Bhb8u63kPM3x7nya_o/edit#heading=h.yjmq51s4fxnm \ No newline at end of file diff --git a/docs/presentation/presentation.md b/docs/presentation/presentation.md new file mode 100644 index 0000000000000000000000000000000000000000..d49239089d5d881ef7a42e5e847ee45c3be725d4 --- /dev/null +++ b/docs/presentation/presentation.md @@ -0,0 +1,5 @@ +This is the presentation materials for VP group + + +## 2020-03-01 - 2020-03-31 +https://docs.google.com/presentation/d/18EJKBJJ2LHI7uNU_l8s_Cm-aGZhw9tkoQ8BxqYZfkWk/edit#slide=id.g71f805bc32_0_80 diff --git a/docs/structure_name_convention.md b/docs/structure_name_convention.md new file mode 100644 index 0000000000000000000000000000000000000000..4a2679c83ea8d99b9562ef775ed2ac1190f5d7fb --- /dev/null +++ b/docs/structure_name_convention.md @@ -0,0 +1,108 @@ +This is the output folder structure and name convention + +## Shared folder structure + +``` +├── ExtractedData +│ ├── [Year] +│ │ ├── [Month] +│ │ │ ├── **/*.netCDF +├── PreprocessedData +│ ├── [Data_name_convention] +│ │ ├── hickle +│ │ │ ├── train +│ │ │ ├── val +│ │ │ ├── test +│ │ ├── tfrecords +│ │ │ ├── train +│ │ │ ├── val +│ │ │ ├── test +├── Models +│ ├── [Data_name_convention] +│ │ ├── [model_name] +│ │ ├── [model_name] +├── Results +│ ├── [Data_name_convention] +│ │ ├── [training_mode] +│ │ │ ├── [source_data_name_convention] +│ │ │ │ ├── [model_name] + +``` + +| Arguments | Value | +|--- |--- | +| [Year] | 2005;2005;2007 ...| +| [Month] | 01;02;03 ...,12| +|[Data_name_convention]|Y[yyyy]to[yyyy]M[mm]to[mm]-[nx]_[ny]-[nn.nn]N[ee.ee]E-[var1]_[var2]_[var3]| +|[model_name]| Ours_savp; ours_gan; ours_vae; prednet| +|[training_mode]|end_to_end; transfer_learning| + + +## Data name convention + +`Y[yyyy]to[yyyy]M[mm]to[mm]-[nx]_[ny]-[nn.nn]N[ee.ee]E-[var1]_[var2]_[var3]` + + - Y[yyyy]to[yyyy]M[mm]to[mm] + - [nx]_[ny] : the size of images,e.g 64_64 means 64*64 pixels + - [nn.nn]N[ee.ee]E :the geolocation of selected regions with two decimal points. e.g : 0.00N11.50E + - [var1]_[var2]_[var3] : [Use the abbrevation of selected variables](#variable-abbrevaition-and-the-corresponding-full-names) + +### `Y[yyyy]to[yyyy]M[mm]to[mm]` + +| Examples | Name abbrevation | +|--- |--- | +|all data from March to June of the years 2005-2015 | Y2005toY2015M03to06 | +|data from February to May of years 2005-2008 + data from March to June of year 2015| Y2005to2008M02to05_Y2015M03to06 | +|Data from February to May, and October to December of 2005 | Y2005M02to05_Y2015M10to12 | +|operational’ data base: whole year 2016 | Y2016M01to12 | +|add new whole year data of 2017 on the operational data base |Y2016to2017M01to12 | +| Note: Y2016to2017M01to12 = Y2016M01to12_Y2017M01to12| + + +### variable abbrevaition and the corresponding full names + +| var | full names | +|--- |--- | +|T|2m temperature| +|gph500|500 hPa geopotential| +|msl|meansealevelpressure| + + + +### Example + +``` +├── ExtractedData +│ ├── 2016 +│ │ ├── 01 +│ │ │ ├── *.netCDF +│ │ ├── 02 +│ │ ├── 03 +│ │ ├── … +│ ├── 2017 +│ │ ├── 01 +│ │ ├── … +├── PreprocessedData +│ ├── 2016to2017M01to12-64_64-50.00N11.50E-T_T_T +│ │ ├── hickle +│ │ │ ├── train +│ │ │ ├── val +│ │ │ ├── test +│ │ ├── tfrecords +│ │ │ ├── train +│ │ │ ├── val +│ │ │ ├── test +├── Models +│ ├── 2016to2017M01to12-64_64-50.00N11.50E-T_T_T +│ │ ├── outs_savp +│ │ ├── outs_gan +├── Results +│ ├── 2016to2017M01to12-64_64-50.00N11.50E-T_T_T +│ │ ├── end_to_end +│ │ │ ├── ours_savp +│ │ │ ├── ours_gan +│ │ ├── transfer_learning +│ │ │ ├── 2018M01to12-64_64-50.00N11.50E-T_T_T +│ │ │ │ ├── ours_savp +``` + diff --git a/env_setup/create_env.sh b/env_setup/create_env.sh new file mode 100755 index 0000000000000000000000000000000000000000..7d0f0a10bd8586e59fe129198a5a6f7c21121502 --- /dev/null +++ b/env_setup/create_env.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +if [[ ! -n "$1" ]]; then + echo "Provide the user name, which will be taken as folder name" + exit 1 +fi + +if [[ ! -n "$2" ]]; then + echo "Provide the env name, which will be taken as folder name" + exit 1 +fi + +ENV_NAME=$2 +FOLDER_NAME=$1 +WORKING_DIR=/p/project/deepacf/deeprain/${FOLDER_NAME}/video_prediction_savp +ENV_SETUP_DIR=${WORKING_DIR}/env_setup +ENV_DIR=${WORKING_DIR}/${ENV_NAME} + +source ${ENV_SETUP_DIR}/modules.sh +# Install additional Python packages. +python3 -m venv $ENV_DIR +source ${ENV_DIR}/bin/activate +pip3 install -r ${ENV_SETUP_DIR}/requirements.txt +#pip3 install --user netCDF4 +#pip3 install --user numpy + +#Copy the hickle package from bing's account +cp -r /p/project/deepacf/deeprain/bing/hickle ${WORKING_DIR} + +source ${ENV_SETUP_DIR}/modules.sh +source ${ENV_DIR}/bin/activate + +export PYTHONPATH=${WORKING_DIR}/hickle/lib/python3.6/site-packages:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH +export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH +#export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}/lpips-tensorflow:$PYTHONPATH + + diff --git a/env_setup/create_env_zam347.sh b/env_setup/create_env_zam347.sh new file mode 100755 index 0000000000000000000000000000000000000000..5e0c43c5cc826c579b096c7b0aad0236e6b2002f --- /dev/null +++ b/env_setup/create_env_zam347.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + + +if [[ ! -n "$1" ]]; then + echo "Provide the env name, which will be taken as folder name" + exit 1 +fi + +ENV_NAME=$1 +WORKING_DIR=/home/$USER/video_prediction_savp +ENV_SETUP_DIR=${WORKING_DIR}/env_setup +ENV_DIR=${WORKING_DIR}/${ENV_NAME} +unset PYTHONPATH +#source ${ENV_SETUP_DIR}/modules.sh +# Install additional Python packages. +python3 -m venv $ENV_DIR +source ${ENV_DIR}/bin/activate +pip3 install --upgrade pip +pip3 install -r ${ENV_SETUP_DIR}/requirements.txt +#conda install mpi4py +pip3 install mpi4py +pip3 install netCDF4 +pip3 install numpy +pip3 install h5py +pip3 install tensorflow-gpu==1.13.1 +#Copy the hickle package from bing's account +#cp -r /p/project/deepacf/deeprain/bing/hickle ${WORKING_DIR} +cp -r /home/b.gong/video_prediction_savp/hickle ${WORKING_DIR} + +#source ${ENV_SETUP_DIR}/modules.sh +#source ${ENV_DIR}/bin/activate + +#export PYTHONPATH=/home/$USER/miniconda3/pkgs:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}/hickle/lib/python3.6/site-packages:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH +#export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH +#export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}/lpips-tensorflow:$PYTHONPATH + + diff --git a/env_setup/modules.sh b/env_setup/modules.sh new file mode 100755 index 0000000000000000000000000000000000000000..e6793787ad59988cfc6646dc8dd789d1573c6b23 --- /dev/null +++ b/env_setup/modules.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +module purge +module use $OTHERSTAGES +module load Stages/2019a +module load GCC/8.3.0 +module load MVAPICH2/.2.3.1-GDR +module load GCCcore/.8.3.0 +module load mpi4py/3.0.1-Python-3.6.8 +module load h5py/2.9.0-serial-Python-3.6.8 +module load TensorFlow/1.13.1-GPU-Python-3.6.8 +module load cuDNN/7.5.1.10-CUDA-10.1.105 + diff --git a/env_setup/requirements.txt b/env_setup/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..76dd1f57d64577cc565968bb7106656e53687261 --- /dev/null +++ b/env_setup/requirements.txt @@ -0,0 +1,4 @@ +opencv-python +scipy +scikit-image +pandas diff --git a/geo_info.json b/geo_info.json new file mode 100644 index 0000000000000000000000000000000000000000..911a7c3b1333c4e815db705197cf77cb107de8a8 --- /dev/null +++ b/geo_info.json @@ -0,0 +1 @@ +{"lat": [58.19999694824219, 57.89999771118164, 57.599998474121094, 57.29999923706055, 57.0, 56.69999694824219, 56.39999771118164, 56.099998474121094, 55.79999923706055, 55.5, 55.19999694824219, 54.89999771118164, 54.599998474121094, 54.29999923706055, 54.0, 53.69999694824219, 53.39999771118164, 53.099998474121094, 52.79999923706055, 52.5, 52.19999694824219, 51.89999771118164, 51.599998474121094, 51.29999923706055, 51.0, 50.69999694824219, 50.39999771118164, 50.099998474121094, 49.79999923706055, 49.5, 49.19999694824219, 48.89999771118164, 48.599998474121094, 48.29999923706055, 48.0, 47.69999694824219, 47.39999771118164, 47.099998474121094, 46.79999923706055, 46.5, 46.19999694824219, 45.89999771118164, 45.599998474121094, 45.29999923706055, 45.0, 44.69999694824219, 44.39999771118164, 44.099998474121094, 43.79999923706055, 43.5, 43.19999694824219, 42.89999771118164, 42.599998474121094, 42.29999923706055, 42.0, 41.69999694824219, 41.39999771118164, 41.099998474121094, 40.79999923706055, 40.499996185302734, 40.19999694824219, 39.89999771118164, 39.599998474121094, 39.29999923706055], "lon": [-0.5999755859375, -0.29998779296875, 0.0, 0.30000001192092896, 0.6000000238418579, 0.9000000357627869, 1.2000000476837158, 1.5, 1.8000000715255737, 2.1000001430511475, 2.4000000953674316, 2.700000047683716, 3.0, 3.3000001907348633, 3.6000001430511475, 3.9000000953674316, 4.200000286102295, 4.5, 4.800000190734863, 5.100000381469727, 5.400000095367432, 5.700000286102295, 6.0, 6.300000190734863, 6.600000381469727, 6.900000095367432, 7.200000286102295, 7.500000476837158, 7.800000190734863, 8.100000381469727, 8.40000057220459, 8.700000762939453, 9.0, 9.300000190734863, 9.600000381469727, 9.90000057220459, 10.200000762939453, 10.5, 10.800000190734863, 11.100000381469727, 11.40000057220459, 11.700000762939453, 12.0, 12.300000190734863, 12.600000381469727, 12.90000057220459, 13.200000762939453, 13.500000953674316, 13.800000190734863, 14.100000381469727, 14.40000057220459, 14.700000762939453, 15.000000953674316, 15.300000190734863, 15.600000381469727, 15.90000057220459, 16.200000762939453, 16.5, 16.80000114440918, 17.100000381469727, 17.400001525878906, 17.700000762939453, 18.0, 18.30000114440918]} \ No newline at end of file diff --git a/helper/helper.py b/helper/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..e85a4df8a8ad9811b2634983db5eac8bf0204b47 --- /dev/null +++ b/helper/helper.py @@ -0,0 +1,53 @@ +import logging +import time +from functools import wraps + +def logDecorator(fn,verbose=False): + @wraps(fn) + def wrapper(*args,**kwargs): + print("inside wrapper of log decorator function") + logger = logging.getLogger(fn.__name__) + # create a file handler + handler = logging.FileHandler("log.log") + logger.setLevel(logging.DEBUG if verbose else logging.INFO) + #create a console handler + ch = logging.StreamHandler() + logger.setLevel(logging.DEBUG if verbose else logging.INFO) + # create a logging format + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + ch.setFormatter(formatter) + logger.addHandler(handler) + logger.addHandler(ch) + logger.info("Logging 1") + start = time.time() + results = fn(*args,**kwargs) + end = time.time() + logger.info("{} ran in {}s".format(fn.__name__, round(end - start, 2))) + return results + return wrapper + + +#logger = logging.getLogger(__name__) +# def set_logger(verbose=False): +# # Remove all handlers associated with the root logger object. +# for handler in logging.root.handlers[:]: +# logging.root.removeHandler(handler) +# logger = logging.getLogger(__name__) +# logger.propagate = False +# +# +# if not logger.handlers: +# logger.setLevel(logging.DEBUG if verbose else logging.INFO) +# formatter = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +# +# +# +# #再创建一个handler,用于输出到控制台 +# console_handler = logging.StreamHandler() +# console_handler.setLevel(logging.DEBUG if verbose else logging.INFO) +# console_handler.setFormatter(formatter) +# logger.handlers = [] +# logger.addHandler(console_handler) +# +# return logger \ No newline at end of file diff --git a/hparams/bair_action_free/ours_deterministic_l1/model_hparams.json b/hparams/bair_action_free/ours_deterministic_l1/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..4a1b23edcb68b57dadee82b1c13366afac50a52a --- /dev/null +++ b/hparams/bair_action_free/ours_deterministic_l1/model_hparams.json @@ -0,0 +1,13 @@ +{ + "batch_size": 32, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 1.0, + "l2_weight": 0.0, + "kl_weight": 0.0, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0, + "nz": 0 +} \ No newline at end of file diff --git a/hparams/bair_action_free/ours_deterministic_l2/model_hparams.json b/hparams/bair_action_free/ours_deterministic_l2/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..31e7152ae15df5ee33b264f11c88c76c50592185 --- /dev/null +++ b/hparams/bair_action_free/ours_deterministic_l2/model_hparams.json @@ -0,0 +1,13 @@ +{ + "batch_size": 32, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 0.0, + "l2_weight": 1.0, + "kl_weight": 0.0, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0, + "nz": 0 +} \ No newline at end of file diff --git a/hparams/bair_action_free/ours_gan/model_hparams.json b/hparams/bair_action_free/ours_gan/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..38837822c90f38c6209dfa27019a90ccdf8ea43a --- /dev/null +++ b/hparams/bair_action_free/ours_gan/model_hparams.json @@ -0,0 +1,14 @@ +{ + "batch_size": 16, + "lr": 0.0002, + "beta1": 0.5, + "beta2": 0.999, + "l1_weight": 100.0, + "l2_weight": 0.0, + "kl_weight": 0.0, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.1, + "vae_gan_feature_cdist_weight": 0.0, + "gan_feature_cdist_weight": 10.0, + "state_weight": 0.0 +} \ No newline at end of file diff --git a/hparams/bair_action_free/ours_savp/model_hparams.json b/hparams/bair_action_free/ours_savp/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..a6eea83a19505d374e9d614f48ef8bc72443c0f2 --- /dev/null +++ b/hparams/bair_action_free/ours_savp/model_hparams.json @@ -0,0 +1,14 @@ +{ + "batch_size": 16, + "lr": 0.0002, + "beta1": 0.5, + "beta2": 0.999, + "l1_weight": 100.0, + "l2_weight": 0.0, + "kl_weight": 1.0, + "video_sn_vae_gan_weight": 0.1, + "video_sn_gan_weight": 0.1, + "vae_gan_feature_cdist_weight": 10.0, + "gan_feature_cdist_weight": 0.0, + "state_weight": 0.0 +} \ No newline at end of file diff --git a/hparams/bair_action_free/ours_vae_l1/model_hparams.json b/hparams/bair_action_free/ours_vae_l1/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..827757e11b75e720d236417be449b7a301a005ec --- /dev/null +++ b/hparams/bair_action_free/ours_vae_l1/model_hparams.json @@ -0,0 +1,12 @@ +{ + "batch_size": 32, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 1.0, + "l2_weight": 0.0, + "kl_weight": 0.001, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0 +} \ No newline at end of file diff --git a/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json b/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..4fddf0eef1d45dbfa16f098e5e42b12f594132e3 --- /dev/null +++ b/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json @@ -0,0 +1,12 @@ +{ + "batch_size": 32, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 0.0, + "l2_weight": 1.0, + "kl_weight": 0.001, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0 +} \ No newline at end of file diff --git a/hparams/era5/model_hparams.json b/hparams/era5/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..b121ee2f005b6db753b2536deb804204dd41b78d --- /dev/null +++ b/hparams/era5/model_hparams.json @@ -0,0 +1,11 @@ +{ + "batch_size": 8, + "lr": 0.001, + "nz": 16, + "max_steps":500, + "context_frames":10, + "sequence_length":20 + +} + + diff --git a/hparams/kth/ours_deterministic_l1/model_hparams.json b/hparams/kth/ours_deterministic_l1/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..4a1b23edcb68b57dadee82b1c13366afac50a52a --- /dev/null +++ b/hparams/kth/ours_deterministic_l1/model_hparams.json @@ -0,0 +1,13 @@ +{ + "batch_size": 32, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 1.0, + "l2_weight": 0.0, + "kl_weight": 0.0, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0, + "nz": 0 +} \ No newline at end of file diff --git a/hparams/kth/ours_deterministic_l2/model_hparams.json b/hparams/kth/ours_deterministic_l2/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..31e7152ae15df5ee33b264f11c88c76c50592185 --- /dev/null +++ b/hparams/kth/ours_deterministic_l2/model_hparams.json @@ -0,0 +1,13 @@ +{ + "batch_size": 32, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 0.0, + "l2_weight": 1.0, + "kl_weight": 0.0, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0, + "nz": 0 +} \ No newline at end of file diff --git a/hparams/kth/ours_gan/model_hparams.json b/hparams/kth/ours_gan/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..3d14b63edbf14efca2cefe4703453d899b3fb0fd --- /dev/null +++ b/hparams/kth/ours_gan/model_hparams.json @@ -0,0 +1,15 @@ +{ + "batch_size": 16, + "lr": 0.0002, + "beta1": 0.5, + "beta2": 0.999, + "l1_weight": 100.0, + "l2_weight": 0.0, + "kl_weight": 0.0, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.1, + "vae_gan_feature_cdist_weight": 0.0, + "gan_feature_cdist_weight": 10.0, + "state_weight": 0.0, + "nz": 32 +} \ No newline at end of file diff --git a/hparams/kth/ours_savp/model_hparams.json b/hparams/kth/ours_savp/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..66b41f87e3c0f417b492314060121a0bfd01c8f9 --- /dev/null +++ b/hparams/kth/ours_savp/model_hparams.json @@ -0,0 +1,18 @@ +{ + "batch_size": 8, + "lr": 0.0002, + "beta1": 0.5, + "beta2": 0.999, + "l1_weight": 100.0, + "l2_weight": 0.0, + "kl_weight": 0.01, + "video_sn_vae_gan_weight": 0.1, + "video_sn_gan_weight": 0.1, + "vae_gan_feature_cdist_weight": 10.0, + "gan_feature_cdist_weight": 0.0, + "state_weight": 0.0, + "nz": 32, + "max_steps":20 +} + + diff --git a/hparams/kth/ours_vae_l1/model_hparams.json b/hparams/kth/ours_vae_l1/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..dee3ce9f8e431d7f7cb46042936cfae3dcfbc6e4 --- /dev/null +++ b/hparams/kth/ours_vae_l1/model_hparams.json @@ -0,0 +1,13 @@ +{ + "batch_size": 32, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 1.0, + "l2_weight": 0.0, + "kl_weight": 1e-05, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0, + "nz": 32 +} \ No newline at end of file diff --git a/lpips-tensorflow/.gitignore b/lpips-tensorflow/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..894a44cc066a027465cd26d634948d56d13af9af --- /dev/null +++ b/lpips-tensorflow/.gitignore @@ -0,0 +1,104 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/lpips-tensorflow/.gitmodules b/lpips-tensorflow/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..085c5852ff85afe688333807a8a26392d20e8ed3 --- /dev/null +++ b/lpips-tensorflow/.gitmodules @@ -0,0 +1,3 @@ +[submodule "PerceptualSimilarity"] + path = PerceptualSimilarity + url = https://github.com/alexlee-gk/PerceptualSimilarity.git diff --git a/lpips-tensorflow/LICENSE b/lpips-tensorflow/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..afbffa507fd832d614440dd75f81c8ed731ded66 --- /dev/null +++ b/lpips-tensorflow/LICENSE @@ -0,0 +1,25 @@ +BSD 2-Clause License + +Copyright (c) 2018, alexlee-gk +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/lpips-tensorflow/README.md b/lpips-tensorflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..760f5a028e2aae7187135c267edca89db464db5f --- /dev/null +++ b/lpips-tensorflow/README.md @@ -0,0 +1,57 @@ +# lpips-tensorflow +Tensorflow port for the [PyTorch](https://github.com/richzhang/PerceptualSimilarity) implementation of the [Learned Perceptual Image Patch Similarity (LPIPS)](http://richzhang.github.io/PerceptualSimilarity/) metric. +This is done by exporting the model from PyTorch to ONNX and then to TensorFlow. + +## Getting started +### Installation +- Clone this repo. +```bash +git clone https://github.com/alexlee-gk/lpips-tensorflow.git +cd lpips-tensorflow +``` +- Install TensorFlow and dependencies from http://tensorflow.org/ +- Install other dependencies. +```bash +pip install -r requirements.txt +``` + +### Using the LPIPS metric +The `lpips` TensorFlow function works with individual images or batches of images. +It also works with images of any spatial dimensions (but the dimensions should be at least the size of the network's receptive field). +This example computes the LPIPS distance between batches of images. +```python +import numpy as np +import tensorflow as tf +import lpips_tf + +batch_size = 32 +image_shape = (batch_size, 64, 64, 3) +image0 = np.random.random(image_shape) +image1 = np.random.random(image_shape) +image0_ph = tf.placeholder(tf.float32) +image1_ph = tf.placeholder(tf.float32) + +distance_t = lpips_tf.lpips(image0_ph, image1_ph, model='net-lin', net='alex') + +with tf.Session() as session: + distance = session.run(distance_t, feed_dict={image0_ph: image0, image1_ph: image1}) +``` + +## Exporting additional models +### Export PyTorch model to TensorFlow through ONNX +- Clone the PerceptualSimilarity submodule and add it to the PYTHONPATH. +```bash +git submodule update --init --recursive +export PYTHONPATH=PerceptualSimilarity:$PYTHONPATH +``` +- Install more dependencies. +```bash +pip install -r requirements-dev.txt +``` +- Export the model to ONNX *.onnx and TensorFlow *.pb files in the `models` directory. +```bash +python export_to_tensorflow.py --model net-lin --net alex +``` + +### Known issues +- The SqueezeNet model cannot be exported since ONNX cannot export one of the operators. diff --git a/lpips-tensorflow/export_to_tensorflow.py b/lpips-tensorflow/export_to_tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..32681117385903e8e7bfd19d4a7154a99a7ce78f --- /dev/null +++ b/lpips-tensorflow/export_to_tensorflow.py @@ -0,0 +1,58 @@ +import argparse +import os + +import onnx +import torch +import torch.onnx + +from models import dist_model as dm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--model', choices=['net-lin', 'net'], default='net-lin', help='net-lin or net') + parser.add_argument('--net', choices=['squeeze', 'alex', 'vgg'], default='alex', help='squeeze, alex, or vgg') + parser.add_argument('--version', type=str, default='0.1') + parser.add_argument('--image_height', type=int, default=64) + parser.add_argument('--image_width', type=int, default=64) + args = parser.parse_args() + + model = dm.DistModel() + model.initialize(model=args.model, net=args.net, use_gpu=False, version=args.version) + print('Model [%s] initialized' % model.name()) + + dummy_im0 = torch.Tensor(1, 3, args.image_height, args.image_width) # image should be RGB, normalized to [-1, 1] + dummy_im1 = torch.Tensor(1, 3, args.image_height, args.image_width) + + cache_dir = os.path.expanduser('~/.lpips') + os.makedirs(cache_dir, exist_ok=True) + onnx_fname = os.path.join(cache_dir, '%s_%s_v%s.onnx' % (args.model, args.net, args.version)) + + # export model to onnx format + torch.onnx.export(model.net, (dummy_im0, dummy_im1), onnx_fname, verbose=True) + + # load and change dimensions to be dynamic + model = onnx.load(onnx_fname) + for dim in (0, 2, 3): + model.graph.input[0].type.tensor_type.shape.dim[dim].dim_param = '?' + model.graph.input[1].type.tensor_type.shape.dim[dim].dim_param = '?' + + # needs to be imported after all the pytorch stuff, otherwise this causes a segfault + from onnx_tf.backend import prepare + tf_rep = prepare(model, device='CPU') + producer_version = tf_rep.graph.graph_def_versions.producer + pb_fname = os.path.join(cache_dir, '%s_%s_v%s_%d.pb' % (args.model, args.net, args.version, producer_version)) + tf_rep.export_graph(pb_fname) + input0_name, input1_name = [tf_rep.tensor_dict[input_name].name for input_name in tf_rep.inputs] + (output_name,) = [tf_rep.tensor_dict[output_name].name for output_name in tf_rep.outputs] + + # ensure these are the names of the 2 inputs, since that will be assumed when loading the pb file + assert input0_name == '0:0' + assert input1_name == '1:0' + # ensure that the only output is the output of the last op in the graph, since that will be assumed later + (last_output_name,) = [output.name for output in tf_rep.graph.get_operations()[-1].outputs] + assert output_name == last_output_name + + +if __name__ == '__main__': + main() diff --git a/lpips-tensorflow/lpips_tf.py b/lpips-tensorflow/lpips_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..5c47f90b68f3fa05b8b2f29cf6100b6d63e584aa --- /dev/null +++ b/lpips-tensorflow/lpips_tf.py @@ -0,0 +1,90 @@ +import os +import sys + +import tensorflow as tf +from six.moves import urllib + +_URL = 'http://rail.eecs.berkeley.edu/models/lpips' + + +def _download(url, output_dir): + """Downloads the `url` file into `output_dir`. + + Modified from https://github.com/tensorflow/models/blob/master/research/slim/datasets/dataset_utils.py + """ + filename = url.split('/')[-1] + filepath = os.path.join(output_dir, filename) + + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % ( + filename, float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + + filepath, _ = urllib.request.urlretrieve(url, filepath, _progress) + print() + statinfo = os.stat(filepath) + print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') + + +def lpips(input0, input1, model='net-lin', net='alex', version=0.1): + """ + Learned Perceptual Image Patch Similarity (LPIPS) metric. + + Args: + input0: An image tensor of shape `[..., height, width, channels]`, + with values in [0, 1]. + input1: An image tensor of shape `[..., height, width, channels]`, + with values in [0, 1]. + + Returns: + The Learned Perceptual Image Patch Similarity (LPIPS) distance. + + Reference: + Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang. + The Unreasonable Effectiveness of Deep Features as a Perceptual Metric. + In CVPR, 2018. + """ + # flatten the leading dimensions + batch_shape = tf.shape(input0)[:-3] + input0 = tf.reshape(input0, tf.concat([[-1], tf.shape(input0)[-3:]], axis=0)) + input1 = tf.reshape(input1, tf.concat([[-1], tf.shape(input1)[-3:]], axis=0)) + # NHWC to NCHW + input0 = tf.transpose(input0, [0, 3, 1, 2]) + input1 = tf.transpose(input1, [0, 3, 1, 2]) + # normalize to [-1, 1] + input0 = input0 * 2.0 - 1.0 + input1 = input1 * 2.0 - 1.0 + + input0_name, input1_name = '0:0', '1:0' + + default_graph = tf.get_default_graph() + producer_version = default_graph.graph_def_versions.producer + + cache_dir = os.path.expanduser('~/.lpips') + os.makedirs(cache_dir, exist_ok=True) + # files to try. try a specific producer version, but fallback to the version-less version (latest). + pb_fnames = [ + '%s_%s_v%s_%d.pb' % (model, net, version, producer_version), + '%s_%s_v%s.pb' % (model, net, version), + ] + for pb_fname in pb_fnames: + if not os.path.isfile(os.path.join(cache_dir, pb_fname)): + try: + _download(os.path.join(_URL, pb_fname), cache_dir) + except urllib.error.HTTPError: + pass + if os.path.isfile(os.path.join(cache_dir, pb_fname)): + break + + with open(os.path.join(cache_dir, pb_fname), 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + _ = tf.import_graph_def(graph_def, + input_map={input0_name: input0, input1_name: input1}) + distance, = default_graph.get_operations()[-1].outputs + + if distance.shape.ndims == 4: + distance = tf.squeeze(distance, axis=[-3, -2, -1]) + # reshape the leading dimensions + distance = tf.reshape(distance, batch_shape) + return distance diff --git a/lpips-tensorflow/requirements-dev.txt b/lpips-tensorflow/requirements-dev.txt new file mode 100644 index 0000000000000000000000000000000000000000..df36766f1d4cc3b378d7aba0164f3fdfab3be1d2 --- /dev/null +++ b/lpips-tensorflow/requirements-dev.txt @@ -0,0 +1,4 @@ +torch>=0.4.0 +torchvision>=0.2.1 +onnx +onnx-tf diff --git a/lpips-tensorflow/requirements.txt b/lpips-tensorflow/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bc2cbbd1ca22aca38c3fd05d94d07fbf60ed4f6d --- /dev/null +++ b/lpips-tensorflow/requirements.txt @@ -0,0 +1,2 @@ +numpy +six diff --git a/lpips-tensorflow/setup.py b/lpips-tensorflow/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc8d1d35b0b9249c7c0e24dd14efc707b873d92 --- /dev/null +++ b/lpips-tensorflow/setup.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python + +from distutils.core import setup + +setup( + name='lpips-tf', + description='Tensorflow port for the Learned Perceptual Image Patch Similarity (LPIPS) metric', + author='Alex Lee', + url='https://github.com/alexlee-gk/lpips-tensorflow/', + py_modules=['lpips_tf'] +) diff --git a/lpips-tensorflow/test_network.py b/lpips-tensorflow/test_network.py new file mode 100644 index 0000000000000000000000000000000000000000..c222ab931807b207a5ea26d2a7297429dcb7a02b --- /dev/null +++ b/lpips-tensorflow/test_network.py @@ -0,0 +1,42 @@ +import argparse + +import cv2 +import numpy as np +import tensorflow as tf + +import lpips_tf + + +def load_image(fname): + image = cv2.imread(fname) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image.astype(np.float32) / 255.0 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--model', choices=['net-lin', 'net'], default='net-lin', help='net-lin or net') + parser.add_argument('--net', choices=['squeeze', 'alex', 'vgg'], default='alex', help='squeeze, alex, or vgg') + parser.add_argument('--version', type=str, default='0.1') + args = parser.parse_args() + + ex_ref = load_image('./PerceptualSimilarity/imgs/ex_ref.png') + ex_p0 = load_image('./PerceptualSimilarity/imgs/ex_p0.png') + ex_p1 = load_image('./PerceptualSimilarity/imgs/ex_p1.png') + + session = tf.Session() + + image0_ph = tf.placeholder(tf.float32) + image1_ph = tf.placeholder(tf.float32) + lpips_fn = session.make_callable( + lpips_tf.lpips(image0_ph, image1_ph, model=args.model, net=args.net, version=args.version), + [image0_ph, image1_ph]) + + ex_d0 = lpips_fn(ex_ref, ex_p0) + ex_d1 = lpips_fn(ex_ref, ex_p1) + + print('Distances: (%.3f, %.3f)' % (ex_d0, ex_d1)) + + +if __name__ == '__main__': + main() diff --git a/metadata.py b/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..b65db76d71a549b3405f9f10b26dcccb09b30230 --- /dev/null +++ b/metadata.py @@ -0,0 +1,330 @@ +""" +Classes and routines to retrieve and handle meta-data +""" + +import os +import sys +import numpy as np +import json +from netCDF4 import Dataset + +class MetaData: + """ + Class for handling, storing and retrieving meta-data + """ + + def __init__(self,json_file=None,suffix_indir=None,data_filename=None,slices=None,variables=None): + + """ + Initailizes MetaData instance by reading a corresponding json-file or by handling arguments of the Preprocessing step + (i.e. exemplary input file, slices defining region of interest, input variables) + """ + + method_name = MetaData.__init__.__name__+" of Class "+MetaData.__name__ + + if not json_file is None: + MetaData.get_metadata_from_file(json_file) + + else: + # No dictionary from json-file available, all other arguments have to set + if not suffix_indir: + raise TypeError(method_name+": 'suffix_indir'-argument is required if 'json_file' is not passed.") + else: + if not isinstance(suffix_indir,str): + raise TypeError(method_name+": 'suffix_indir'-argument must be a string.") + + if not data_filename: + raise TypeError(method_name+": 'data_filename'-argument is required if 'json_file' is not passed.") + else: + if not isinstance(data_filename,str): + raise TypeError(method_name+": 'data_filename'-argument must be a string.") + + if not slices: + raise TypeError(method_name+": 'slices'-argument is required if 'json_file' is not passed.") + else: + if not isinstance(slices,dict): + raise TypeError(method_name+": 'slices'-argument must be a dictionary.") + + if not variables: + raise TypeError(method_name+": 'variables'-argument is required if 'json_file' is not passed.") + else: + if not isinstance(variables,list): + raise TypeError(method_name+": 'variables'-argument must be a list.") + + MetaData.get_and_set_metadata_from_file(self,suffix_indir,data_filename,slices,variables) + + MetaData.write_metadata_to_file(self) + + + def get_and_set_metadata_from_file(self,suffix_indir,datafile_name,slices,variables): + """ + Retrieves several meta data from netCDF-file and sets corresponding class instance attributes. + Besides, the name of the experiment directory is constructed following the naming convention (see below) + + Naming convention: + [model_base]_Y[yyyy]to[yyyy]M[mm]to[mm]-[nx]x[ny]-[nnnn]N[eeee]E-[var1]_[var2]_(...)_[varN] + ---------------- Given ----------------|---------------- Created dynamically -------------- + + Note that the model-base as well as the date-identifiers must already be included in target_dir_in. + """ + + method_name = MetaData.get_and_set_metadata_from_file.__name__+" of Class "+MetaData.__name__ + + if not suffix_indir: raise ValueError(method_name+": suffix_indir must be a non-empty path.") + + # retrieve required information from file + flag_coords = ["N", "E"] + + print("Retrieve metadata based on file: '"+datafile_name+"'") + try: + datafile = Dataset(datafile_name,'r') + except: + print(method_name + ": Error when handling data file: '"+datafile_name+"'.") + exit() + + # Check if all requested variables can be obtained from datafile + MetaData.check_datafile(datafile,variables) + self.varnames = variables + + + self.nx, self.ny = np.abs(slices['lon_e'] - slices['lon_s']), np.abs(slices['lat_e'] - slices['lat_s']) + sw_c = [float(datafile.variables['lat'][slices['lat_e']-1]),float(datafile.variables['lon'][slices['lon_s']])] # meridional axis lat is oriented from north to south (i.e. monotonically decreasing) + self.sw_c = sw_c + + # Now start constructing exp_dir-string + # switch sign and coordinate-flags to avoid negative values appearing in exp_dir-name + if sw_c[0] < 0.: + sw_c[0] = np.abs(sw_c[0]) + flag_coords[0] = "S" + if sw_c[1] < 0.: + sw_c[1] = np.abs(sw_c[1]) + flag_coords[1] = "W" + nvar = len(variables) + + # splitting has to be done in order to retrieve the expname-suffix (and the year if required) + path_parts = os.path.split(suffix_indir.rstrip("/")) + + if (is_integer(path_parts[1])): + year = path_parts[1] + path_parts = os.path.split(path_parts[0].rstrip("/")) + else: + year = "" + + expdir, expname = path_parts[0], path_parts[1] + + # extend exp_dir_in successively (splitted up for better readability) + expname += "-"+str(self.nx) + "x" + str(self.ny) + expname += "-"+(("{0: 05.2f}"+flag_coords[0]+"{1:05.2f}"+flag_coords[1]).format(*sw_c)).strip().replace(".","")+"-" + + # reduced for-loop length as last variable-name is not followed by an underscore (see above) + for i in range(nvar-1): + expname += variables[i]+"_" + expname += variables[nvar-1] + + self.expname = expname + self.expdir = expdir + self.status = "" # uninitialized (is set when metadata is written/compared to/with json-file, see write_metadata_to_file-method) + + # ML 2020/04/24 E + + def write_metadata_to_file(self,dest_dir = None): + + """ + Write meta data attributes of class instance to json-file. + """ + + method_name = MetaData.write_metadata_to_file.__name__+" of Class "+MetaData.__name__ + # actual work: + meta_dict = {"expname": self.expname, + "expdir" : self.expdir} + + meta_dict["sw_corner_frame"] = { + "lat" : self.sw_c[0], + "lon" : self.sw_c[1] + } + + meta_dict["frame_size"] = { + "nx" : int(self.nx), + "ny" : int(self.ny) + } + + meta_dict["variables"] = [] + for i in range(len(self.varnames)): + print(self.varnames[i]) + meta_dict["variables"].append( + {"var"+str(i+1) : self.varnames[i]}) + + # create directory if required + if dest_dir is None: + dest_dir = os.path.join(self.expdir,self.expname) + if not os.path.exists(dest_dir): + print("Created experiment directory: '"+self.expdir+"'") + os.makedirs(dest_dir,exist_ok=True) + + meta_fname = os.path.join(dest_dir,"metadata.json") + + if os.path.exists(meta_fname): # check if a metadata-file already exists and check its content + self.status = "old" # set status to old in order to prevent repeated modification of shell-/Batch-scripts + with open(meta_fname,'r') as js_file: + dict_dupl = json.load(js_file) + + if dict_dupl != meta_dict: + print(method_name+": Already existing metadata (see '"+meta_fname+") do not fit data being processed right now. Ensure a common data base.") + sys.exit(1) + else: #do not need to do anything + pass + else: + # write dictionary to file + print(method_name+": Write dictionary to json-file: '"+meta_fname+"'") + with open(meta_fname,'w') as js_file: + json.dump(meta_dict,js_file) + self.status = "new" # set status to new in order to trigger modification of shell-/Batch-scripts + + def get_metadata_from_file(self,js_file): + + """ + Retrieves meta data attributes from json-file + """ + + with open(js_file) as js_file: + dict_in = json.load(js_file) + + self.exp_dir = dict_in["exp_dir"] + + self.sw_c = [dict_in["sw_corner_frame"]["lat"],dict_in["sw_corner_frame"]["lon"] ] + + self.nx = dict_in["frame_size"]["nx"] + self.ny = dict_in["frame_size"]["ny"] + + self.variables = [dict_in["variables"][ivar] for ivar in dict_in["variables"].keys()] + + + def write_dirs_to_batch_scripts(self,batch_script): + + """ + Expands ('known') directory-variables in batch_script by exp_dir-attribute of class instance + """ + + paths_to_mod = ["source_dir=","destination_dir=","checkpoint_dir=","results_dir="] # known directory-variables in batch-scripts + + with open(batch_script,'r') as file: + data = file.readlines() + + nlines = len(data) + matched_lines = [iline for iline in range(nlines) if any(str_id in data[iline] for str_id in paths_to_mod)] # list of line-number indices to be modified + + for i in matched_lines: + data[i] = add_str_to_path(data[i],self.expname) + + + with open(batch_script,'w') as file: + file.writelines(data) + + @staticmethod + def write_destdir_jsontmp(dest_dir, tmp_dir = None): + """ + Writes dest_dir to temporary json-file (temp.json) stored in the current working directory. + """ + + if not tmp_dir: tmp_dir = os.getcwd() + + file_tmp = os.path.join(tmp_dir,"temp.json") + dict_tmp = {"destination_dir": dest_dir} + + with open(file_tmp,"w") as js_file: + print("Save destination_dir-variable in temporary json-file: '"+file_tmp+"'") + json.dump(dict_tmp,js_file) + + @staticmethod + def get_destdir_jsontmp(tmp_dir = None): + """ + Retrieves dest_dir from temporary json-file which is expected to exist in the current working directory and returns it. + """ + + method_name = MetaData.get_destdir_jsontmp.__name__+" of Class "+MetaData.__name__ + + if not tmp_dir: tmp_dir = os.getcwd() + + file_tmp = os.path.join(tmp_dir,"temp.json") + + try: + with open(file_tmp,"r") as js_file: + dict_tmp = json.load(js_file) + except: + print(method_name+": Could not open requested json-file '"+file_tmp+"'") + sys.exit(1) + + if not "destination_dir" in dict_tmp.keys(): + raise Exception(method_name+": Could not find 'destination_dir' in dictionary obtained from "+file_tmp) + else: + return(dict_tmp.get("destination_dir")) + + + @staticmethod + def issubset(a,b): + """ + Checks if all elements of a exist in b or vice versa (depends on the length of the corresponding lists/sets) + """ + + if len(a) > len(b): + return(set(b).issubset(set(a))) + elif len(b) >= len(a): + return(set(a).issubset(set(b))) + + @staticmethod + def check_datafile(datafile,varnames): + """ + Checks if all varnames can be found in datafile + """ + + if not MetaData.issubset(varnames,datafile.variables.keys()): + for i in range(len(varnames2check)): + if not varnames2check[i] in f0.variables.keys(): + print("Variable '"+varnames2check[i]+"' not found in datafile '"+data_filenames[0]+"'.") + raise ValueError("Could not find the above mentioned variables.") + else: + pass + + + +# ----------------------------------- end of class MetaData ----------------------------------- + +# some auxilary functions which are not bound to MetaData-class + +def add_str_to_path(path_in,add_str): + + """ + Adds add_str to path_in if path_in does not already end with add_str. + Function is also capable to handle carriage returns for handling input-strings obtained by reading a file. + """ + + l_linebreak = path_in.endswith("\n") # flag for carriage return at the end of input string + line_str = path_in.rstrip("\n") + + if (not line_str.endswith(add_str)) or \ + (not line_str.endswith(add_str.rstrip("/"))): + + line_str = line_str + add_str + "/" + else: + print(add_str+" is already part of "+line_str+". No change is performed.") + + if l_linebreak: # re-add carriage return to string if required + return(line_str+"\n") + else: + return(line_str) + + +def is_integer(n): + ''' + Checks if input string is numeric and of type integer. + ''' + try: + float(n) + except ValueError: + return False + else: + return float(n).is_integer() + + + + diff --git a/pretrained_models/download_model.sh b/pretrained_models/download_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..fdffd762b334f709edc9369b1d7c69268b2c43b4 --- /dev/null +++ b/pretrained_models/download_model.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash + +# exit if any command fails +set -e + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 DATASET_NAME MODEL_NAME" >&2 + exit 1 +fi +DATASET_NAME=$1 +MODEL_NAME=$2 + +declare -A model_name_to_fname +if [ ${DATASET_NAME} = "bair_action_free" ]; then + model_name_to_fname=( + [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1 + [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1 + [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2 + [ours_gan]=${DATASET_NAME}_ours_gan + [ours_savp]=${DATASET_NAME}_ours_savp + [ours_vae]=${DATASET_NAME}_ours_vae_l1 + [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1 + [ours_vae_l2]=${DATASET_NAME}_ours_vae_l2 + [sv2p_time_invariant]=${DATASET_NAME}_sv2p_time_invariant + ) +elif [ ${DATASET_NAME} = "kth" ]; then + model_name_to_fname=( + [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1 + [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1 + [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2 + [ours_gan]=${DATASET_NAME}_ours_gan + [ours_savp]=${DATASET_NAME}_ours_savp + [ours_vae]=${DATASET_NAME}_ours_vae_l1 + [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1 + [sv2p_time_invariant]=${DATASET_NAME}_sv2p_time_invariant + [sv2p_time_variant]=${DATASET_NAME}_sv2p_time_variant + ) +elif [ ${DATASET_NAME} = "bair" ]; then + model_name_to_fname=( + [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1 + [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1 + [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2 + [ours_gan]=${DATASET_NAME}_ours_gan + [ours_savp]=${DATASET_NAME}_ours_savp + [ours_vae]=${DATASET_NAME}_ours_vae_l1 + [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1 + [ours_vae_l2]=${DATASET_NAME}_ours_vae_l2 + [sna_l1]=${DATASET_NAME}_sna_l1 + [sna_l2]=${DATASET_NAME}_sna_l2 + [sv2p_time_variant]=${DATASET_NAME}_sv2p_time_variant + ) +else + echo "Invalid dataset name: '${DATASET_NAME}' (choose from 'bair_action_free', 'kth', 'bair)" >&2 + exit 1 +fi + +if ! [[ ${model_name_to_fname[${MODEL_NAME}]} ]]; then + echo "Invalid model name '${MODEL_NAME}' when dataset name is '${DATASET_NAME}'. Valid mode names are:" >&2 + for model_name in "${!model_name_to_fname[@]}"; do + echo "'${model_name}'" >&2 + done + exit 1 +fi +TARGET_DIR=./pretrained_models/${DATASET_NAME}/${MODEL_NAME} +mkdir -p ${TARGET_DIR} +TAR_FNAME=${model_name_to_fname[${MODEL_NAME}]}.tar.gz +URL=http://rail.eecs.berkeley.edu/models/savp/pretrained_models/${TAR_FNAME} +echo "Downloading '${TAR_FNAME}'" +wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} +tar -xvf ${TARGET_DIR}/${TAR_FNAME} -C ${TARGET_DIR} +rm ${TARGET_DIR}/${TAR_FNAME} + +echo "Succesfully finished downloading pretrained model '${MODEL_NAME}' on dataset '${DATASET_NAME}' into directory ${TARGET_DIR}" diff --git a/scripts/Analysis_all.py b/scripts/Analysis_all.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed4b6b666634abafbd18af8517cf5b81b75bd61 --- /dev/null +++ b/scripts/Analysis_all.py @@ -0,0 +1,105 @@ +import pickle +import os +from matplotlib.pylab import plt + +# results_path = ["results_test_samples/era5_size_64_64_3_norm_dup_pretrained/ours_savp","results_test_samples/era5_size_64_64_3_norm_dup_pretrained_finetune/ours_savp", +# "results_test_samples/era5_size_64_64_3_norm_dup_pretrained_gan/kth_ours_gan","results_test_samples/era5_size_64_64_3_norm_dup_pretrained_vae_l1/kth_ours_vae_l1"] +# +# model_names = ["SAVP","SAVP_Finetune","GAN","VAE"] + + +# results_path = ["results_test_samples/era5_size_64_64_3_norm_dup_pretrained/ours_savp","results_test_samples/era5_size_64_64_3_norm_msl_gph_pretrained_savp/ours_savp", +# "results_test_samples/era5_size_64_64_3_norm_dup_pretrained_gan/kth_ours_gan","results_test_samples/era5_size_64_64_3_norm_msl_gph_pretrained_gan/kth_ours_gan"] +# +# model_names = ["SAVP_3T","SAVP_T-MSL-GPH","GAN_3T","GAN_T-MSL_GPH"] +# +# results_path = ["results_test_samples/era5_size_64_64_3_norm_dup_pretrained/ours_savp","results_test_samples/era5_size_64_64_3_norm_dup/ours_savp", +# "results_test_samples/era5_size_64_64_3_norm_dup_pretrained/kth_ours_gan","results_test_samples/era5_size_64_64_3_norm_dup/ours_gan", +# "results_test_samples/era5_size_64_64_3_norm_dup_pretrained/kth_ours_vae_l1","results_test_samples/era5_size_64_64_3_norm_dup/ours_vae_l1"] +# model_names = ["TF-SAVP(KTH)","SAVP (3T)","TF-GAN(KTH)","GAN (3T)","TF-VAE (KTH)","VAE (3T)"] + +## +##results_path = ["results_test_samples/era5_size_64_64_3_norm_t_msl_gph/ours_savp", "results_test_samples/era5_size_64_64_3_norm_dup/ours_savp", +## "results_test_samples/era5_size_64_64_3_norm_t_msl_gph/ours_gan","results_test_samples/era5_size_64_64_3_norm_dup/ours_gan"] +##model_names = ["SAVP(T-MSL-GPH)", "SAVP (3T)", "GAN (T-MSL-GPH)","GAN (3T)"] + +##results_path = ["results_test_samples/era5_size_64_64_3_norm_t_msl_gph/ours_savp", "results_test_samples/era5_size_64_64_3_norm_dup/ours_savp", +## "results_test_samples/era5_size_64_64_3_norm_t_msl_gph/ours_gan","results_test_samples/era5_size_64_64_3_norm_dup/ours_gan"] +##model_names = ["SAVP(T-MSL-GPH)", "SAVP (3T)", "GAN (T-MSL-GPH)","GAN (3T)"] +## +##mse_all = [] +##psnr_all = [] +##ssim_all = [] +##for path in results_path: +## p = os.path.join(path,"results.pkl") +## result = pickle.load(open(p,"rb")) +## mse = result["mse"] +## psnr = result["psnr"] +## ssim = result["ssim"] +## mse_all.append(mse) +## psnr_all.append(psnr) +## ssim_all.append(ssim) +## +## +##def get_metric(metrtic): +## if metric == "mse": +## return mse_all +## elif metric == "psnr": +## return psnr_all +## elif metric == "ssim": +## return ssim_all +## else: +## raise("Metric error") +## +##for metric in ["mse","psnr","ssim"]: +## evals = get_metric(metric) +## timestamp = list(range(1,11)) +## fig = plt.figure() +## plt.plot(timestamp, evals[0],'-.',label=model_names[0]) +## plt.plot(timestamp, evals[1],'--',label=model_names[1]) +## plt.plot(timestamp, evals[2],'-',label=model_names[2]) +## plt.plot(timestamp, evals[3],'--.',label=model_names[3]) +## # plt.plot(timestamp, evals[4],'*-.',label=model_names[4]) +## # plt.plot(timestamp, evals[5],'--*',label=model_names[5]) +## if metric == "mse": +## plt.legend(loc="upper left") +## else: +## plt.legend(loc = "upper right") +## plt.xlabel("Timestamps") +## plt.ylabel(metric) +## plt.title(metric,fontsize=15) +## plt.savefig(metric + "2.png") +## plt.clf() + + + +#persistent analysis +persistent_mse_all = [] +persistent_psnr_all = [] +persistent_ssim_all = [] +mse_all = [] +psnr_all = [] +ssim_all = [] +results_root_path = "/p/scratch/deepacf/video_prediction_shared_folder/results/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/ours_gan" +p1 = os.path.join(results_root_path,"results.pkl") +result1 = pickle.load(open(p1,"rb")) +p2 = os.path.join(results_root_path,"persistent_results.pkl") +result2 = pickle.load(open(p2,"rb")) +mse = result1["mse"] +psnr = result1["psnr"] +ssim = result1["ssim"] +mse_all.append(mse) +psnr_all.append(psnr) +ssim_all.append(ssim) + +persistent_mse = result2["mse"] +persistent_psnr = result2["psnr"] +persistent_ssim = result2["ssim"] +persistent_mse_all.append(persistent_mse) +persistent_psnr_all.append(persistent_psnr) +persistent_ssim_all.append(persistent_ssim) + + + +print("persistent_mse",persistent_mse_all) +print("mse",mse_all) diff --git a/scripts/combine_results.py b/scripts/combine_results.py new file mode 100644 index 0000000000000000000000000000000000000000..2b684ac1212350c8d0c630884baeb34154dd1fa3 --- /dev/null +++ b/scripts/combine_results.py @@ -0,0 +1,258 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import glob +import itertools +import os + +import cv2 +import numpy as np + +from video_prediction.utils import html +from video_prediction.utils.ffmpeg_gif import save_gif as ffmpeg_save_gif + + +def load_metrics(prefix_fname): + import csv + with open('%s.csv' % prefix_fname, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter='\t', quotechar='|') + rows = list(reader) + # skip header (first row), indices (first column), and means (last column) + metrics = np.array(rows)[1:, 1:-1].astype(np.float32) + return metrics + + +def load_images(image_fnames): + images = [] + for image_fname in image_fnames: + image = cv2.imread(image_fname) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + images.append(image) + return images + + +def save_images(image_fnames, images): + head, tail = os.path.split(image_fnames[0]) + if head and not os.path.exists(head): + os.makedirs(head) + for image_fname, image in zip(image_fnames, images): + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + cv2.imwrite(image_fname, image) + + +def save_gif(gif_fname, images, fps=4): + import moviepy.editor as mpy + head, tail = os.path.split(gif_fname) + if head and not os.path.exists(head): + os.makedirs(head) + clip = mpy.ImageSequenceClip(list(images), fps=fps) + clip.write_gif(gif_fname) + + +def concat_images(all_images): + """ + all_images is a list of lists of images + """ + min_height, min_width = None, None + for all_image in all_images: + for image in all_image: + if min_height is None or min_width is None: + min_height, min_width = image.shape[:2] + else: + min_height = min(min_height, image.shape[0]) + min_width = min(min_width, image.shape[1]) + + def maybe_resize(image): + if image.shape[:2] != (min_height, min_width): + image = cv2.resize(image, (min_height, min_width)) + return image + + resized_all_images = [] + for all_image in all_images: + resized_all_image = [maybe_resize(image) for image in all_image] + resized_all_images.append(resized_all_image) + all_images = resized_all_images + all_images = [np.concatenate(all_image, axis=1) for all_image in zip(*all_images)] + return all_images + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("results_dir", type=str) + parser.add_argument("--method_dirs", type=str, nargs='+', help='directories in results_dir (all of them by default)') + parser.add_argument("--method_names", type=str, nargs='+', help='method names for the header') + parser.add_argument("--web_dir", type=str, help='default is results_dir/web') + parser.add_argument("--sort_by", type=str, nargs=2, help='task and metric name to sort by, e.g. prediction mse') + parser.add_argument("--no_ffmpeg", action='store_true') + parser.add_argument("--batch_size", type=int, default=1, help="number of samples in batch") + parser.add_argument("--num_samples", type=int, help="number of samples for the table of sequence (all of them by default)") + parser.add_argument("--show_se", action='store_true', help="show standard error in the table metrics") + parser.add_argument("--only_metrics", action='store_true') + args = parser.parse_args() + + if args.web_dir is None: + args.web_dir = os.path.join(args.results_dir, 'web') + webpage = html.HTML(args.web_dir, 'Experiment name = %s' % os.path.normpath(args.results_dir), reflesh=1) + webpage.add_header1(os.path.normpath(args.results_dir)) + + if args.method_dirs is None: + unsorted_method_dirs = os.listdir(args.results_dir) + # exclude web_dir and all directories that starts with web + if args.web_dir in unsorted_method_dirs: + unsorted_method_dirs.remove(args.web_dir) + unsorted_method_dirs = [method_dir for method_dir in unsorted_method_dirs if not os.path.basename(method_dir).startswith('web')] + # put ground_truth and repeat in the front (if any) + method_dirs = [] + for first_method_dir in ['ground_truth', 'repeat']: + if first_method_dir in unsorted_method_dirs: + unsorted_method_dirs.remove(first_method_dir) + method_dirs.append(first_method_dir) + method_dirs.extend(sorted(unsorted_method_dirs)) + else: + method_dirs = list(args.method_dirs) + if args.method_names is None: + method_names = list(method_dirs) + else: + method_names = list(args.method_names) + method_dirs = [os.path.join(args.results_dir, method_dir) for method_dir in method_dirs] + + if args.sort_by: + task_name, metric_name = args.sort_by + sort_criterion = [] + for method_id, (method_name, method_dir) in enumerate(zip(method_names, method_dirs)): + metric = load_metrics(os.path.join(method_dir, task_name, 'metrics', metric_name)) + sort_criterion.append(np.mean(metric)) + sort_criterion, method_ids, method_names, method_dirs = \ + zip(*sorted(zip(sort_criterion, range(len(method_names)), method_names, method_dirs))) + webpage.add_header3('sorted by %s, %s' % tuple(args.sort_by)) + else: + method_ids = range(len(method_names)) + + # infer task and metric names from first method + metric_fnames = sorted(glob.glob('%s/*/metrics/*.csv' % glob.escape(method_dirs[0]))) + task_names = [] + metric_names = [] + for metric_fname in metric_fnames: + head, tail = os.path.split(metric_fname) + task_name = head.split('/')[-2] + metric_name, _ = os.path.splitext(tail) + task_names.append(task_name) + metric_names.append(metric_name) + + # save metrics + webpage.add_table() + header_txts = [''] + header_colspans = [2] + for task_name in task_names: + if task_name != header_txts[-1]: + header_txts.append(task_name) + header_colspans.append(2 if args.show_se else 1) # mean and standard error for each task + else: + # group consecutive task names that are the same + header_colspans[-1] += 2 if args.show_se else 1 + webpage.add_row(header_txts, header_colspans) + subheader_txts = ['id', 'method'] + for task_name, metric_name in zip(task_names, metric_names): + subheader_txts.append('%s (mean)' % metric_name) + if args.show_se: + subheader_txts.append('%s (se)' % metric_name) + webpage.add_row(subheader_txts) + all_metric_means = [] + for method_id, method_name, method_dir in zip(method_ids, method_names, method_dirs): + metric_txts = [method_id, method_name] + metric_means = [] + for task_name, metric_name in zip(task_names, metric_names): + metric = load_metrics(os.path.join(method_dir, task_name, 'metrics', metric_name)) + metric_mean = np.mean(metric) + num_samples = len(metric) + metric_se = np.std(metric) / np.sqrt(num_samples) + metric_txts.append('%.4f' % metric_mean) + if args.show_se: + metric_txts.append('%.4f' % metric_se) + metric_means.append(metric_mean) + webpage.add_row(metric_txts) + all_metric_means.append(metric_means) + webpage.save() + + if args.only_metrics: + return + + # infer task names from first method + outputs_dirs = sorted(glob.glob('%s/*/outputs' % glob.escape(method_dirs[0]))) + task_names = [outputs_dir.split('/')[-2] for outputs_dir in outputs_dirs] + + # save image sequences + image_dir = os.path.join(args.web_dir, 'images') + webpage.add_table() + header_txts = [''] + subheader_txts = ['id'] + methods_subheader_txts = [''] + header_colspans = [1] + subheader_colspans = [1] + methods_subheader_colspans = [1] + num_samples = args.num_samples or num_samples + for sample_ind in range(num_samples): + if sample_ind % args.batch_size == 0: + print("saving samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) + ims = [None] + txts = [sample_ind] + links = [None] + colspans = [1] + for task_name in task_names: + # load input images from first method + input_fnames = sorted(glob.glob('%s/inputs/*_%05d_??.png' % + (glob.escape(os.path.join(method_dirs[0], task_name)), sample_ind))) + input_images = load_images(input_fnames) + # save input images as image sequence + input_fnames = [os.path.join(task_name, 'inputs', os.path.basename(input_fname)) for input_fname in input_fnames] + save_images([os.path.join(image_dir, input_fname) for input_fname in input_fnames], input_images) + # infer output names from first method + output_fnames = sorted(glob.glob('%s/outputs/*_%05d_??.png' % + (glob.escape(os.path.join(method_dirs[0], task_name)), sample_ind))) + output_names = sorted(set(os.path.splitext(os.path.basename(output_fname))[0][:-9] + for output_fname in output_fnames)) # remove _?????_??.png + # load output images + all_output_images = [] + for output_name in output_names: + for method_name, method_dir in zip(method_names, method_dirs): + output_fnames = sorted(glob.glob('%s/outputs/%s_%05d_??.png' % + (glob.escape(os.path.join(method_dir, task_name)), + output_name, sample_ind))) + output_images = load_images(output_fnames) + all_output_images.append(output_images) + # concatenate output images of all the methods + all_output_images = concat_images(all_output_images) + # save output images as image sequence or as gif clip + output_fname = os.path.join(task_name, 'outputs', '%s_%05d.gif' % ('_'.join(output_names), sample_ind)) + if args.no_ffmpeg: + save_gif(os.path.join(image_dir, output_fname), all_output_images, fps=4) + else: + ffmpeg_save_gif(os.path.join(image_dir, output_fname), all_output_images, fps=4) + + if sample_ind == 0: + header_txts.append(task_name) + subheader_txts.extend(['inputs', 'outputs']) + header_colspans.append(len(input_fnames) + len(method_ids) * len(output_names)) + subheader_colspans.extend([len(input_fnames), len(method_ids) * len(output_names)]) + method_id_strs = ['%02d' % method_id for method_id in method_ids] + methods_subheader_txts.extend([''] + list(itertools.chain(*[method_id_strs] * len(output_names)))) + methods_subheader_colspans.extend([len(input_fnames)] + [1] * (len(method_ids) * len(output_names))) + ims.extend(input_fnames + [output_fname]) + txts.extend([None] * (len(input_fnames) + 1)) + links.extend(input_fnames + [output_fname]) + colspans.extend([1] * len(input_fnames) + [len(method_ids) * len(output_names)]) + + if sample_ind == 0: + webpage.add_row(header_txts, header_colspans) + webpage.add_row(subheader_txts, subheader_colspans) + webpage.add_row(methods_subheader_txts, methods_subheader_colspans) + webpage.add_images(ims, txts, links, colspans, height=64, width=None) + if (sample_ind + 1) % args.batch_size == 0: + webpage.save() + webpage.save() + + +if __name__ == '__main__': + main() diff --git a/scripts/evaluate.py b/scripts/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4062f7e3abe4574fa434cfbb2bc66a1b63eed3 --- /dev/null +++ b/scripts/evaluate.py @@ -0,0 +1,318 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import argparse +import csv +import errno +import json +import os +import random + +import numpy as np +import tensorflow as tf + +from video_prediction import datasets, models + + +def save_image_sequence(prefix_fname, images, time_start_ind=0): + import cv2 + head, tail = os.path.split(prefix_fname) + if head and not os.path.exists(head): + os.makedirs(head) + for t, image in enumerate(images): + image_fname = '%s_%02d.png' % (prefix_fname, time_start_ind + t) + image = (image * 255.0).astype(np.uint8) + if image.shape[-1] == 1: + image = np.tile(image, (1, 1, 3)) + else: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + cv2.imwrite(image_fname, image) + + +def save_image_sequences(prefix_fname, images, sample_start_ind=0, time_start_ind=0): + head, tail = os.path.split(prefix_fname) + if head and not os.path.exists(head): + os.makedirs(head) + for i, images_ in enumerate(images): + images_fname = '%s_%05d' % (prefix_fname, sample_start_ind + i) + save_image_sequence(images_fname, images_, time_start_ind=time_start_ind) + + +def save_metrics(prefix_fname, metrics, sample_start_ind=0): + head, tail = os.path.split(prefix_fname) + if head and not os.path.exists(head): + os.makedirs(head) + assert metrics.ndim == 2 + file_mode = 'w' if sample_start_ind == 0 else 'a' + with open('%s.csv' % prefix_fname, file_mode, newline='') as csvfile: + writer = csv.writer(csvfile, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL) + if sample_start_ind == 0: + writer.writerow(map(str, ['sample_ind'] + list(range(metrics.shape[1])) + ['mean'])) + for i, metrics_row in enumerate(metrics): + writer.writerow(map(str, [sample_start_ind + i] + list(metrics_row) + [np.mean(metrics_row)])) + + +def load_metrics(prefix_fname): + with open('%s.csv' % prefix_fname, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter='\t', quotechar='|') + rows = list(reader) + # skip header (first row), indices (first column), and means (last column) + metrics = np.array(rows)[1:, 1:-1].astype(np.float32) + return metrics + + +def merge_hparams(hparams0, hparams1): + hparams0 = hparams0 or [] + hparams1 = hparams1 or [] + if not isinstance(hparams0, (list, tuple)): + hparams0 = [hparams0] + if not isinstance(hparams1, (list, tuple)): + hparams1 = [hparams1] + hparams = list(hparams0) + list(hparams1) + # simplify into the content if possible + if len(hparams) == 1: + hparams, = hparams + return hparams + + +def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None): + sequence_length = model_hparams.sequence_length + context_frames = model_hparams.context_frames + future_length = sequence_length - context_frames + + context_images = results['images'][:, :context_frames] + + if 'eval_diversity' in results: + metric = results['eval_diversity'] + metric_name = 'diversity' + subtask_dir = task_dir + '_%s' % metric_name + save_metrics(os.path.join(subtask_dir, 'metrics', metric_name), + metric, sample_start_ind=sample_start_ind) + + subtasks = subtasks or ['max'] + for subtask in subtasks: + metric_names = [] + for k in results.keys(): + if re.match('eval_(\w+)/%s' % subtask, k) and not re.match('eval_gen_images_(\w+)/%s' % subtask, k): + m = re.match('eval_(\w+)/%s' % subtask, k) + metric_names.append(m.group(1)) + for metric_name in metric_names: + subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask) + gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images')) + # only keep the future frames + gen_images = gen_images[:, -future_length:] + metric = results['eval_%s/%s' % (metric_name, subtask)] + save_metrics(os.path.join(subtask_dir, 'metrics', metric_name), + metric, sample_start_ind=sample_start_ind) + if only_metrics: + continue + + save_image_sequences(os.path.join(subtask_dir, 'inputs', 'context_image'), + context_images, sample_start_ind=sample_start_ind) + save_image_sequences(os.path.join(subtask_dir, 'outputs', 'gen_image'), + gen_images, sample_start_ind=sample_start_ind) + + +def main(): + """ + results_dir + ├── output_dir # condition / method + │ ├── prediction_eval_lpips_max # task: best sample in terms of LPIPS similarity + │ │ ├── inputs + │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step + │ │ │ └── ... + │ │ ├── outputs + │ │ │ ├── gen_image_00000_00.png # predicted images (only the future ones) + │ │ │ └── ... + │ │ └── metrics + │ │ └── lpips.csv + │ ├── prediction_eval_ssim_max # task: best sample in terms of SSIM + │ │ ├── inputs + │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step + │ │ │ └── ... + │ │ ├── outputs + │ │ │ ├── gen_image_00000_00.png # predicted images (only the future ones) + │ │ │ └── ... + │ │ └── metrics + │ │ └── ssim.csv + │ └── ... + └── ... + """ + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified") + parser.add_argument("--output_dir", help="output directory where results are saved. default is results_dir/model_fname, " + "where model_fname is the directory name of checkpoint") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + + parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + + parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") + parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") + parser.add_argument("--num_epochs", type=int, default=1) + + parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min)') + parser.add_argument("--only_metrics", action='store_true') + parser.add_argument("--num_stochastic_samples", type=int, default=100) + + parser.add_argument("--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset") + parser.add_argument("--gt_outputs_dir", type=str, help="directory containing output ground truth images for ismple dataset") + + parser.add_argument("--eval_parallel_iterations", type=int, default=10) + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int, default=7) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + dataset_hparams_dict = {} + model_hparams_dict = {} + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1]) + else: + if not args.dataset: + raise ValueError('dataset is required when checkpoint is not specified') + if not args.model: + raise ValueError('model is required when checkpoint is not specified') + args.output_dir = args.output_dir or os.path.join(args.results_dir, 'model.%s' % args.model) + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + dataset = VideoDataset( + args.input_dir, + mode=args.mode, + num_epochs=args.num_epochs, + seed=args.seed, + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + mode=args.mode, + hparams_dict=hparams_dict, + hparams=args.model_hparams, + eval_num_samples=args.num_stochastic_samples, + eval_parallel_iterations=args.eval_parallel_iterations) + + if args.num_samples: + if args.num_samples > dataset.num_examples_per_epoch(): + raise ValueError('num_samples cannot be larger than the dataset') + num_examples_per_epoch = args.num_samples + else: + num_examples_per_epoch = dataset.num_examples_per_epoch() + if num_examples_per_epoch % args.batch_size != 0: + #bing0 + #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + pass + #Bing if it is era 5 data we used dataset.make_batch_v2 + #inputs = dataset.make_batch(args.batch_size) + inputs = dataset.make_batch_v2(args.batch_size) + input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + with tf.variable_scope(''): + model.build_graph(input_phs) + + output_dir = args.output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + sess = tf.Session(config=config) + sess.graph.as_default() + + model.restore(sess, args.checkpoint) + + sample_ind = 0 + while True: + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results = sess.run(inputs) + except tf.errors.OutOfRangeError: + break + print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) + + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + # compute "best" metrics using the computation graph + fetches = {'images': model.inputs['images']} + fetches.update(model.eval_outputs.items()) + fetches.update(model.eval_metrics.items()) + results = sess.run(fetches, feed_dict=feed_dict) + save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'), + results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks) + sample_ind += args.batch_size + + metric_fnames = [] + metric_names = ['psnr', 'ssim', 'lpips'] + subtasks = ['max'] + for metric_name in metric_names: + for subtask in subtasks: + metric_fnames.append( + os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name)) + + for metric_fname in metric_fnames: + task_name, _, metric_name = metric_fname.split('/')[-3:] + metric = load_metrics(metric_fname) + print('=' * 31) + print(task_name, metric_name) + print('-' * 31) + metric_header_format = '{:>10} {:>20}' + metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})' + print(metric_header_format.format('time step', os.path.split(metric_fname)[1])) + for t, (metric_mean, metric_std) in enumerate(zip(metric.mean(axis=0), metric.std(axis=0))): + print(metric_row_format.format(t, metric_mean, metric_std)) + print(metric_row_format.format('mean (std)', metric.mean(), metric.std())) + print('=' * 31) + + +if __name__ == '__main__': + main() diff --git a/scripts/evaluate_all.sh b/scripts/evaluate_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..c57f5c895da22f167a0a1bb4204a225965c8a48c --- /dev/null +++ b/scripts/evaluate_all.sh @@ -0,0 +1,44 @@ +# BAIR action-free robot pushing dataset +dataset=bair_action_free +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sv2p_time_invariant \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8 +done + +# KTH human actions dataset +# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence +dataset=kth +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sv2p_time_variant \ + sv2p_time_invariant \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/kth --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 1 +done + +# BAIR action-conditioned robot pushing dataset +dataset=bair +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sna_l1 \ + sna_l2 \ + sv2p_time_variant \ +; do + CUDA_VISIBLE_DEVICES=1 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8 +done diff --git a/scripts/evaluate_svg.sh b/scripts/evaluate_svg.sh new file mode 100644 index 0000000000000000000000000000000000000000..212c4ba239ecdbd15c70c05b9336c32175dc8c5c --- /dev/null +++ b/scripts/evaluate_svg.sh @@ -0,0 +1 @@ +#!/usr/bin/env bash \ No newline at end of file diff --git a/scripts/generate.py b/scripts/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..e1893194fe600981350a52e9880df9f6a034701d --- /dev/null +++ b/scripts/generate.py @@ -0,0 +1,537 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import math +import random +import cv2 +import numpy as np +import tensorflow as tf +import seaborn as sns +import pickle +from random import seed +import random +import json +import numpy as np +#from six.moves import cPickle +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import matplotlib.animation as animation +import seaborn as sns +import pandas as pd +from video_prediction import datasets, models +from matplotlib.colors import LinearSegmentedColormap +from matplotlib.ticker import MaxNLocator +from video_prediction.utils.ffmpeg_gif import save_gif + +with open("./splits_size_64_64_1/geo_info.json","r") as json_file: + geo = json.load(json_file) + lat = [round(i,2) for i in geo["lat"]] + lon = [round(i,2) for i in geo["lon"]] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") + parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") + parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") + parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is " + "results_gif_dir/model_fname") + parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is " + "results_png_dir/model_fname") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + + parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + + parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") + parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") + parser.add_argument("--num_epochs", type=int, default=1) + + parser.add_argument("--num_stochastic_samples", type=int, default=1) #Bing original is 5, change to 1 + parser.add_argument("--gif_length", type=int, help="default is sequence_length") + parser.add_argument("--fps", type=int, default=4) + + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int, default=7) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + args.results_gif_dir = args.results_gif_dir or args.results_dir + args.results_png_dir = args.results_png_dir or args.results_dir + dataset_hparams_dict = {} + model_hparams_dict = {} + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1]) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1]) + else: + if not args.dataset: + raise ValueError('dataset is required when checkpoint is not specified') + if not args.model: + raise ValueError('model is required when checkpoint is not specified') + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + dataset = VideoDataset( + args.input_dir, + mode=args.mode, + num_epochs=args.num_epochs, + seed=args.seed, + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + mode=args.mode, + hparams_dict=hparams_dict, + hparams=args.model_hparams) + + sequence_length = model.hparams.sequence_length + context_frames = model.hparams.context_frames + future_length = sequence_length - context_frames + + if args.num_samples: + if args.num_samples > dataset.num_examples_per_epoch(): + raise ValueError('num_samples cannot be larger than the dataset') + num_examples_per_epoch = args.num_samples + else: + #Bing: error occurs here, cheats a little bit here + #num_examples_per_epoch = dataset.num_examples_per_epoch() + num_examples_per_epoch = args.batch_size * 8 + if num_examples_per_epoch % args.batch_size != 0: + #bing + #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + pass + #Bing if it is era 5 data we used dataset.make_batch_v2 + #inputs = dataset.make_batch(args.batch_size) + inputs = dataset.make_batch(args.batch_size) + input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + with tf.variable_scope(''): + model.build_graph(input_phs) + + for output_dir in (args.output_gif_dir, args.output_png_dir): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + sess = tf.Session(config=config) + sess.graph.as_default() + model.restore(sess, args.checkpoint) + sample_ind = 0 + gen_images_all = [] + input_images_all = [] + + while True: + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results = sess.run(inputs) + except tf.errors.OutOfRangeError: + break + print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + for stochastic_sample_ind in range(args.num_stochastic_samples): #Todo: why use here + print("Stochastic sample id", stochastic_sample_ind) + gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) + #input_images = sess.run(inputs["images"]) + #Bing: Add evaluation metrics + # fetches = {'images': model.inputs['images']} + # fetches.update(model.eval_outputs.items()) + # fetches.update(model.eval_metrics.items()) + # results = sess.run(fetches, feed_dict = feed_dict) + # input_images = results["images"] #shape (batch_size,future_frames,height,width,channel) + # only keep the future frames + #gen_images = gen_images[:, -future_length:] #(8,10,64,64,1) (batch_size, sequences, height, width, channel) + #input_images = input_results["images"][:,-future_length:,:,:] + input_images = input_results["images"][:,1:,:,:,:] + #gen_mse_avg = results["eval_mse/avg"] #shape (batch_size,future_frames) + print("Finish sample ind",stochastic_sample_ind) + input_gen_diff_ = input_images - gen_images + #diff_image_range = pd.cut(input_gen_diff_.flatten(), bins = 4, labels = [-10, -5, 0, 5], right = False) + #diff_image_range = np.reshape(np.array(diff_image_range),input_gen_diff_.shape) + gen_images_all.extend(gen_images) + input_images_all.extend(input_images) + + colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] + cmap_name = 'my_list' + if sample_ind < 100: + for i in range(len(gen_images)): + name = 'Batch_id_' + str(sample_ind) + " + Sample_" + str(i) + gen_images_ = gen_images[i, :] + gen_mse_avg_ = [np.mean(input_gen_diff_[i, frame, :, :, :]**2) for frame in + range(19)] # return the list with 10 (sequence) mse + input_gen_diff = input_gen_diff_ [i,:,:,:,:] + input_images_ = input_images[i, :] + #gen_mse_avg_ = gen_mse_avg[i, :] + + # Bing: This is to check the difference between the images and next images for debugging the freezon issues + # gen_images_diff = [] + # for gen_idx in range(len(gen_images_) - 1): + # img_1 = gen_images_[gen_idx, :, :, :] + # img_2 = gen_images_[gen_idx + 1, :, :, :] + # img_diff = img_2 - img_1 + # img_diff_nonzero = [e for img_idx, e in enumerate(img_diff.flatten()) if round(e,3) != 0.000] + # gen_images_diff.append(img_diff_nonzero) + + fig = plt.figure() + gs = gridspec.GridSpec(4,6) + gs.update(wspace = 0.7,hspace=0.8) + ax1 = plt.subplot(gs[0:2,0:3]) + ax2 = plt.subplot(gs[0:2,3:],sharey=ax1) + ax3 = plt.subplot(gs[2:4,0:3]) + ax4 = plt.subplot(gs[2:4,3:]) + xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] + ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] + plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels) + ax1.title.set_text("(a) Ground Truth") + ax2.title.set_text("(b) SAVP") + ax3.title.set_text("(c) Diff.") + ax4.title.set_text("(d) MSE") + + ax1.xaxis.set_tick_params(labelsize=7) + ax1.yaxis.set_tick_params(labelsize = 7) + ax2.xaxis.set_tick_params(labelsize=7) + ax2.yaxis.set_tick_params(labelsize = 7) + ax3.xaxis.set_tick_params(labelsize=7) + ax3.yaxis.set_tick_params(labelsize = 7) + + init_images = np.zeros((input_images_.shape[1], input_images_.shape[2])) + print("inti images shape", init_images.shape) + xdata, ydata = [], [] + plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) + plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) + #x = np.linspace(0, 64, 64) + #y = np.linspace(0, 64, 64) + #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) + #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) + fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7) + fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7) + + cm = LinearSegmentedColormap.from_list( + cmap_name, "bwr", N = 5) + + plot3 = ax3.imshow(init_images, vmin=-10, vmax=10, cmap=cm)#cmap = 'PuBu_r', + plot4, = ax4.plot([], [], color = "r") + ax4.set_xlim(0, len(gen_mse_avg_)-1) + ax4.set_ylim(0, 10) + ax4.set_xlabel("Frames", fontsize=10) + #ax4.set_ylabel("MSE", fontsize=10) + ax4.xaxis.set_tick_params(labelsize=7) + ax4.yaxis.set_tick_params(labelsize=7) + + + plots = [plot1, plot2, plot3, plot4] + + #fig.colorbar(plots[1], ax = [ax1, ax2]) + + fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7) + #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7) + #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7) + + def animation_sample(t): + input_image = input_images_[t, :, :, 0] + gen_image = gen_images_[t, :, :, 0] + diff_image = input_gen_diff[t,:,:,0] + + data = gen_mse_avg_[:t + 1] + # x = list(range(len(gen_mse_avg_)))[:t+1] + xdata.append(t) + print("xdata", xdata) + ydata.append(gen_mse_avg_[t]) + + print("ydata", ydata) + # p = sns.lineplot(x=x,y=data,color="b") + # p.tick_params(labelsize=17) + # plt.setp(p.lines, linewidth=6) + plots[0].set_data(input_image) + plots[1].set_data(gen_image) + #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) + #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) + plots[2].set_data(diff_image) + plots[3].set_data(xdata, ydata) + fig.suptitle("Frame " + str(t+1)) + + return plots + + ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000, + repeat_delay = 2000) + ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4")) + + else: + pass + + + + + + + + + + + + + + + + + + # # for i, gen_mse_avg_ in enumerate(gen_mse_avg): + # # ims = [] + # # fig = plt.figure() + # # plt.xlim(0,len(gen_mse_avg_)) + # # plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg)) + # # plt.xlabel("Frames") + # # plt.ylabel("MSE_AVG") + # # #X = list(range(len(gen_mse_avg_))) + # # #for t, gen_mse_avg_ in enumerate(gen_mse_avg): + # # def animate_metric(j): + # # data = gen_mse_avg_[:(j+1)] + # # x = list(range(len(gen_mse_avg_)))[:(j+1)] + # # p = sns.lineplot(x=x,y=data,color="b") + # # p.tick_params(labelsize=17) + # # plt.setp(p.lines, linewidth=6) + # # ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif")) + # # + # # + # # for i, input_images_ in enumerate(input_images): + # # #context_images_ = (input_results['images'][i]) + # # #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # ims = [] + # # fig = plt.figure() + # # for t, input_image in enumerate(input_images_): + # # im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2,"Frame_" + str(t)) + # # ims.append([im,ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif")) + # # #plt.show() + # # + # # for i,gen_images_ in enumerate(gen_images): + # # ims = [] + # # fig = plt.figure() + # # for t, gen_image in enumerate(gen_images_): + # # im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2, "Frame_" + str(t)) + # # ims.append([im, ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + # # ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif")) + # + # + # # for i, gen_images_ in enumerate(gen_images): + # # #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) + # # #gen_images_ = (gen_images_ * 255.0).astype(np.uint8) + # # #bing + # # context_images_ = (input_results['images'][i]) + # # gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) + # # plt.figure(figsize = (10,2)) + # # gs = gridspec.GridSpec(2,10) + # # gs.update(wspace=0.,hspace=0.) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1))) + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # plt.subplot(gs[t]) + # # plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') # the last index sets the channel. 0 = t2 + # # # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Actual', fontsize = 10) + # # + # # plt.subplot(gs[t + 10]) + # # plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Predicted', fontsize = 10) + # # plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png') + # # plt.clf() + # + # # if args.gif_length: + # # context_and_gen_images = context_and_gen_images[:args.gif_length] + # # save_gif(os.path.join(args.output_gif_dir, gen_images_fname), + # # context_and_gen_images, fps=args.fps) + # # + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # if gen_image.shape[-1] == 1: + # # gen_image = np.tile(gen_image, (1, 1, 3)) + # # else: + # # gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) + # # cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) + + sample_ind += args.batch_size + + + with open(os.path.join(args.output_png_dir, "input_images_all"),"wb") as input_files: + pickle.dump(input_images_all,input_files) + + with open(os.path.join(args.output_png_dir, "gen_images_all"),"wb") as gen_files: + pickle.dump(gen_images_all,gen_files) + + with open(os.path.join(args.output_png_dir, "input_images_all"),"rb") as input_files: + input_images_all = pickle.load(input_files) + + with open(os.path.join(args.output_png_dir, "gen_images_all"),"rb") as gen_files: + gen_images_all=pickle.load(gen_files) + ims = [] + fig = plt.figure() + for frame in range(19): + input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) + #pix_mean = np.mean(input_gen_diff, axis = 0) + #pix_std = np.std(input_gen_diff, axis=0) + im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu') + if frame == 0: + fig.colorbar(im) + ttl = plt.text(1.5, 2, "Frame_" + str(frame +1)) + ims.append([im, ttl]) + ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4")) + plt.close("all") + + ims = [] + fig = plt.figure() + for frame in range(19): + pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) + #pix_mean = np.mean(input_gen_diff, axis = 0) + #pix_std = np.std(input_gen_diff, axis=0) + im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu') + if frame == 0: + fig.colorbar(im) + ttl = plt.text(1.5, 2, "Frame_" + str(frame+1)) + ims.append([im, ttl]) + ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4")) + + gen_images_all = np.array(gen_images_all) + input_images_all = np.array(input_images_all) + # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first + # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2) + # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 ) + + mse_model = np.mean((input_images_all[:, :10,:,:,0] - gen_images_all[:, :10,:,:,0])**2) # look at all timesteps except the first + mse_model_last = np.mean((input_images_all[:,10,:,:,0] - gen_images_all[:, 10,:,:,0])**2) + mse_prev = np.mean((input_images_all[:, :9, :, :, 0] - input_images_all[:, 1:10, :, :, 0])**2 ) + + def psnr(img1, img2): + mse = np.mean((img1 - img2) ** 2) + if mse == 0: return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + + psnr_model = psnr(input_images_all[:, :10, :, :, 0], gen_images_all[:, :10, :, :, 0]) + psnr_model_last = psnr(input_images_all[:, 10, :, :, 0], gen_images_all[:,10, :, :, 0]) + psnr_prev = psnr(input_images_all[:, :9, :, :, 0], input_images_all[:, 1:10, :, :, 0]) + f = open(os.path.join(args.output_png_dir,'prediction_scores_4prediction.txt'), 'w') + f.write("Model MSE: %f\n" % mse_model) + f.write("Model MSE from only last prediction in sequence: %f\n" % mse_model_last) + f.write("Previous Frame MSE: %f\n" % mse_prev) + f.write("Model PSNR: %f\n" % psnr_model) + f.write("Model PSNR from only last prediction in sequence: %f\n" % psnr_model_last) + f.write("Previous frame PSNR: %f\n" % psnr_prev) + f.write("Shape of X_test: " + str(input_images_all.shape)) + f.write("") + f.write("Shape of X_hat: " + str(gen_images_all.shape)) + f.close() + + seed(1) + s = random.sample(range(len(gen_images_all)), 100) + print("******KDP******") + #kernel density plot for checking the model collapse + fig = plt.figure() + kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images") + kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True") + kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability') + plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400) + plt.clf() + + #line plot for evaluating the prediction and groud-truth + for i in [0,3,6,9,12,15,18]: + fig = plt.figure() + plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3) + #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3) + plt.xlabel("Prediction") + plt.ylabel("Real values") + plt.title("Frame_{}".format(i+1)) + plt.plot([250,300], [250,300],color="black") + plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i)))) + plt.clf() + + + mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence) + x = [str(i+1) for i in list(range(19))] + fig,axis = plt.subplots() + mean_f = np.mean(mse_model_by_frames, axis = 0) + median = np.median(mse_model_by_frames, axis=0) + q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0) + q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0) + d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0) + d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0) + plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range") + plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range") + plt.plot(x, median, color="grey", linewidth=0.6, label="Median") + plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean") + plt.title(f'MSE percentile') + plt.xlabel("Frames") + plt.legend(loc=2, fontsize=8) + plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png")) + +if __name__ == '__main__': + main() diff --git a/scripts/generate_all.sh b/scripts/generate_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..c3736b36df1840cbc46b89526cef0c908b500760 --- /dev/null +++ b/scripts/generate_all.sh @@ -0,0 +1,55 @@ +# BAIR action-free robot pushing dataset +dataset=bair_action_free +CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \ + --dataset_hparams sequence_length=30 --model ground_truth --mode test \ + --output_gif_dir results_test_2afc/${dataset}/ground_truth \ + --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + sv2p_time_invariant \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \ + --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \ + --results_gif_dir results_test_2afc/${dataset} \ + --results_png_dir results_test_samples/${dataset} --gif_length 10 +done + +# KTH human actions dataset +# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence +dataset=kth +CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/kth --dataset kth \ + --dataset_hparams sequence_length=40 --model ground_truth --mode test \ + --output_gif_dir results_test_2afc/${dataset}/ground_truth \ + --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 --batch_size 1 +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + sv2p_time_invariant \ + sv2p_time_variant \ +; do + CUDA_VISIBLE_DEVICES=1 python scripts/generate.py --input_dir data/kth \ + --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test \ + --results_gif_dir results_test_2afc/${dataset} \ + --results_png_dir results_test_samples/${dataset} --gif_length 10 --batch_size 1 +done + +# BAIR action-conditioned robot pushing dataset +dataset=bair +CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \ + --dataset_hparams sequence_length=30 --model ground_truth --mode test \ + --output_gif_dir results_test_2afc/${dataset}/ground_truth \ + --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + sv2p_time_variant \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \ + --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \ + --results_gif_dir results_test_2afc/${dataset} \ + --results_png_dir results_test_samples/${dataset} --gif_length 10 +done diff --git a/scripts/generate_anomaly.py b/scripts/generate_anomaly.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8e555c2e223c51bfb555c7dd0fa0eb089610d8 --- /dev/null +++ b/scripts/generate_anomaly.py @@ -0,0 +1,518 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import math +import random +import cv2 +import numpy as np +import tensorflow as tf +import seaborn as sns +import pickle +from random import seed +import random +import json +import numpy as np +#from six.moves import cPickle +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import matplotlib.animation as animation +import seaborn as sns +import pandas as pd +from video_prediction import datasets, models +from matplotlib.colors import LinearSegmentedColormap +from matplotlib.ticker import MaxNLocator +from video_prediction.utils.ffmpeg_gif import save_gif + +with open("./splits_size_64_64_1/geo_info.json","r") as json_file: + geo = json.load(json_file) + lat = [round(i,2) for i in geo["lat"]] + lon = [round(i,2) for i in geo["lon"]] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") + parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") + parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") + parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is " + "results_gif_dir/model_fname") + parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is " + "results_png_dir/model_fname") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + + parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + + parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") + parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") + parser.add_argument("--num_epochs", type=int, default=1) + + parser.add_argument("--num_stochastic_samples", type=int, default=1) #Bing original is 5, change to 1 + parser.add_argument("--gif_length", type=int, help="default is sequence_length") + parser.add_argument("--fps", type=int, default=4) + + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int, default=7) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + args.results_gif_dir = args.results_gif_dir or args.results_dir + args.results_png_dir = args.results_png_dir or args.results_dir + dataset_hparams_dict = {} + model_hparams_dict = {} + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1]) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1]) + else: + if not args.dataset: + raise ValueError('dataset is required when checkpoint is not specified') + if not args.model: + raise ValueError('model is required when checkpoint is not specified') + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + dataset = VideoDataset( + args.input_dir, + mode=args.mode, + num_epochs=args.num_epochs, + seed=args.seed, + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + mode=args.mode, + hparams_dict=hparams_dict, + hparams=args.model_hparams) + + sequence_length = model.hparams.sequence_length + context_frames = model.hparams.context_frames + future_length = sequence_length - context_frames + + if args.num_samples: + if args.num_samples > dataset.num_examples_per_epoch(): + raise ValueError('num_samples cannot be larger than the dataset') + num_examples_per_epoch = args.num_samples + else: + #Bing: error occurs here, cheats a little bit here + #num_examples_per_epoch = dataset.num_examples_per_epoch() + num_examples_per_epoch = args.batch_size * 8 + if num_examples_per_epoch % args.batch_size != 0: + #bing + #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + pass + #Bing if it is era 5 data we used dataset.make_batch_v2 + #inputs = dataset.make_batch(args.batch_size) + inputs, inputs_mean = dataset.make_batch_v2(args.batch_size) + input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + with tf.variable_scope(''): + model.build_graph(input_phs) + + for output_dir in (args.output_gif_dir, args.output_png_dir): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + sess = tf.Session(config=config) + sess.graph.as_default() + model.restore(sess, args.checkpoint) + sample_ind = 0 + gen_images_all = [] + input_images_all = [] + + while True: + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results = sess.run(inputs) + input_mean_results = sess.run(inputs_mean) + input_final = input_results["images"] + input_mean_results["images"] + + except tf.errors.OutOfRangeError: + break + print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + for stochastic_sample_ind in range(args.num_stochastic_samples): + print("Stochastic sample id", stochastic_sample_ind) + gen_anomaly = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) + gen_images = gen_anomaly + input_mean_results["images"][:,1:,:,:] + #input_images = sess.run(inputs["images"]) + #Bing: Add evaluation metrics + # fetches = {'images': model.inputs['images']} + # fetches.update(model.eval_outputs.items()) + # fetches.update(model.eval_metrics.items()) + # results = sess.run(fetches, feed_dict = feed_dict) + # input_images = results["images"] #shape (batch_size,future_frames,height,width,channel) + # only keep the future frames + #gen_images = gen_images[:, -future_length:] #(8,10,64,64,1) (batch_size, sequences, height, width, channel) + #input_images = input_results["images"][:,-future_length:,:,:] + #input_images = input_results["images"][:,1:,:,:,:] + input_images = input_final [:,1:,:,:,:] + #gen_mse_avg = results["eval_mse/avg"] #shape (batch_size,future_frames) + print("Finish sample ind",stochastic_sample_ind) + input_gen_diff_ = input_images - gen_images + #diff_image_range = pd.cut(input_gen_diff_.flatten(), bins = 4, labels = [-10, -5, 0, 5], right = False) + #diff_image_range = np.reshape(np.array(diff_image_range),input_gen_diff_.shape) + gen_images_all.extend(gen_images) + input_images_all.extend(input_images) + + colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] + cmap_name = 'my_list' + if sample_ind < 100: + for i in range(len(gen_images)): + name = 'Batch_id_' + str(sample_ind) + " + Sample_" + str(i) + gen_images_ = gen_images[i, :] + gen_mse_avg_ = [np.mean(input_gen_diff_[i, frame, :, :, :]**2) for frame in + range(19)] # return the list with 10 (sequence) mse + + input_gen_diff = input_gen_diff_[i,:,:,:,:] + input_images_ = input_images[i, :] + #gen_mse_avg_ = gen_mse_avg[i, :] + fig = plt.figure() + gs = gridspec.GridSpec(4,6) + gs.update(wspace = 0.7,hspace=0.8) + ax1 = plt.subplot(gs[0:2,0:3]) + ax2 = plt.subplot(gs[0:2,3:],sharey=ax1) + ax3 = plt.subplot(gs[2:4,0:3]) + ax4 = plt.subplot(gs[2:4,3:]) + xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] + ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] + plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels) + ax1.title.set_text("(a) Ground Truth") + ax2.title.set_text("(b) SAVP") + ax3.title.set_text("(c) Diff.") + ax4.title.set_text("(d) MSE") + + ax1.xaxis.set_tick_params(labelsize=7) + ax1.yaxis.set_tick_params(labelsize = 7) + ax2.xaxis.set_tick_params(labelsize=7) + ax2.yaxis.set_tick_params(labelsize = 7) + ax3.xaxis.set_tick_params(labelsize=7) + ax3.yaxis.set_tick_params(labelsize = 7) + + init_images = np.zeros((input_images_.shape[1], input_images_.shape[2])) + print("inti images shape", init_images.shape) + xdata, ydata = [], [] + plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) + plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) + #x = np.linspace(0, 64, 64) + #y = np.linspace(0, 64, 64) + #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) + #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) + fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7) + fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7) + + cm = LinearSegmentedColormap.from_list( + cmap_name, "bwr", N = 5) + + plot3 = ax3.imshow(init_images, vmin=-10, vmax=10, cmap=cm)#cmap = 'PuBu_r', + + plot4, = ax4.plot([], [], color = "r") + ax4.set_xlim(0, len(gen_mse_avg_)-1) + ax4.set_ylim(0, 10) + ax4.set_xlabel("Frames", fontsize=10) + #ax4.set_ylabel("MSE", fontsize=10) + ax4.xaxis.set_tick_params(labelsize=7) + ax4.yaxis.set_tick_params(labelsize=7) + + + plots = [plot1, plot2, plot3, plot4] + + #fig.colorbar(plots[1], ax = [ax1, ax2]) + + fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7) + #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7) + #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7) + + def animation_sample(t): + input_image = input_images_[t, :, :, 0] + gen_image = gen_images_[t, :, :, 0] + diff_image = input_gen_diff[t,:,:,0] + + data = gen_mse_avg_[:t + 1] + # x = list(range(len(gen_mse_avg_)))[:t+1] + xdata.append(t) + print("xdata", xdata) + ydata.append(gen_mse_avg_[t]) + + print("ydata", ydata) + # p = sns.lineplot(x=x,y=data,color="b") + # p.tick_params(labelsize=17) + # plt.setp(p.lines, linewidth=6) + plots[0].set_data(input_image) + plots[1].set_data(gen_image) + #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) + #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) + plots[2].set_data(diff_image) + plots[3].set_data(xdata, ydata) + fig.suptitle("Frame " + str(t+1)) + + return plots + + ani = animation.FuncAnimation(fig, animation_sample, frames = len(gen_mse_avg_), interval = 1000, + repeat_delay = 2000) + ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4")) + + else: + pass + + # # for i, gen_mse_avg_ in enumerate(gen_mse_avg): + # # ims = [] + # # fig = plt.figure() + # # plt.xlim(0,len(gen_mse_avg_)) + # # plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg)) + # # plt.xlabel("Frames") + # # plt.ylabel("MSE_AVG") + # # #X = list(range(len(gen_mse_avg_))) + # # #for t, gen_mse_avg_ in enumerate(gen_mse_avg): + # # def animate_metric(j): + # # data = gen_mse_avg_[:(j+1)] + # # x = list(range(len(gen_mse_avg_)))[:(j+1)] + # # p = sns.lineplot(x=x,y=data,color="b") + # # p.tick_params(labelsize=17) + # # plt.setp(p.lines, linewidth=6) + # # ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif")) + # # + # # + # # for i, input_images_ in enumerate(input_images): + # # #context_images_ = (input_results['images'][i]) + # # #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # ims = [] + # # fig = plt.figure() + # # for t, input_image in enumerate(input_images_): + # # im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2,"Frame_" + str(t)) + # # ims.append([im,ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif")) + # # #plt.show() + # # + # # for i,gen_images_ in enumerate(gen_images): + # # ims = [] + # # fig = plt.figure() + # # for t, gen_image in enumerate(gen_images_): + # # im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2, "Frame_" + str(t)) + # # ims.append([im, ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + # # ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif")) + # + # + # # for i, gen_images_ in enumerate(gen_images): + # # #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) + # # #gen_images_ = (gen_images_ * 255.0).astype(np.uint8) + # # #bing + # # context_images_ = (input_results['images'][i]) + # # gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) + # # plt.figure(figsize = (10,2)) + # # gs = gridspec.GridSpec(2,10) + # # gs.update(wspace=0.,hspace=0.) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1))) + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # plt.subplot(gs[t]) + # # plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') # the last index sets the channel. 0 = t2 + # # # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Actual', fontsize = 10) + # # + # # plt.subplot(gs[t + 10]) + # # plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Predicted', fontsize = 10) + # # plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png') + # # plt.clf() + # + # # if args.gif_length: + # # context_and_gen_images = context_and_gen_images[:args.gif_length] + # # save_gif(os.path.join(args.output_gif_dir, gen_images_fname), + # # context_and_gen_images, fps=args.fps) + # # + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # if gen_image.shape[-1] == 1: + # # gen_image = np.tile(gen_image, (1, 1, 3)) + # # else: + # # gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) + # # cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) + + sample_ind += args.batch_size + + + with open(os.path.join(args.output_png_dir, "input_images_all"),"wb") as input_files: + pickle.dump(input_images_all,input_files) + + with open(os.path.join(args.output_png_dir, "gen_images_all"),"wb") as gen_files: + pickle.dump(gen_images_all,gen_files) + + with open(os.path.join(args.output_png_dir, "input_images_all"),"rb") as input_files: + input_images_all = pickle.load(input_files) + + with open(os.path.join(args.output_png_dir, "gen_images_all"),"rb") as gen_files: + gen_images_all=pickle.load(gen_files) + ims = [] + fig = plt.figure() + for frame in range(19): + input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) + #pix_mean = np.mean(input_gen_diff, axis = 0) + #pix_std = np.std(input_gen_diff, axis=0) + im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu') + if frame == 0: + fig.colorbar(im) + ttl = plt.text(1.5, 2, "Frame_" + str(frame +1)) + ims.append([im, ttl]) + ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4")) + plt.close("all") + + ims = [] + fig = plt.figure() + for frame in range(19): + pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) + #pix_mean = np.mean(input_gen_diff, axis = 0) + #pix_std = np.std(input_gen_diff, axis=0) + im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu') + if frame == 0: + fig.colorbar(im) + ttl = plt.text(1.5, 2, "Frame_" + str(frame+1)) + ims.append([im, ttl]) + ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4")) + + gen_images_all = np.array(gen_images_all) + input_images_all = np.array(input_images_all) + # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first + # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2) + # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 ) + + mse_model = np.mean((input_images_all[:, :10,:,:,0] - gen_images_all[:, :10,:,:,0])**2) # look at all timesteps except the first + mse_model_last = np.mean((input_images_all[:,10,:,:,0] - gen_images_all[:, 10,:,:,0])**2) + mse_prev = np.mean((input_images_all[:, :9, :, :, 0] - input_images_all[:, 1:10, :, :, 0])**2 ) + + def psnr(img1, img2): + mse = np.mean((img1 - img2) ** 2) + if mse == 0: return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + + psnr_model = psnr(input_images_all[:, :10, :, :, 0], gen_images_all[:, :10, :, :, 0]) + psnr_model_last = psnr(input_images_all[:, 10, :, :, 0], gen_images_all[:,10, :, :, 0]) + psnr_prev = psnr(input_images_all[:, :9, :, :, 0], input_images_all[:, 1:10, :, :, 0]) + f = open(os.path.join(args.output_png_dir,'prediction_scores_4prediction.txt'), 'w') + f.write("Model MSE: %f\n" % mse_model) + f.write("Model MSE from only last prediction in sequence: %f\n" % mse_model_last) + f.write("Previous Frame MSE: %f\n" % mse_prev) + f.write("Model PSNR: %f\n" % psnr_model) + f.write("Model PSNR from only last prediction in sequence: %f\n" % psnr_model_last) + f.write("Previous frame PSNR: %f\n" % psnr_prev) + f.write("Shape of X_test: " + str(input_images_all.shape)) + f.write("") + f.write("Shape of X_hat: " + str(gen_images_all.shape)) + f.close() + + seed(1) + s = random.sample(range(len(gen_images_all)), 100) + print("******KDP******") + #kernel density plot for checking the model collapse + fig = plt.figure() + kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images") + kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True") + kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability') + plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400) + plt.clf() + + #line plot for evaluating the prediction and groud-truth + for i in [0,3,6,9,12,15,18]: + fig = plt.figure() + plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3) + #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3) + plt.xlabel("Prediction") + plt.ylabel("Real values") + plt.title("Frame_{}".format(i+1)) + plt.plot([250,300], [250,300],color="black") + plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i)))) + plt.clf() + + + mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence) + x = [str(i+1) for i in list(range(19))] + fig,axis = plt.subplots() + mean_f = np.mean(mse_model_by_frames, axis = 0) + median = np.median(mse_model_by_frames, axis=0) + q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0) + q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0) + d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0) + d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0) + plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range") + plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range") + plt.plot(x, median, color="grey", linewidth=0.6, label="Median") + plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean") + plt.title(f'MSE percentile') + plt.xlabel("Frames") + plt.legend(loc=2, fontsize=8) + plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png")) + +if __name__ == '__main__': + main() diff --git a/scripts/generate_orig.py b/scripts/generate_orig.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9e53037d247002743593a9bed36d7f1b9b4249 --- /dev/null +++ b/scripts/generate_orig.py @@ -0,0 +1,193 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import random + +import cv2 +import numpy as np +import tensorflow as tf + +from video_prediction import datasets, models +from video_prediction.utils.ffmpeg_gif import save_gif + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") + parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") + parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") + parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is " + "results_gif_dir/model_fname") + parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is " + "results_png_dir/model_fname") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + + parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + + parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") + parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") + parser.add_argument("--num_epochs", type=int, default=1) + + parser.add_argument("--num_stochastic_samples", type=int, default=5) + parser.add_argument("--gif_length", type=int, help="default is sequence_length") + parser.add_argument("--fps", type=int, default=4) + + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int, default=7) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + args.results_gif_dir = args.results_gif_dir or args.results_dir + args.results_png_dir = args.results_png_dir or args.results_dir + dataset_hparams_dict = {} + model_hparams_dict = {} + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1]) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1]) + else: + if not args.dataset: + raise ValueError('dataset is required when checkpoint is not specified') + if not args.model: + raise ValueError('model is required when checkpoint is not specified') + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + dataset = VideoDataset( + args.input_dir, + mode=args.mode, + num_epochs=args.num_epochs, + seed=args.seed, + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + mode=args.mode, + hparams_dict=hparams_dict, + hparams=args.model_hparams) + + sequence_length = model.hparams.sequence_length + context_frames = model.hparams.context_frames + future_length = sequence_length - context_frames + + if args.num_samples: + if args.num_samples > dataset.num_examples_per_epoch(): + raise ValueError('num_samples cannot be larger than the dataset') + num_examples_per_epoch = args.num_samples + else: + num_examples_per_epoch = dataset.num_examples_per_epoch() + if num_examples_per_epoch % args.batch_size != 0: + raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + + inputs = dataset.make_batch(args.batch_size) + input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + with tf.variable_scope(''): + model.build_graph(input_phs) + + for output_dir in (args.output_gif_dir, args.output_png_dir): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + sess = tf.Session(config=config) + sess.graph.as_default() + model.restore(sess, args.checkpoint) + + sample_ind = 0 + while True: + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results = sess.run(inputs) + except tf.errors.OutOfRangeError: + break + print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) + + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + for stochastic_sample_ind in range(args.num_stochastic_samples): + gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) + # only keep the future frames + gen_images = gen_images[:, -future_length:] + for i, gen_images_ in enumerate(gen_images): + #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) + #gen_images_ = (gen_images_ * 255.0).astype(np.uint8) + context_images_ = (input_results['images'][i]) + gen_images_ = (gen_images_) + + gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) + if args.gif_length: + context_and_gen_images = context_and_gen_images[:args.gif_length] + save_gif(os.path.join(args.output_gif_dir, gen_images_fname), + context_and_gen_images, fps=args.fps) + gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) + for t, gen_image in enumerate(gen_images_): + gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + if gen_image.shape[-1] == 1: + gen_image = np.tile(gen_image, (1, 1, 3)) + else: + gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) + cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) + + sample_ind += args.batch_size + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..331559f6287a4f24c1c19ee9f7f4b03309a22abf --- /dev/null +++ b/scripts/generate_transfer_learning_finetune.py @@ -0,0 +1,731 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import math +import random +import cv2 +import numpy as np +import tensorflow as tf +import pickle +import hickle +from random import seed +import random +import json +import numpy as np +#from six.moves import cPickle +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import matplotlib.animation as animation +import pandas as pd +import re +from video_prediction import datasets, models +from matplotlib.colors import LinearSegmentedColormap +#from matplotlib.ticker import MaxNLocator +#from video_prediction.utils.ffmpeg_gif import save_gif +from skimage.metrics import structural_similarity as ssim +import datetime +# Scarlet 2020/05/28: access to statistical values in json file +from os import path +import sys +sys.path.append(path.abspath('../video_prediction/datasets/')) +from era5_dataset_v2 import Norm_data +from os.path import dirname + +with open("../geo_info.json","r") as json_file: + geo = json.load(json_file) + lat = [round(i,2) for i in geo["lat"]] + lon = [round(i,2) for i in geo["lon"]] + + +def psnr(img1, img2): + mse = np.mean((img1 - img2) ** 2) + if mse == 0: return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type = str, required = True, + help = "either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--results_dir", type = str, default = 'results', + help = "ignored if output_gif_dir is specified") + parser.add_argument("--results_gif_dir", type = str, + help = "default is results_dir. ignored if output_gif_dir is specified") + parser.add_argument("--results_png_dir", type = str, + help = "default is results_dir. ignored if output_png_dir is specified") + parser.add_argument("--output_gif_dir", help = "output directory where samples are saved as gifs. default is " + "results_gif_dir/model_fname") + parser.add_argument("--output_png_dir", help = "output directory where samples are saved as pngs. default is " + "results_png_dir/model_fname") + parser.add_argument("--checkpoint", + help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + + parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val', + help = 'mode for dataset, val or test.') + + parser.add_argument("--dataset", type = str, help = "dataset class name") + parser.add_argument("--dataset_hparams", type = str, + help = "a string of comma separated list of dataset hyperparameters") + parser.add_argument("--model", type = str, help = "model class name") + parser.add_argument("--model_hparams", type = str, + help = "a string of comma separated list of model hyperparameters") + + parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch") + parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)") + parser.add_argument("--num_epochs", type = int, default = 1) + + parser.add_argument("--num_stochastic_samples", type = int, default = 1) + parser.add_argument("--gif_length", type = int, help = "default is sequence_length") + parser.add_argument("--fps", type = int, default = 4) + + parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use") + parser.add_argument("--seed", type = int, default = 7) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + #Bing:20200518 + input_dir = args.input_dir + temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/" + print ("temporal_dir:",temporal_dir) + args.results_gif_dir = args.results_gif_dir or args.results_dir + args.results_png_dir = args.results_png_dir or args.results_dir + dataset_hparams_dict = {} + model_hparams_dict = {} + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, + os.path.split(checkpoint_dir)[1]) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, + os.path.split(checkpoint_dir)[1]) + else: + if not args.dataset: + raise ValueError('dataset is required when checkpoint is not specified') + if not args.model: + raise ValueError('model is required when checkpoint is not specified') + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + dataset = VideoDataset( + args.input_dir, + mode = args.mode, + num_epochs = args.num_epochs, + seed = args.seed, + hparams_dict = dataset_hparams_dict, + hparams = args.dataset_hparams) + + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + mode = args.mode, + hparams_dict = hparams_dict, + hparams = args.model_hparams) + + sequence_length = model.hparams.sequence_length + context_frames = model.hparams.context_frames + future_length = sequence_length - context_frames + + if args.num_samples: + if args.num_samples > dataset.num_examples_per_epoch(): + raise ValueError('num_samples cannot be larger than the dataset') + num_examples_per_epoch = args.num_samples + else: + num_examples_per_epoch = dataset.num_examples_per_epoch() + if num_examples_per_epoch % args.batch_size != 0: + raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + + inputs = dataset.make_batch(args.batch_size) + input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + with tf.variable_scope(''): + model.build_graph(input_phs) + + for output_dir in (args.output_gif_dir, args.output_png_dir): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys = True, indent = 4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4)) + with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4)) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac) + config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True) + sess = tf.Session(config = config) + sess.graph.as_default() + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + model.restore(sess, args.checkpoint) + + sample_ind = 0 + gen_images_all = [] + #Bing:20200410 + persistent_images_all = [] + input_images_all = [] + #Bing:20201417 + print ("temporal_dir:",temporal_dir) + test_temporal_pkl = pickle.load(open(os.path.join(temporal_dir,"T_test.pkl"),"rb")) + #val_temporal_pkl = pickle.load(open("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/T_val.pkl","rb")) + print("test temporal_pkl file looks like folowing", test_temporal_pkl) + + #X_val = hickle.load("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/X_val.hkl") + X_test = hickle.load(os.path.join(temporal_dir,"X_test.hkl")) + is_first=True + + #+++Scarlet:20200528 + norm_cls = Norm_data('T2') + norm = 'minmax' + with open(os.path.join(dirname(input_dir),"hickle/splits/statistics.json")) as js_file: + norm_cls.check_and_set_norm(json.load(js_file),norm) + #---Scarlet:20200528 + while True: + print("Sample id", sample_ind) + if sample_ind <= 24: + pass + elif sample_ind >= len(X_test): + break + else: + gen_images_stochastic = [] + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results = sess.run(inputs) + input_images = input_results["images"] + + + except tf.errors.OutOfRangeError: + break + + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + for stochastic_sample_ind in range(args.num_stochastic_samples): + input_images_all.extend(input_images) + with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files: + pickle.dump(list(input_images_all), input_files) + + gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict) + gen_images_stochastic.append(gen_images) + #print("Stochastic_sample,", stochastic_sample_ind) + for i in range(args.batch_size): + #bing:20200417 + t_stampe = test_temporal_pkl[sample_ind+i] + print("timestamp:",type(t_stampe)) + persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1) + print ("persistent ts",persistent_ts) + persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts)) + persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length] + print("persistent index in test set:", persistent_idx) + print("persistent_X.shape",persistent_X.shape) + persistent_images_all.append(persistent_X) + + + #print("batch", i) + #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] + + + cmap_name = 'my_list' + if sample_ind < 100: + #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str( + # sample_ind) + " + Sample_" + str(i) + name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S") + print ("name",name) + gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :])) + #gen_images_ = gen_images[i, :] + input_images_ = input_images[i, :] + #Bing:20200417 + #persistent_images = ? + #+++Scarlet:20200528 + #print('Scarlet1') + input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm) + persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm) + #---Scarlet:20200528 + gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in + range(sequence_length)] # return the list with 10 (sequence) mse + persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in + range(sequence_length)] # return the list with 10 (sequence) mse + + fig = plt.figure(figsize=(18,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + ts = list(range(10,20)) #[10,11,12,..] + xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] + ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] + + for t in ts: + + #if t==0 : ax1=plt.subplot(gs[t]) + ax1 = plt.subplot(gs[ts.index(t)]) + #+++Scarlet:20200528 + #print('Scarlet2') + input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm) + #---Scarlet:20200528 + plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300) + ax1.title.set_text("t = " + str(t+1-10)) + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + if t == 0: + plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels) + plt.ylabel("Ground Truth", fontsize=10) + plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg")) + plt.clf() + + fig = plt.figure(figsize=(12,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + + for t in ts: + #if t==0 : ax1=plt.subplot(gs[t]) + ax1 = plt.subplot(gs[ts.index(t)]) + #+++Scarlet:20200528 + #print('Scarlet3') + gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm) + #---Scarlet:20200528 + plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300) + ax1.title.set_text("t = " + str(t+1-10)) + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + + plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg")) + plt.clf() + + + fig = plt.figure(figsize=(12,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + for t in ts: + #if t==0 : ax1=plt.subplot(gs[t]) + ax1 = plt.subplot(gs[ts.index(t)]) + #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300) + ax1.title.set_text("t = " + str(t+1-10)) + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + + plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg")) + plt.clf() + + + with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files: + pickle.dump(list(persistent_images_all), input_files) + print ("Save persistent all") + if is_first: + gen_images_all = gen_images_stochastic + is_first = False + else: + gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) + + if args.num_stochastic_samples == 1: + with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files: + pickle.dump(list(gen_images_all[0]), gen_files) + print ("Save generate all") + else: + with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: + pickle.dump(list(gen_images_stochastic), gen_files) + with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: + pickle.dump(list(gen_images_all), gen_files) +## +## +## # fig = plt.figure() +## # gs = gridspec.GridSpec(4,6) +## # gs.update(wspace = 0.7,hspace=0.8) +## # ax1 = plt.subplot(gs[0:2,0:3]) +## # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1) +## # ax3 = plt.subplot(gs[2:4,0:3]) +## # ax4 = plt.subplot(gs[2:4,3:]) +## # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] +## # ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] +## # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels) +## # ax1.title.set_text("(a) Ground Truth") +## # ax2.title.set_text("(b) SAVP") +## # ax3.title.set_text("(c) Diff.") +## # ax4.title.set_text("(d) MSE") +## # +## # ax1.xaxis.set_tick_params(labelsize=7) +## # ax1.yaxis.set_tick_params(labelsize = 7) +## # ax2.xaxis.set_tick_params(labelsize=7) +## # ax2.yaxis.set_tick_params(labelsize = 7) +## # ax3.xaxis.set_tick_params(labelsize=7) +## # ax3.yaxis.set_tick_params(labelsize = 7) +## # +## # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2])) +## # print("inti images shape", init_images.shape) +## # xdata, ydata = [], [] +## # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1) +## # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1) +## # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) +## # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) +## # #x = np.linspace(0, 64, 64) +## # #y = np.linspace(0, 64, 64) +## # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) +## # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) +## # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7) +## # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7) +## # +## # cm = LinearSegmentedColormap.from_list( +## # cmap_name, "bwr", N = 5) +## # +## # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r', +## # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm) # cmap = 'PuBu_r', +## # plot4, = ax4.plot([], [], color = "r") +## # ax4.set_xlim(0, future_length-1) +## # ax4.set_ylim(0, 20) +## # #ax4.set_ylim(0, 0.5) +## # ax4.set_xlabel("Frames", fontsize=10) +## # #ax4.set_ylabel("MSE", fontsize=10) +## # ax4.xaxis.set_tick_params(labelsize=7) +## # ax4.yaxis.set_tick_params(labelsize=7) +## # +## # +## # plots = [plot1, plot2, plot3, plot4] +## # +## # #fig.colorbar(plots[1], ax = [ax1, ax2]) +## # +## # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7) +## # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7) +## # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7) +## # +## # def animation_sample(t): +## # input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 +## # gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 +## # diff_image = input_gen_diff[t,:,:] +## # # p = sns.lineplot(x=x,y=data,color="b") +## # # p.tick_params(labelsize=17) +## # # plt.setp(p.lines, linewidth=6) +## # plots[0].set_data(input_image) +## # plots[1].set_data(gen_image) +## # #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) +## # #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) +## # plots[2].set_data(diff_image) +## # +## # if t >= future_length: +## # #data = gen_mse_avg_[:t + 1] +## # # x = list(range(len(gen_mse_avg_)))[:t+1] +## # xdata.append(t-future_length) +## # print("xdata", xdata) +## # ydata.append(gen_mse_avg_[t]) +## # print("ydata", ydata) +## # plots[3].set_data(xdata, ydata) +## # fig.suptitle("Predicted Frame " + str(t-future_length)) +## # else: +## # #plots[3].set_data(xdata, ydata) +## # fig.suptitle("Context Frame " + str(t)) +## # return plots +## # +## # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000, +## # repeat_delay=2000) +## # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4")) +## +#### else: +#### pass +## + sample_ind += args.batch_size + + + # # for i, gen_mse_avg_ in enumerate(gen_mse_avg): + # # ims = [] + # # fig = plt.figure() + # # plt.xlim(0,len(gen_mse_avg_)) + # # plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg)) + # # plt.xlabel("Frames") + # # plt.ylabel("MSE_AVG") + # # #X = list(range(len(gen_mse_avg_))) + # # #for t, gen_mse_avg_ in enumerate(gen_mse_avg): + # # def animate_metric(j): + # # data = gen_mse_avg_[:(j+1)] + # # x = list(range(len(gen_mse_avg_)))[:(j+1)] + # # p = sns.lineplot(x=x,y=data,color="b") + # # p.tick_params(labelsize=17) + # # plt.setp(p.lines, linewidth=6) + # # ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif")) + # # + # # + # # for i, input_images_ in enumerate(input_images): + # # #context_images_ = (input_results['images'][i]) + # # #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # ims = [] + # # fig = plt.figure() + # # for t, input_image in enumerate(input_images_): + # # im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2,"Frame_" + str(t)) + # # ims.append([im,ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif")) + # # #plt.show() + # # + # # for i,gen_images_ in enumerate(gen_images): + # # ims = [] + # # fig = plt.figure() + # # for t, gen_image in enumerate(gen_images_): + # # im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2, "Frame_" + str(t)) + # # ims.append([im, ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + # # ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif")) + # + # + # # for i, gen_images_ in enumerate(gen_images): + # # #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) + # # #gen_images_ = (gen_images_ * 255.0).astype(np.uint8) + # # #bing + # # context_images_ = (input_results['images'][i]) + # # gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) + # # plt.figure(figsize = (10,2)) + # # gs = gridspec.GridSpec(2,10) + # # gs.update(wspace=0.,hspace=0.) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1))) + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # plt.subplot(gs[t]) + # # plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') # the last index sets the channel. 0 = t2 + # # # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Actual', fontsize = 10) + # # + # # plt.subplot(gs[t + 10]) + # # plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Predicted', fontsize = 10) + # # plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png') + # # plt.clf() + # + # # if args.gif_length: + # # context_and_gen_images = context_and_gen_images[:args.gif_length] + # # save_gif(os.path.join(args.output_gif_dir, gen_images_fname), + # # context_and_gen_images, fps=args.fps) + # # + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # if gen_image.shape[-1] == 1: + # # gen_image = np.tile(gen_image, (1, 1, 3)) + # # else: + # # gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) + # # cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) + + + + with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files: + input_images_all = pickle.load(input_files) + + with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files: + gen_images_all = pickle.load(gen_files) + + with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files: + persistent_images_all = pickle.load(gen_files) + + #+++Scarlet:20200528 + #print('Scarlet4') + input_images_all = np.array(input_images_all) + input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm) + #---Scarlet:20200528 + persistent_images_all = np.array(persistent_images_all) + if len(np.array(gen_images_all).shape) == 6: + for i in range(len(gen_images_all)): + #+++Scarlet:20200528 + #print('Scarlet5') + gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:] + gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm) + #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + #---Scarlet:20200528 + mse_all = [] + psnr_all = [] + ssim_all = [] + f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w') + for i in range(future_length): + mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :, + 0]) ** 2) # look at all timesteps except the first + psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0]) + ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0], + data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min( + input_images_all[:, i + 10, :, :, 0].flatten())) + mse_all.extend([mse_model]) + psnr_all.extend([psnr_model]) + ssim_all.extend([ssim_model]) + results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} + f.write("##########Predicted Frame {}\n".format(str(i + 1))) + f.write("Model MSE: %f\n" % mse_model) + # f.write("Previous Frame MSE: %f\n" % mse_prev) + f.write("Model PSNR: %f\n" % psnr_model) + f.write("Model SSIM: %f\n" % ssim_model) + + + pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb")) + # f.write("Previous frame PSNR: %f\n" % psnr_prev) + f.write("Shape of X_test: " + str(input_images_all.shape)) + f.write("") + f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape)) + + else: + #+++Scarlet:20200528 + #print('Scarlet6') + gen_images_all = np.array(gen_images_all) + gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm) + #---Scarlet:20200528 + + # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first + # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2) + # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 ) + mse_all = [] + psnr_all = [] + ssim_all = [] + persistent_mse_all = [] + persistent_psnr_all = [] + persistent_ssim_all = [] + f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w') + for i in range(future_length): + mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :, + 0]) ** 2) # look at all timesteps except the first + persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :, + 0]) ** 2) # look at all timesteps except the first + + psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0]) + ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0], + data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min( + input_images_all[:, i + 10, :, :, 0].flatten())) + persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0]) + persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0], + data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten())) + mse_all.extend([mse_model]) + psnr_all.extend([psnr_model]) + ssim_all.extend([ssim_model]) + persistent_mse_all.extend([persistent_mse_model]) + persistent_psnr_all.extend([persistent_psnr_model]) + persistent_ssim_all.extend([persistent_ssim_model]) + results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} + + persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all} + f.write("##########Predicted Frame {}\n".format(str(i + 1))) + f.write("Model MSE: %f\n" % mse_model) + # f.write("Previous Frame MSE: %f\n" % mse_prev) + f.write("Model PSNR: %f\n" % psnr_model) + f.write("Model SSIM: %f\n" % ssim_model) + + pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb")) + pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb")) + # f.write("Previous frame PSNR: %f\n" % psnr_prev) + f.write("Shape of X_test: " + str(input_images_all.shape)) + f.write("") + f.write("Shape of X_hat: " + str(gen_images_all.shape)) + + + + #psnr_model = psnr(input_images_all[:, :10, :, :, 0], gen_images_all[:, :10, :, :, 0]) + #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0], gen_images_all[:,10, :, :, 0]) + #psnr_prev = psnr(input_images_all[:, :, :, :, 0], input_images_all[:, 1:10, :, :, 0]) + + # ims = [] + # fig = plt.figure() + # for frame in range(20): + # input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel) + # #pix_mean = np.mean(input_gen_diff, axis = 0) + # #pix_std = np.std(input_gen_diff, axis=0) + # im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu') + # if frame == 0: + # fig.colorbar(im) + # ttl = plt.text(1.5, 2, "Frame_" + str(frame +1)) + # ims.append([im, ttl]) + # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000) + # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4")) + # plt.close("all") + + # ims = [] + # fig = plt.figure() + # for frame in range(19): + # pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) + # #pix_mean = np.mean(input_gen_diff, axis = 0) + # #pix_std = np.std(input_gen_diff, axis=0) + # im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu') + # if frame == 0: + # fig.colorbar(im) + # ttl = plt.text(1.5, 2, "Frame_" + str(frame+1)) + # ims.append([im, ttl]) + # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4")) + + # seed(1) + # s = random.sample(range(len(gen_images_all)), 100) + # print("******KDP******") + # #kernel density plot for checking the model collapse + # fig = plt.figure() + # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images") + # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True") + # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability') + # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400) + # plt.clf() + + #line plot for evaluating the prediction and groud-truth + # for i in [0,3,6,9,12,15,18]: + # fig = plt.figure() + # plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3) + # #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3) + # plt.xlabel("Prediction") + # plt.ylabel("Real values") + # plt.title("Frame_{}".format(i+1)) + # plt.plot([250,300], [250,300],color="black") + # plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i)))) + # plt.clf() + # + # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence) + # x = [str(i+1) for i in list(range(19))] + # fig,axis = plt.subplots() + # mean_f = np.mean(mse_model_by_frames, axis = 0) + # median = np.median(mse_model_by_frames, axis=0) + # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0) + # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0) + # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0) + # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0) + # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range") + # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range") + # plt.plot(x, median, color="grey", linewidth=0.6, label="Median") + # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean") + # plt.title(f'MSE percentile') + # plt.xlabel("Frames") + # plt.legend(loc=2, fontsize=8) + # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png")) + +if __name__ == '__main__': + main() diff --git a/scripts/plot_results.py b/scripts/plot_results.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf49066db5c4276bfb23167bbcde35295c270d4 --- /dev/null +++ b/scripts/plot_results.py @@ -0,0 +1,254 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import glob +import os + +import numpy as np + + +def load_metrics(prefix_fname): + import csv + with open('%s.csv' % prefix_fname, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter='\t', quotechar='|') + rows = list(reader) + # skip header (first row), indices (first column), and means (last column) + metrics = np.array(rows)[1:, 1:-1].astype(np.float32) + return metrics + + +def plot_metric(metric, start_x=0, color=None, label=None, zorder=None): + import matplotlib.pyplot as plt + metric_mean = np.mean(metric, axis=0) + metric_se = np.std(metric, axis=0) / np.sqrt(len(metric)) + kwargs = {} + if color: + kwargs['color'] = color + if zorder: + kwargs['zorder'] = zorder + plt.errorbar(np.arange(len(metric_mean)) + start_x, + metric_mean, yerr=metric_se, linewidth=2, + label=label, **kwargs) + # metric_std = np.std(metric, axis=0) + # plt.plot(np.arange(len(metric_mean)) + start_x, metric_mean, + # linewidth=2, color=color, label=label) + # plt.fill_between(np.arange(len(metric_mean)) + start_x, + # metric_mean - metric_std, metric_mean + metric_std, + # color=color, alpha=0.5) + + +def get_color(method_name): + import matplotlib.pyplot as plt + color_mapping = { + 'ours_vae_gan': plt.cm.Vega20(0), + 'ours_gan': plt.cm.Vega20(2), + 'ours_vae': plt.cm.Vega20(4), + 'ours_vae_l1': plt.cm.Vega20(4), + 'ours_vae_l2': plt.cm.Vega20(14), + 'ours_deterministic': plt.cm.Vega20(6), + 'ours_deterministic_l1': plt.cm.Vega20(6), + 'ours_deterministic_l2': plt.cm.Vega20(10), + 'sna_l1': plt.cm.Vega20(8), + 'sna_l2': plt.cm.Vega20(9), + 'sv2p_time_variant': plt.cm.Vega20(16), + 'sv2p_time_invariant': plt.cm.Vega20(16), + 'svg_lp': plt.cm.Vega20(18), + 'svg_fp': plt.cm.Vega20(18), + 'svg_fp_resized_data_loader': plt.cm.Vega20(18), + 'mathieu': plt.cm.Vega20(8), + 'mcnet': plt.cm.Vega20(8), + 'repeat': 'k', + } + if method_name in color_mapping: + color = color_mapping[method_name] + else: + color = None + for k, v in color_mapping.items(): + if method_name.startswith(k): + color = v + break + return color + + +def get_method_name(method_name): + method_name_mapping = { + 'ours_vae_gan': 'Ours, SAVP', + 'ours_gan': 'Ours, GAN-only', + 'ours_vae': 'Ours, VAE-only', + 'ours_vae_l1': 'Ours, VAE-only, $\mathcal{L}_1$', + 'ours_vae_l2': 'Ours, VAE-only, $\mathcal{L}_2$', + 'ours_deterministic': 'Ours, deterministic', + 'ours_deterministic_l1': 'Ours, deterministic, $\mathcal{L}_1$', + 'ours_deterministic_l2': 'Ours, deterministic, $\mathcal{L}_2$', + 'sna_l1': 'SNA, $\mathcal{L}_1$ (Ebert et al.)', + 'sna_l2': 'SNA, $\mathcal{L}_2$ (Ebert et al.)', + 'sv2p_time_variant': 'SV2P time-variant (Babaeizadeh et al.)', + 'sv2p_time_invariant': 'SV2P time-invariant (Babaeizadeh et al.)', + 'mathieu': 'Mathieu et al.', + 'mcnet': 'MCnet (Villegas et al.)', + 'repeat': 'Copy last frame', + } + return method_name_mapping.get(method_name, method_name) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("results_dir", type=str) + parser.add_argument("--dataset_name", type=str) + parser.add_argument("--method_dirs", type=str, nargs='+', help='directories in results_dir (all of them by default)') + parser.add_argument("--method_names", type=str, nargs='+', help='method names for the header') + parser.add_argument("--web_dir", type=str, help='default is results_dir/web') + parser.add_argument("--plot_fname", type=str, default='metrics.pdf') + parser.add_argument('--usetex', '--use_tex', action='store_true') + parser.add_argument('--save', action='store_true') + parser.add_argument('--mode', choices=['paper', 'rebuttal'], default='paper') + parser.add_argument("--plot_metric_names", type=str, nargs='+') + args = parser.parse_args() + + if args.save: + import matplotlib + matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab! + import matplotlib.pyplot as plt + + if args.usetex: + plt.rc('text', usetex=True) + plt.rc('text.latex', preview=True) + plt.rc('font', family='serif') + + if args.web_dir is None: + args.web_dir = os.path.join(args.results_dir, 'web') + + if args.method_dirs is None: + unsorted_method_dirs = os.listdir(args.results_dir) + # exclude web_dir and all directories that starts with web + if args.web_dir in unsorted_method_dirs: + unsorted_method_dirs.remove(args.web_dir) + unsorted_method_dirs = [method_dir for method_dir in unsorted_method_dirs if not os.path.basename(method_dir).startswith('web')] + # put ground_truth and repeat in the front (if any) + method_dirs = [] + for first_method_dir in ['ground_truth', 'repeat']: + if first_method_dir in unsorted_method_dirs: + unsorted_method_dirs.remove(first_method_dir) + method_dirs.append(first_method_dir) + method_dirs.extend(sorted(unsorted_method_dirs)) + else: + method_dirs = list(args.method_dirs) + if args.method_names is None: + method_names = [get_method_name(method_dir) for method_dir in method_dirs] + else: + method_names = list(args.method_names) + if args.usetex: + method_names = [method_name.replace('kl_weight', r'$\lambda_{\textsc{kl}}$') for method_name in method_names] + method_dirs = [os.path.join(args.results_dir, method_dir) for method_dir in method_dirs] + + # infer task and metric names from first method + metric_fnames = sorted(glob.glob('%s/*_max/metrics/*.csv' % glob.escape(method_dirs[0]))) + task_names = [] + metric_names = [] # all the metric names inferred from file names + for metric_fname in metric_fnames: + head, tail = os.path.split(metric_fname) + task_name = head.split('/')[-2] + metric_name, _ = os.path.splitext(tail) + task_names.append(task_name) + metric_names.append(metric_name) + + # save plots + dataset_name = args.dataset_name or os.path.split(os.path.normpath(args.results_dir))[1] + plots_dir = os.path.join(args.web_dir, 'plots') + if not os.path.exists(plots_dir): + os.makedirs(plots_dir) + + if dataset_name in ('bair', 'bair_action_free'): + context_frames = 2 + training_sequence_length = 12 + plot_metric_names = ('psnr', 'ssim_finn', 'vgg_csim') + elif dataset_name == 'kth': + context_frames = 10 + training_sequence_length = 20 + plot_metric_names = ('psnr', 'ssim_scikit', 'vgg_csim') + elif dataset_name == 'ucf101': + context_frames = 4 + training_sequence_length = 8 + plot_metric_names = ('psnr', 'ssim_mcnet', 'vgg_csim') + else: + raise NotImplementedError + plot_metric_names = args.plot_metric_names or plot_metric_names # metric names to plot + + if args.mode == 'paper': + fig = plt.figure(figsize=(4 * len(plot_metric_names), 5)) + elif args.mode == 'rebuttal': + fig = plt.figure(figsize=(4, 3 * len(plot_metric_names))) + else: + raise ValueError + i_task = 0 + for task_name, metric_name in zip(task_names, metric_names): + if not task_name.endswith('max'): + continue + if metric_name not in plot_metric_names: + continue + + if args.mode == 'paper': + plt.subplot(1, len(plot_metric_names), i_task + 1) + elif args.mode == 'rebuttal': + plt.subplot(len(plot_metric_names), 1, i_task + 1) + + for method_name, method_dir in zip(method_names, method_dirs): + metric_fname = os.path.join(method_dir, task_name, 'metrics', metric_name) + if not os.path.isfile('%s.csv' % metric_fname): + print('Skipping', metric_fname) + continue + metric = load_metrics(metric_fname) + plot_metric(metric, context_frames + 1, color=get_color(os.path.basename(method_dir)), label=method_name) + + plt.grid(axis='y') + plt.axvline(x=training_sequence_length, linewidth=1, color='k') + fontsize = 12 if args.mode == 'rebuttal' else 15 + legend_fontsize = 10 if args.mode == 'rebuttal' else 15 + labelsize = 10 + if args.mode == 'paper': + plt.xlabel('Time Step', fontsize=fontsize) + plt.ylabel({ + 'psnr': 'Average PSNR', + 'ssim': 'Average SSIM', + 'ssim_scikit': 'Average SSIM', + 'ssim_finn': 'Average SSIM', + 'ssim_mcnet': 'Average SSIM', + 'vgg_csim': 'Average VGG cosine similarity', + }[metric_name], fontsize=fontsize) + plt.xlim((context_frames + 1, metric.shape[1] + context_frames)) + plt.tick_params(labelsize=labelsize) + + if args.mode == 'paper': + if i_task == 1: + # plt.title({ + # 'bair': 'Action-conditioned BAIR Dataset', + # 'bair_action_free': 'Action-free BAIR Dataset', + # 'kth': 'KTH Dataset', + # }[dataset_name], fontsize=16) + if len(method_names) <= 4 and sum([len(method_name) for method_name in method_names]) < 90: + ncol = len(method_names) + else: + ncol = (len(method_names) + 1) // 2 + # ncol = 2 + plt.legend(bbox_to_anchor=(0.5, -0.12), loc='upper center', ncol=ncol, fontsize=legend_fontsize) + elif args.mode == 'rebuttal': + if i_task == 0: + # plt.legend(fontsize=legend_fontsize) + plt.legend(bbox_to_anchor=(0.4, -0.12), loc='upper center', fontsize=legend_fontsize) + plt.ylim(ymin=0.8) + plt.xlim((context_frames + 1, metric.shape[1] + context_frames)) + i_task += 1 + fig.tight_layout(rect=(0, 0.1, 1, 1)) + + if args.save: + plt.show(block=False) + print("Saving to", os.path.join(plots_dir, args.plot_fname)) + plt.savefig(os.path.join(plots_dir, args.plot_fname), bbox_inches='tight') + else: + plt.show() + + +if __name__ == '__main__': + main() diff --git a/scripts/plot_results_all.sh b/scripts/plot_results_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..11045ca4831cea4c01ef355dd2c333418d83daeb --- /dev/null +++ b/scripts/plot_results_all.sh @@ -0,0 +1,80 @@ +python scripts/plot_results.py results_test/bair_action_free --method_dirs \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sv2p_time_invariant \ + svg_lp \ + --save --use_tex --plot_fname metrics_all.pdf + +python scripts/plot_results.py results_test/bair --method_dirs \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sna_l1 \ + sna_l2 \ + sv2p_time_variant \ + --save --use_tex --plot_fname metrics_all.pdf + +python scripts/plot_results.py results_test/kth --method_dirs \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sv2p_time_variant \ + sv2p_time_invariant \ + svg_fp_resized_data_loader \ + --save --use_tex --plot_fname metrics_all.pdf + + +python scripts/plot_results.py results_test/bair_action_free --method_dirs \ + sv2p_time_invariant \ + svg_lp \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics.pdf; \ +python scripts/plot_results.py results_test/bair_action_free --method_dirs \ + ours_deterministic \ + ours_vae \ + ours_gan \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics_ablation.pdf; \ +python scripts/plot_results.py results_test/bair_action_free --method_dirs \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + ours_vae_l1 \ + ours_vae_l2 \ + --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf; \ +python scripts/plot_results.py results_test/kth --method_dirs \ + sv2p_time_variant \ + svg_fp_resized_data_loader \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics.pdf; \ +python scripts/plot_results.py results_test/kth --method_dirs \ + ours_deterministic \ + ours_vae \ + ours_gan \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics_ablation.pdf; \ +python scripts/plot_results.py results_test/bair --method_dirs \ + sv2p_time_variant \ + ours_deterministic \ + ours_vae \ + ours_gan \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics.pdf; \ +python scripts/plot_results.py results_test/bair -- + +method_dirs \ + sna_l1 \ + sna_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + ours_vae_l1 \ + ours_vae_l2 \ + --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..79f317ab1b6f30ab7708331ae4671de25b18167b --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,367 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import random +import time + +import numpy as np +import tensorflow as tf + +from video_prediction import datasets, models + + +def add_tag_suffix(summary, tag_suffix): + summary_proto = tf.Summary() + summary_proto.ParseFromString(summary) + summary = summary_proto + + for value in summary.value: + tag_split = value.tag.split('/') + value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:]) + return summary.SerializeToString() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") + parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") + parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " + "default is logs_dir/model_fname, where model_fname consists of " + "information from model and model_hparams") + parser.add_argument("--output_dir_postfix", default="") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") + parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") + + parser.add_argument("--summary_freq", type=int, default=1000, help="save frequency of summaries (except for image and eval summaries) for train/validation set") + parser.add_argument("--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set") + parser.add_argument("--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set") + parser.add_argument("--accum_eval_summary_freq", type=int, default=100000, help="save frequency of accumulated eval summaries for validation set only") + parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps") + parser.add_argument("--save_freq", type=int, default=5000, help="save frequence of model, 0 to disable") + + parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training") + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + if args.output_dir is None: + list_depth = 0 + model_fname = '' + for t in ('model=%s,%s' % (args.model, args.model_hparams)): + if t == '[': + list_depth += 1 + if t == ']': + list_depth -= 1 + if list_depth and t == ',': + t = '..' + if t in '=,': + t = '.' + if t in '[]': + t = '' + model_fname += t + args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix + + if args.resume: + if args.checkpoint: + raise ValueError('resume and checkpoint cannot both be specified') + args.checkpoint = args.output_dir + + dataset_hparams_dict = {} + model_hparams_dict = {} + if args.dataset_hparams_dict: + with open(args.dataset_hparams_dict) as f: + dataset_hparams_dict.update(json.loads(f.read())) + if args.model_hparams_dict: + with open(args.model_hparams_dict) as f: + model_hparams_dict.update(json.loads(f.read())) + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict.update(json.loads(f.read())) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict.update(json.loads(f.read())) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + train_dataset = VideoDataset( + args.input_dir, + mode='train', + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + val_dataset = VideoDataset( + args.val_input_dir or args.input_dir, + mode='val', + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + if val_dataset.hparams.long_sequence_length != val_dataset.hparams.sequence_length: + # the longer dataset is only used for the accum_eval_metrics + long_val_dataset = VideoDataset( + args.val_input_dir or args.input_dir, + mode='val', + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + long_val_dataset.set_sequence_length(val_dataset.hparams.long_sequence_length) + else: + long_val_dataset = None + + variable_scope = tf.get_variable_scope() + variable_scope.set_use_resource(True) + + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': train_dataset.hparams.context_frames,#Bing: TODO what is context_frames? + 'sequence_length': train_dataset.hparams.sequence_length,#Bing: TODO what is sequence_frames + 'repeat': train_dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + hparams_dict=hparams_dict, + hparams=args.model_hparams, + aggregate_nccl=args.aggregate_nccl) + + batch_size = model.hparams.batch_size + train_tf_dataset = train_dataset.make_dataset(batch_size)#Bing: adopt the meteo data prepartion here + train_iterator = train_tf_dataset.make_one_shot_iterator()#Bing:for era5, the problem happen in sess.run(feches) should come from here + # The `Iterator.string_handle()` method returns a tensor that can be evaluated + # and used to feed the `handle` placeholder. + train_handle = train_iterator.string_handle() + val_tf_dataset = val_dataset.make_dataset(batch_size) + val_iterator = val_tf_dataset.make_one_shot_iterator() + val_handle = val_iterator.string_handle() + iterator = tf.data.Iterator.from_string_handle( + train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) + inputs = iterator.get_next() + #Bing for debug + with tf.Session() as sess: + for i in range(2): + print(sess.run(tf.shape(inputs["images"]))) + + # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles + model.build_graph(inputs) + + if long_val_dataset is not None: + # separately build a model for the longer sequence. + # this is needed because the model doesn't support dynamic shapes. + long_hparams_dict = dict(hparams_dict) + long_hparams_dict['sequence_length'] = long_val_dataset.hparams.sequence_length + # use smaller batch size for longer model to prevenet running out of memory + long_hparams_dict['batch_size'] = model.hparams.batch_size // 2 + long_model = VideoPredictionModel( + mode="test", # to not build the losses and discriminators + hparams_dict=long_hparams_dict, + hparams=args.model_hparams, + aggregate_nccl=args.aggregate_nccl) + tf.get_variable_scope().reuse_variables() + long_model.build_graph(long_val_dataset.make_batch(batch_size)) + else: + long_model = None + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + with tf.name_scope("parameter_count"): + # exclude trainable variables that are replicas (used in multi-gpu setting) + trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables) + parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) + + saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) + + # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero + if (args.summary_freq != 0 or args.image_summary_freq != 0 or + args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0): + summary_writer = tf.summary.FileWriter(args.output_dir) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + global_step = tf.train.get_or_create_global_step() + max_steps = model.hparams.max_steps + with tf.Session(config=config) as sess: + print("parameter_count =", sess.run(parameter_count)) + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + #coord = tf.train.Coordinator() + #threads = tf.train.start_queue_runners(sess = sess, coord = coord) + print("Init done: {sess.run(tf.local_variables_initializer())}%") + model.restore(sess, args.checkpoint) + print("Restore processed finished") + sess.run(model.post_init_ops) + print("Model run started") + val_handle_eval = sess.run(val_handle) + print("val handle done") + sess.graph.finalize() + print("graph inalize done") + start_step = sess.run(global_step) + print("global step done") + + def should(step, freq): + if freq is None: + return (step + 1) == (max_steps - start_step) + else: + return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step)) + + def should_eval(step, freq): + # never run eval summaries at the beginning since it's expensive, unless it's the last iteration + return should(step, freq) and (step >= 0 or (step + 1) == (max_steps - start_step)) + + # start at one step earlier to log everything without doing any training + # step is relative to the start_step + for step in range(-1, max_steps - start_step): + if step == 1: + # skip step -1 and 0 for timing purposes (for warmstarting) + start_time = time.time() + + fetches = {"global_step": global_step} + if step >= 0: + fetches["train_op"] = model.train_op + if should(step, args.progress_freq): + fetches['d_loss'] = model.d_loss + fetches['g_loss'] = model.g_loss + fetches['d_losses'] = model.d_losses + fetches['g_losses'] = model.g_losses + if isinstance(model.learning_rate, tf.Tensor): + fetches["learning_rate"] = model.learning_rate + if should(step, args.summary_freq): + fetches["summary"] = model.summary_op + if should(step, args.image_summary_freq): + fetches["image_summary"] = model.image_summary_op + if should_eval(step, args.eval_summary_freq): + fetches["eval_summary"] = model.eval_summary_op + + run_start_time = time.time() + results = sess.run(fetches) #fetch the elements in dictinoary fetch + + run_elapsed_time = time.time() - run_start_time + if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: + print('running train_op took too long (%0.1fs)' % run_elapsed_time) + + if (should(step, args.summary_freq) or + should(step, args.image_summary_freq) or + should_eval(step, args.eval_summary_freq)): + val_fetches = {"global_step": global_step} + if should(step, args.summary_freq): + val_fetches["summary"] = model.summary_op + if should(step, args.image_summary_freq): + val_fetches["image_summary"] = model.image_summary_op + if should_eval(step, args.eval_summary_freq): + val_fetches["eval_summary"] = model.eval_summary_op + val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) + for name, summary in val_results.items(): + if name == 'global_step': + continue + val_results[name] = add_tag_suffix(summary, '_1') + + if should(step, args.summary_freq): + print("recording summary") + summary_writer.add_summary(results["summary"], results["global_step"]) + summary_writer.add_summary(val_results["summary"], val_results["global_step"]) + print("done") + if should(step, args.image_summary_freq): + print("recording image summary") + summary_writer.add_summary(results["image_summary"], results["global_step"]) + summary_writer.add_summary(val_results["image_summary"], val_results["global_step"]) + print("done") + if should_eval(step, args.eval_summary_freq): + print("recording eval summary") + summary_writer.add_summary(results["eval_summary"], results["global_step"]) + summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"]) + print("done") + if should_eval(step, args.accum_eval_summary_freq): + val_datasets = [val_dataset] + val_models = [model] + if long_model is not None: + val_datasets.append(long_val_dataset) + val_models.append(long_model) + for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)): + sess.run(val_model.accum_eval_metrics_reset_op) + # traverse (roughly up to rounding based on the batch size) all the validation dataset + accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size + val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op} + for update_step in range(accum_eval_summary_num_updates): + print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates)) + val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) + accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1)) + print("recording accum eval summary") + summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) + print("done") + if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or + should_eval(step, args.eval_summary_freq) or should_eval(step, args.accum_eval_summary_freq)): + summary_writer.flush() + if should(step, args.progress_freq): + # global_step will have the correct step count if we resume from a checkpoint + # global step is read before it's incremented + steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size + train_epoch = results["global_step"] / steps_per_epoch + print("progress global step %d epoch %0.1f" % (results["global_step"] + 1, train_epoch)) + if step > 0: + elapsed_time = time.time() - start_time + average_time = elapsed_time / step + images_per_sec = batch_size / average_time + remaining_time = (max_steps - (start_step + step + 1)) * average_time + print(" image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % + (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) + + if results['d_losses']: + print("d_loss", results["d_loss"]) + for name, loss in results['d_losses'].items(): + print(" ", name, loss) + if results['g_losses']: + print("g_loss", results["g_loss"]) + for name, loss in results['g_losses'].items(): + print(" ", name, loss) + if isinstance(model.learning_rate, tf.Tensor): + print("learning_rate", results["learning_rate"]) + + if should(step, args.save_freq): + print("saving model to", args.output_dir) + saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step) + print("done") + + +if __name__ == '__main__': + main() diff --git a/scripts/train_all.sh b/scripts/train_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..c695a8b2453956dcac54c5440c0058e5598fa03d --- /dev/null +++ b/scripts/train_all.sh @@ -0,0 +1,40 @@ +# BAIR action-free robot pushing dataset +for model in \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_gan \ + ours_savp \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair --model savp --model_hparams_dict hparams/bair_action_free/${model}/model_hparams.json --output_dir logs/bair_action_free/${model} +done + +# KTH human actions dataset +for model in \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + ours_vae_l1 \ + ours_gan \ + ours_savp \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/kth --dataset kth --model savp --model_hparams_dict hparams/kth/${model}/model_hparams.json --output_dir logs/kth/${model} +done + +# BAIR action-conditioned robot pushing dataset +for model in \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_gan \ + ours_savp \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair --dataset_hparams use_state=True --model savp --model_hparams_dict hparams/bair/${model}/model_hparams.json --output_dir logs/bair/${model} +done +for model in \ + sna_l1 \ + sna_l2 \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair --dataset_hparams use_state=True --model sna --model_hparams_dict hparams/bair/${model}/model_hparams.json --output_dir logs/bair/${model} +done diff --git a/scripts/train_dummy.py b/scripts/train_dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..2f892f69c901f1eaa0a7ce2e57a3d0f6f131a7f9 --- /dev/null +++ b/scripts/train_dummy.py @@ -0,0 +1,274 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import random +import time +import numpy as np +import tensorflow as tf +from video_prediction import datasets, models + + +def add_tag_suffix(summary, tag_suffix): + summary_proto = tf.Summary() + summary_proto.ParseFromString(summary) + summary = summary_proto + + for value in summary.value: + tag_split = value.tag.split('/') + value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:]) + return summary.SerializeToString() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") + parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") + parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " + "default is logs_dir/model_fname, where model_fname consists of " + "information from model and model_hparams") + parser.add_argument("--output_dir_postfix", default="") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") + + # parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training") + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + if args.output_dir is None: + list_depth = 0 + model_fname = '' + for t in ('model=%s,%s' % (args.model, args.model_hparams)): + if t == '[': + list_depth += 1 + if t == ']': + list_depth -= 1 + if list_depth and t == ',': + t = '..' + if t in '=,': + t = '.' + if t in '[]': + t = '' + model_fname += t + args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix + + if args.resume: + if args.checkpoint: + raise ValueError('resume and checkpoint cannot both be specified') + args.checkpoint = args.output_dir + + + model_hparams_dict = {} + if args.model_hparams_dict: + with open(args.model_hparams_dict) as f: + model_hparams_dict.update(json.loads(f.read())) + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict.update(json.loads(f.read())) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + train_dataset = VideoDataset( + args.input_dir, + mode='train') + val_dataset = VideoDataset( + args.val_input_dir or args.input_dir, + mode='val') + + variable_scope = tf.get_variable_scope() + variable_scope.set_use_resource(True) + + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': train_dataset.hparams.context_frames, + 'sequence_length': train_dataset.hparams.sequence_length, + 'repeat': train_dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + hparams_dict=hparams_dict, + hparams=args.model_hparams) + + batch_size = model.hparams.batch_size + train_tf_dataset = train_dataset.make_dataset_v2(batch_size) + train_iterator = train_tf_dataset.make_one_shot_iterator() + # The `Iterator.string_handle()` method returns a tensor that can be evaluated + # and used to feed the `handle` placeholder. + train_handle = train_iterator.string_handle() + val_tf_dataset = val_dataset.make_dataset_v2(batch_size) + val_iterator = val_tf_dataset.make_one_shot_iterator() + val_handle = val_iterator.string_handle() + #iterator = tf.data.Iterator.from_string_handle( + # train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) + inputs = train_iterator.get_next() + val = val_iterator.get_next() + + model.build_graph(inputs) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + with tf.name_scope("parameter_count"): + # exclude trainable variables that are replicas (used in multi-gpu setting) + trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables) + parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) + + saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) + + # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero + summary_writer = tf.summary.FileWriter(args.output_dir) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + + + max_steps = model.hparams.max_steps + print ("max_steps",max_steps) + with tf.Session(config=config) as sess: + print("parameter_count =", sess.run(parameter_count)) + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + #coord = tf.train.Coordinator() + #threads = tf.train.start_queue_runners(sess = sess, coord = coord) + print("Init done: {sess.run(tf.local_variables_initializer())}%") + model.restore(sess, args.checkpoint) + + #sess.run(model.post_init_ops) + + #val_handle_eval = sess.run(val_handle) + #print ("val_handle_val",val_handle_eval) + #print("val handle done") + sess.graph.finalize() + start_step = sess.run(model.global_step) + + + # start at one step earlier to log everything without doing any training + # step is relative to the start_step + for step in range(-1, max_steps - start_step): + global_step = sess.run(model.global_step) + print ("global_step:", global_step) + val_handle_eval = sess.run(val_handle) + + if step == 1: + # skip step -1 and 0 for timing purposes (for warmstarting) + start_time = time.time() + + fetches = {"global_step":model.global_step} + fetches["train_op"] = model.train_op + + # fetches["latent_loss"] = model.latent_loss + fetches["total_loss"] = model.total_loss + if model.__class__.__name__ == "McNetVideoPredictionModel": + fetches["L_p"] = model.L_p + fetches["L_gdl"] = model.L_gdl + fetches["L_GAN"] =model.L_GAN + + + + fetches["summary"] = model.summary_op + + run_start_time = time.time() + #Run training results + #X = inputs["images"].eval(session=sess) + + results = sess.run(fetches) + + run_elapsed_time = time.time() - run_start_time + if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: + print('running train_op took too long (%0.1fs)' % run_elapsed_time) + + #Run testing results + #val_fetches = {"global_step":global_step} + val_fetches = {} + #val_fetches["latent_loss"] = model.latent_loss + #val_fetches["total_loss"] = model.total_loss + val_fetches["summary"] = model.summary_op + val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval}) + + summary_writer.add_summary(results["summary"]) + summary_writer.add_summary(val_results["summary"]) + + + + + val_datasets = [val_dataset] + val_models = [model] + + # for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)): + # sess.run(val_model.accum_eval_metrics_reset_op) + # # traverse (roughly up to rounding based on the batch size) all the validation dataset + # accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size + # val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op} + # for update_step in range(accum_eval_summary_num_updates): + # print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates)) + # val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) + # accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1)) + # print("recording accum eval summary") + # summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) + summary_writer.flush() + + # global_step will have the correct step count if we resume from a checkpoint + # global step is read before it's incremented + steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size + #train_epoch = results["global_step"] / steps_per_epoch + train_epoch = global_step/steps_per_epoch + print("progress global step %d epoch %0.1f" % (global_step + 1, train_epoch)) + if step > 0: + elapsed_time = time.time() - start_time + average_time = elapsed_time / step + images_per_sec = batch_size / average_time + remaining_time = (max_steps - (start_step + step + 1)) * average_time + print("image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % + (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) + + + print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"])) + + print("saving model to", args.output_dir) + saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)##Bing: cheat here a little bit because of the global step issue + print("done") + +if __name__ == '__main__': + main() diff --git a/scripts/train_v2.py b/scripts/train_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f83192d6af7d3953c666f40cc9e6d3766a78e92e --- /dev/null +++ b/scripts/train_v2.py @@ -0,0 +1,362 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import random +import time + +import numpy as np +import tensorflow as tf + +from video_prediction import datasets, models + + +def add_tag_suffix(summary, tag_suffix): + summary_proto = tf.Summary() + summary_proto.ParseFromString(summary) + summary = summary_proto + + for value in summary.value: + tag_split = value.tag.split('/') + value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:]) + return summary.SerializeToString() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") + parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") + parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " + "default is logs_dir/model_fname, where model_fname consists of " + "information from model and model_hparams") + parser.add_argument("--output_dir_postfix", default="") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") + parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") + + parser.add_argument("--summary_freq", type=int, default=1000, help="save frequency of summaries (except for image and eval summaries) for train/validation set") + parser.add_argument("--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set") + parser.add_argument("--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set") + parser.add_argument("--accum_eval_summary_freq", type=int, default=100000, help="save frequency of accumulated eval summaries for validation set only") + parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps") + parser.add_argument("--save_freq", type=int, default=5000, help="save frequence of model, 0 to disable") + + parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training") + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + if args.output_dir is None: + list_depth = 0 + model_fname = '' + for t in ('model=%s,%s' % (args.model, args.model_hparams)): + if t == '[': + list_depth += 1 + if t == ']': + list_depth -= 1 + if list_depth and t == ',': + t = '..' + if t in '=,': + t = '.' + if t in '[]': + t = '' + model_fname += t + args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix + + if args.resume: + if args.checkpoint: + raise ValueError('resume and checkpoint cannot both be specified') + args.checkpoint = args.output_dir + + dataset_hparams_dict = {} + model_hparams_dict = {} + if args.dataset_hparams_dict: + with open(args.dataset_hparams_dict) as f: + dataset_hparams_dict.update(json.loads(f.read())) + if args.model_hparams_dict: + with open(args.model_hparams_dict) as f: + model_hparams_dict.update(json.loads(f.read())) + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict.update(json.loads(f.read())) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict.update(json.loads(f.read())) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + train_dataset = VideoDataset( + args.input_dir, + mode='train', + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + val_dataset = VideoDataset( + args.val_input_dir or args.input_dir, + mode='val', + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + if val_dataset.hparams.long_sequence_length != val_dataset.hparams.sequence_length: + # the longer dataset is only used for the accum_eval_metrics + long_val_dataset = VideoDataset( + args.val_input_dir or args.input_dir, + mode='val', + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) + long_val_dataset.set_sequence_length(val_dataset.hparams.long_sequence_length) + else: + long_val_dataset = None + + variable_scope = tf.get_variable_scope() + variable_scope.set_use_resource(True) + + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': train_dataset.hparams.context_frames, + 'sequence_length': train_dataset.hparams.sequence_length, + 'repeat': train_dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + hparams_dict=hparams_dict, + hparams=args.model_hparams, + aggregate_nccl=args.aggregate_nccl) + + batch_size = model.hparams.batch_size + train_tf_dataset = train_dataset.make_dataset_v2(batch_size)#Bing: adopt the meteo data prepartion here + train_iterator = train_tf_dataset.make_one_shot_iterator()#Bing:for era5, the problem happen in sess.run(feches) should come from here + # The `Iterator.string_handle()` method returns a tensor that can be evaluated + # and used to feed the `handle` placeholder. + train_handle = train_iterator.string_handle() + val_tf_dataset = val_dataset.make_dataset_v2(batch_size) + val_iterator = val_tf_dataset.make_one_shot_iterator() + val_handle = val_iterator.string_handle() + #iterator = tf.data.Iterator.from_string_handle( + # train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) + inputs = train_iterator.get_next() + + # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles + model.build_graph(inputs, finetune=True) + + if long_val_dataset is not None: + # separately build a model for the longer sequence. + # this is needed because the model doesn't support dynamic shapes. + long_hparams_dict = dict(hparams_dict) + long_hparams_dict['sequence_length'] = long_val_dataset.hparams.sequence_length + # use smaller batch size for longer model to prevenet running out of memory + long_hparams_dict['batch_size'] = model.hparams.batch_size // 2 + long_model = VideoPredictionModel( + mode="test", # to not build the losses and discriminators + hparams_dict=long_hparams_dict, + hparams=args.model_hparams, + aggregate_nccl=args.aggregate_nccl) + tf.get_variable_scope().reuse_variables() + long_model.build_graph(long_val_dataset.make_batch(batch_size)) + else: + long_model = None + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + with tf.name_scope("parameter_count"): + # exclude trainable variables that are replicas (used in multi-gpu setting) + trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables) + parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) + + saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) + + # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero + if (args.summary_freq != 0 or args.image_summary_freq != 0 or + args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0): + summary_writer = tf.summary.FileWriter(args.output_dir) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + global_step = tf.train.get_or_create_global_step() + max_steps = model.hparams.max_steps + with tf.Session(config=config) as sess: + print("parameter_count =", sess.run(parameter_count)) + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + #coord = tf.train.Coordinator() + #threads = tf.train.start_queue_runners(sess = sess, coord = coord) + print("Init done: {sess.run(tf.local_variables_initializer())}%") + model.restore(sess, args.checkpoint) + print("Restore processed finished") + sess.run(model.post_init_ops) + print("Model run started") + val_handle_eval = sess.run(val_handle) + print("val handle done") + sess.graph.finalize() + print("graph inalize done") + start_step = sess.run(global_step) + print("global step done") + + def should(step, freq): + if freq is None: + return (step + 1) == (max_steps - start_step) + else: + return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step)) + + def should_eval(step, freq): + # never run eval summaries at the beginning since it's expensive, unless it's the last iteration + return should(step, freq) and (step >= 0 or (step + 1) == (max_steps - start_step)) + + # start at one step earlier to log everything without doing any training + # step is relative to the start_step + for step in range(-1, max_steps - start_step): + if step == 1: + # skip step -1 and 0 for timing purposes (for warmstarting) + start_time = time.time() + + fetches = {"global_step": global_step} + if step >= 0: + fetches["train_op"] = model.train_op + if should(step, args.progress_freq): + fetches['d_loss'] = model.d_loss + fetches['g_loss'] = model.g_loss + fetches['d_losses'] = model.d_losses + fetches['g_losses'] = model.g_losses + if isinstance(model.learning_rate, tf.Tensor): + fetches["learning_rate"] = model.learning_rate + if should(step, args.summary_freq): + fetches["summary"] = model.summary_op + if should(step, args.image_summary_freq): + fetches["image_summary"] = model.image_summary_op + if should_eval(step, args.eval_summary_freq): + fetches["eval_summary"] = model.eval_summary_op + + run_start_time = time.time() + results = sess.run(fetches) #fetch the elements in dictinoary fetch + + run_elapsed_time = time.time() - run_start_time + if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: + print('running train_op took too long (%0.1fs)' % run_elapsed_time) + + if (should(step, args.summary_freq) or + should(step, args.image_summary_freq) or + should_eval(step, args.eval_summary_freq)): + val_fetches = {"global_step": global_step} + if should(step, args.summary_freq): + val_fetches["summary"] = model.summary_op + if should(step, args.image_summary_freq): + val_fetches["image_summary"] = model.image_summary_op + if should_eval(step, args.eval_summary_freq): + val_fetches["eval_summary"] = model.eval_summary_op + val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) + for name, summary in val_results.items(): + if name == 'global_step': + continue + val_results[name] = add_tag_suffix(summary, '_1') + + if should(step, args.summary_freq): + print("recording summary") + summary_writer.add_summary(results["summary"], results["global_step"]) + summary_writer.add_summary(val_results["summary"], val_results["global_step"]) + print("done") + if should(step, args.image_summary_freq): + print("recording image summary") + summary_writer.add_summary(results["image_summary"], results["global_step"]) + summary_writer.add_summary(val_results["image_summary"], val_results["global_step"]) + print("done") + if should_eval(step, args.eval_summary_freq): + print("recording eval summary") + summary_writer.add_summary(results["eval_summary"], results["global_step"]) + summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"]) + print("done") + if should_eval(step, args.accum_eval_summary_freq): + val_datasets = [val_dataset] + val_models = [model] + if long_model is not None: + val_datasets.append(long_val_dataset) + val_models.append(long_model) + for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)): + sess.run(val_model.accum_eval_metrics_reset_op) + # traverse (roughly up to rounding based on the batch size) all the validation dataset + accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size + val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op} + for update_step in range(accum_eval_summary_num_updates): + print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates)) + val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) + accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1)) + print("recording accum eval summary") + summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) + print("done") + if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or + should_eval(step, args.eval_summary_freq) or should_eval(step, args.accum_eval_summary_freq)): + summary_writer.flush() + if should(step, args.progress_freq): + # global_step will have the correct step count if we resume from a checkpoint + # global step is read before it's incremented + steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size + train_epoch = results["global_step"] / steps_per_epoch + print("progress global step %d epoch %0.1f" % (results["global_step"] + 1, train_epoch)) + if step > 0: + elapsed_time = time.time() - start_time + average_time = elapsed_time / step + images_per_sec = batch_size / average_time + remaining_time = (max_steps - (start_step + step + 1)) * average_time + print(" image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % + (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) + + if results['d_losses']: + print("d_loss", results["d_loss"]) + for name, loss in results['d_losses'].items(): + print(" ", name, loss) + if results['g_losses']: + print("g_loss", results["g_loss"]) + for name, loss in results['g_losses'].items(): + print(" ", name, loss) + if isinstance(model.learning_rate, tf.Tensor): + print("learning_rate", results["learning_rate"]) + + if should(step, args.save_freq): + print("saving model to", args.output_dir) + saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step) + print("done") + +if __name__ == '__main__': + main() diff --git a/video_prediction/.DS_Store b/video_prediction/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c43ee62a22d53cc75b181ee2096e26474c200495 Binary files /dev/null and b/video_prediction/.DS_Store differ diff --git a/video_prediction/__init__.py b/video_prediction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3089b251bbbcbe086a03a4ea63571ce9427b2cb9 --- /dev/null +++ b/video_prediction/__init__.py @@ -0,0 +1,3 @@ +from . import losses +from . import metrics +from . import ops diff --git a/video_prediction/datasets/__init__.py b/video_prediction/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..736b8202172051f36586db5579c545863c72e14d --- /dev/null +++ b/video_prediction/datasets/__init__.py @@ -0,0 +1,29 @@ +from .base_dataset import BaseVideoDataset +from .base_dataset import VideoDataset, SequenceExampleVideoDataset, VarLenFeatureVideoDataset +from .google_robot_dataset import GoogleRobotVideoDataset +from .sv2p_dataset import SV2PVideoDataset +from .softmotion_dataset import SoftmotionVideoDataset +from .kth_dataset import KTHVideoDataset +from .ucf101_dataset import UCF101VideoDataset +from .cartgripper_dataset import CartgripperVideoDataset +from .era5_dataset_v2 import ERA5Dataset_v2 +from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly + +def get_dataset_class(dataset): + dataset_mappings = { + 'google_robot': 'GoogleRobotVideoDataset', + 'sv2p': 'SV2PVideoDataset', + 'softmotion': 'SoftmotionVideoDataset', + 'bair': 'SoftmotionVideoDataset', # alias of softmotion + 'kth': 'KTHVideoDataset', + 'ucf101': 'UCF101VideoDataset', + 'cartgripper': 'CartgripperVideoDataset', + "era5":"ERA5Dataset_v2", + "era5_anomaly":"ERA5Dataset_v2_anomaly", + } + dataset_class = dataset_mappings.get(dataset, dataset) + print("datset_class",dataset_class) + dataset_class = globals().get(dataset_class) + if dataset_class is None or not issubclass(dataset_class, BaseVideoDataset): + raise ValueError('Invalid dataset %s' % dataset) + return dataset_class diff --git a/video_prediction/datasets/base_dataset.py b/video_prediction/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb6b00550e178ce845f585f9f07a255220c26da --- /dev/null +++ b/video_prediction/datasets/base_dataset.py @@ -0,0 +1,521 @@ +import glob +import os +import random +import re +from collections import OrderedDict + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.training import HParams + + +class BaseVideoDataset(object): + def __init__(self, input_dir, mode='train', num_epochs=None, seed=None, + hparams_dict=None, hparams=None): + """ + Args: + input_dir: either a directory containing subdirectories train, + val, test, etc, or a directory containing the tfrecords. + mode: either train, val, or test + num_epochs: if None, dataset is iterated indefinitely. + seed: random seed for the op that samples subsequences. + hparams_dict: a dict of `name=value` pairs, where `name` must be + defined in `self.get_default_hparams()`. + hparams: a string of comma separated list of `name=value` pairs, + where `name` must be defined in `self.get_default_hparams()`. + These values overrides any values in hparams_dict (if any). + Note: + self.input_dir is the directory containing the tfrecords. + """ + + self.input_dir = os.path.normpath(os.path.expanduser(input_dir)) + self.mode = mode + self.num_epochs = num_epochs + self.seed = seed + + if self.mode not in ('train', 'val', 'test'): + raise ValueError('Invalid mode %s' % self.mode) + + if not os.path.exists(self.input_dir): + raise FileNotFoundError("input_dir %s does not exist" % self.input_dir) + self.filenames = None + # look for tfrecords in input_dir and input_dir/mode directories + for input_dir in [self.input_dir, os.path.join(self.input_dir, self.mode)]: + filenames = glob.glob(os.path.join(input_dir, '*.tfrecord*')) + if filenames: + self.input_dir = input_dir + self.filenames = sorted(filenames) # ensures order is the same across systems + break + if not self.filenames: + raise FileNotFoundError('No tfrecords were found in %s.' % self.input_dir) + self.dataset_name = os.path.basename(os.path.split(self.input_dir)[0]) + + self.state_like_names_and_shapes = OrderedDict() + self.action_like_names_and_shapes = OrderedDict() + + self.hparams = self.parse_hparams(hparams_dict, hparams) + #Bing: add this for anomaly + if os.path.exists(input_dir+"_mean"): + input_mean_dir = input_dir+"_mean" + self.filenames_mean = sorted(glob.glob(os.path.join(input_mean_dir, '*.tfrecord*'))) + else: + self.filenames_mean = None + + + def get_default_hparams_dict(self): + """ + Returns: + A dict with the following hyperparameters. + + crop_size: crop image into a square with sides of this length. + scale_size: resize image to this size after it has been cropped. + context_frames: the number of ground-truth frames to pass in at + start. + sequence_length: the number of frames in the video sequence, so + state-like sequences are of length sequence_length and + action-like sequences are of length sequence_length - 1. + This number includes the context frames. + long_sequence_length: the number of frames for the long version. + The default is the same as sequence_length. + frame_skip: number of frames to skip in between outputted frames, + so frame_skip=0 denotes no skipping. + time_shift: shift in time by multiples of this, so time_shift=1 + denotes all possible shifts. time_shift=0 denotes no shifting. + It is ignored (equiv. to time_shift=0) when mode != 'train'. + force_time_shift: whether to do the shift in time regardless of + mode. + shuffle_on_val: whether to shuffle the samples regardless if mode + is 'train' or 'val'. Shuffle never happens when mode is 'test'. + use_state: whether to load and return state and actions. + """ + hparams = dict( + crop_size=0, + scale_size=0, + context_frames=1, + sequence_length=0, + long_sequence_length=0, + frame_skip=0, + time_shift=1, + force_time_shift=False, + shuffle_on_val=False, + use_state=False, + ) + return hparams + + def get_default_hparams(self): + return HParams(**self.get_default_hparams_dict()) + + def parse_hparams(self, hparams_dict, hparams): + parsed_hparams = self.get_default_hparams().override_from_dict(hparams_dict or {}) + if hparams: + if not isinstance(hparams, (list, tuple)): + hparams = [hparams] + for hparam in hparams: + parsed_hparams.parse(hparam) + if parsed_hparams.long_sequence_length == 0: + parsed_hparams.long_sequence_length = parsed_hparams.sequence_length + return parsed_hparams + + @property + def jpeg_encoding(self): + raise NotImplementedError + + def set_sequence_length(self, sequence_length): + self.hparams.sequence_length = sequence_length + + def filter(self, serialized_example): + return tf.convert_to_tensor(True) + + def parser(self, serialized_example): + """ + Parses a single tf.train.Example or tf.train.SequenceExample into + images, states, actions, etc tensors. + """ + + + raise NotImplementedError + + def make_dataset(self, batch_size): + filenames = self.filenames + shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) + if shuffle: + random.shuffle(filenames) + + dataset = tf.data.TFRecordDataset(filenames, buffer_size= 8 * 1024 * 1024) #todo: what is buffer_size + dataset = dataset.filter(self.filter) + if shuffle: + dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1024, count=self.num_epochs)) + else: + dataset = dataset.repeat(self.num_epochs) + + def _parser(serialized_example): + state_like_seqs, action_like_seqs = self.parser(serialized_example) + seqs = OrderedDict(list(state_like_seqs.items()) + list(action_like_seqs.items())) + return seqs + + num_parallel_calls = None if shuffle else 1 # for reproducibility (e.g. sampled subclips from the test set) + dataset = dataset.apply(tf.contrib.data.map_and_batch( + _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs + dataset = dataset.prefetch(batch_size) #Bing: Take the data to buffer inorder to save the waiting time for GPU + return dataset + + def make_batch(self, batch_size): + dataset = self.make_dataset(batch_size) + iterator = dataset.make_one_shot_iterator() + return iterator.get_next() + + + def decode_and_preprocess_images(self, image_buffers, image_shape): + def decode_and_preprocess_image(image_buffer): + print("image buffer", tf.shape(image_buffer)) + + image_buffer = tf.reshape(image_buffer,[],name="reshape_1") + + if self.jpeg_encoding: + image = tf.image.decode_jpeg(image_buffer) + print("14********image decode_jpeg********", image) + else: + image = tf.decode_raw(image_buffer, tf.uint8) + print("15 ********image decode_raw********", tf.shape(image)) + print("16 ******** image shape", image_shape) + + image = tf.reshape(image, image_shape, name="reshape_4") ##Bing:the bug #issue 1 is here + crop_size = self.hparams.crop_size + scale_size = self.hparams.scale_size + if crop_size or scale_size: + if not crop_size: + crop_size = min(image_shape[0], image_shape[1]) + image = tf.image.resize_image_with_crop_or_pad(image, crop_size, crop_size) + image = tf.reshape(image, [crop_size, crop_size, 3],"reshape_3") + if scale_size: + # upsample with bilinear interpolation but downsample with area interpolation + if crop_size < scale_size: + image = tf.image.resize_images(image, [scale_size, scale_size], + method=tf.image.ResizeMethod.BILINEAR) + elif crop_size > scale_size: + image = tf.image.resize_images(image, [scale_size, scale_size], + method=tf.image.ResizeMethod.AREA) + else: + # image remains unchanged + pass + return image + + if not isinstance(image_buffers, (list, tuple)): + image_buffers = tf.unstack(image_buffers) + print("17 **************image buffer", image_buffers[0]) + images = [decode_and_preprocess_image(image_buffer) for image_buffer in image_buffers] + images = tf.image.convert_image_dtype(images, dtype=tf.float32) + return images + + def slice_sequences(self, state_like_seqs, action_like_seqs, example_sequence_length): + """ + Slices sequences of length `example_sequence_length` into subsequences + of length `sequence_length`. The dicts of sequences are updated + in-place and the same dicts are returned. + """ + # handle random shifting and frame skip + sequence_length = self.hparams.sequence_length # desired sequence length + frame_skip = self.hparams.frame_skip + time_shift = self.hparams.time_shift + print("22***********example sequence_length",example_sequence_length) + if (time_shift and self.mode == 'train') or self.hparams.force_time_shift: + print("23***********I am here") + assert time_shift > 0 and isinstance(time_shift, int) + if isinstance(example_sequence_length, tf.Tensor): + example_sequence_length = tf.cast(example_sequence_length, tf.int32) + num_shifts = ((example_sequence_length - 1) - (sequence_length - 1) * (frame_skip + 1)) // time_shift + assert_message = ('example_sequence_length has to be at least %d when ' + 'sequence_length=%d, frame_skip=%d.' % + ((sequence_length - 1) * (frame_skip + 1) + 1, + sequence_length, frame_skip)) + with tf.control_dependencies([tf.assert_greater_equal(num_shifts, 0, + data=[example_sequence_length, num_shifts], message=assert_message)]): + t_start = tf.random_uniform([], 0, num_shifts + 1, dtype=tf.int32, seed=self.seed) * time_shift + else: + t_start = 0 + print("20:**********************sequence_len: {}, t_start:{}, frame_skip:{}".format(sequence_length,tf.shape(t_start),frame_skip)) + state_like_t_slice = slice(t_start, t_start + (sequence_length - 1) * (frame_skip + 1) + 1, frame_skip + 1) + action_like_t_slice = slice(t_start, t_start + (sequence_length - 1) * (frame_skip + 1)) + + for example_name, seq in state_like_seqs.items(): + print("21*****************seq*******",seq) + seq = tf.convert_to_tensor(seq)[state_like_t_slice] + print("25**************ses.shape", [self.hparams.sequence_length] + seq.shape.as_list()[1:]) + seq.set_shape([sequence_length] + seq.shape.as_list()[1:]) + state_like_seqs[example_name] = seq + for example_name, seq in action_like_seqs.items(): + seq = tf.convert_to_tensor(seq)[action_like_t_slice] + seq.set_shape([(sequence_length - 1) * (frame_skip + 1)] + seq.shape.as_list()[1:]) + # concatenate actions of skipped frames into single macro actions + seq = tf.reshape(seq, [sequence_length - 1, -1]) + action_like_seqs[example_name] = seq + return state_like_seqs, action_like_seqs + + def num_examples_per_epoch(self): + raise NotImplementedError + + + +class VideoDataset(BaseVideoDataset): + """ + This class supports reading tfrecords where a sequence is stored as + multiple tf.train.Example and each of them is stored under a different + feature name (which is indexed by the time step). + """ + def __init__(self, *args, **kwargs): + super(VideoDataset, self).__init__(*args, **kwargs) + self._max_sequence_length = None + self._dict_message = None + + def _check_or_infer_shapes(self): + """ + Should be called after state_like_names_and_shapes and + action_like_names_and_shapes have been finalized. + """ + state_like_names_and_shapes = OrderedDict([(k, list(v)) for k, v in self.state_like_names_and_shapes.items()]) + action_like_names_and_shapes = OrderedDict([(k, list(v)) for k, v in self.action_like_names_and_shapes.items()]) + from google.protobuf.json_format import MessageToDict + example = next(tf.python_io.tf_record_iterator(self.filenames[0])) + self._dict_message = MessageToDict(tf.train.Example.FromString(example)) + for example_name, name_and_shape in (list(state_like_names_and_shapes.items()) + + list(action_like_names_and_shapes.items())): + name, shape = name_and_shape + feature = self._dict_message['features']['feature'] + names = [name_ for name_ in feature.keys() if re.search(name.replace('%d', '\d+'), name_) is not None] + if not names: + raise ValueError('Could not found any feature with name pattern %s.' % name) + if example_name in self.state_like_names_and_shapes: + sequence_length = len(names) + else: + sequence_length = len(names) + 1 + if self._max_sequence_length is None: + self._max_sequence_length = sequence_length + else: + self._max_sequence_length = min(sequence_length, self._max_sequence_length) + name = names[0] + feature = feature[name] + list_type, = feature.keys() + if list_type == 'floatList': + inferred_shape = (len(feature[list_type]['value']),) + if shape is None: + name_and_shape[1] = inferred_shape + else: + if inferred_shape != shape: + raise ValueError('Inferred shape for feature %s is %r but instead got shape %r.' % + (name, inferred_shape, shape)) + elif list_type == 'bytesList': + image_str, = feature[list_type]['value'] + # try to infer image shape + inferred_shape = None + if not self.jpeg_encoding: + spatial_size = len(image_str) // 4 + height = width = int(np.sqrt(spatial_size)) # assume square image + if len(image_str) == (height * width * 4): + inferred_shape = (height, width, 3) + if shape is None: + if inferred_shape is not None: + name_and_shape[1] = inferred_shape + else: + raise ValueError('Unable to infer shape for feature %s of size %d.' % (name, len(image_str))) + else: + if inferred_shape is not None and inferred_shape != shape: + raise ValueError('Inferred shape for feature %s is %r but instead got shape %r.' % + (name, inferred_shape, shape)) + else: + raise NotImplementedError + self.state_like_names_and_shapes = OrderedDict([(k, tuple(v)) for k, v in state_like_names_and_shapes.items()]) + self.action_like_names_and_shapes = OrderedDict([(k, tuple(v)) for k, v in action_like_names_and_shapes.items()]) + + # set sequence_length to the longest possible if it is not specified + if not self.hparams.sequence_length: + self.hparams.sequence_length = (self._max_sequence_length - 1) // (self.hparams.frame_skip + 1) + 1 + + def set_sequence_length(self, sequence_length): + if not sequence_length: + sequence_length = (self._max_sequence_length - 1) // (self.hparams.frame_skip + 1) + 1 + self.hparams.sequence_length = sequence_length + + def parser(self, serialized_example): + """ + Parses a single tf.train.Example into images, states, actions, etc tensors. + """ + features = dict() + for i in range(self._max_sequence_length): + for example_name, (name, shape) in self.state_like_names_and_shapes.items(): + if example_name == 'images': # special handling for image + features[name % i] = tf.FixedLenFeature([1], tf.string) + else: + features[name % i] = tf.FixedLenFeature(shape, tf.float32) + for i in range(self._max_sequence_length - 1): + for example_name, (name, shape) in self.action_like_names_and_shapes.items(): + features[name % i] = tf.FixedLenFeature(shape, tf.float32) + + # check that the features are in the tfrecord + for name in features.keys(): + if name not in self._dict_message['features']['feature']: + raise ValueError('Feature with name %s not found in tfrecord. Possible feature names are:\n%s' % + (name, '\n'.join(sorted(self._dict_message['features']['feature'].keys())))) + + # parse all the features of all time steps together + features = tf.parse_single_example(serialized_example, features=features) + + + state_like_seqs = OrderedDict([(example_name, []) for example_name in self.state_like_names_and_shapes]) + action_like_seqs = OrderedDict([(example_name, []) for example_name in self.action_like_names_and_shapes]) + for i in range(self._max_sequence_length): + for example_name, (name, shape) in self.state_like_names_and_shapes.items(): + state_like_seqs[example_name].append(features[name % i]) + for i in range(self._max_sequence_length - 1): + for example_name, (name, shape) in self.action_like_names_and_shapes.items(): + action_like_seqs[example_name].append(features[name % i]) + + # for this class, it's much faster to decode and preprocess the entire sequence before sampling a slice + _, image_shape = self.state_like_names_and_shapes['images'] + state_like_seqs['images'] = self.decode_and_preprocess_images(state_like_seqs['images'], image_shape) + + state_like_seqs, action_like_seqs = \ + self.slice_sequences(state_like_seqs, action_like_seqs, self._max_sequence_length) + return state_like_seqs, action_like_seqs + + +class SequenceExampleVideoDataset(BaseVideoDataset): + """ + This class supports reading tfrecords where an entire sequence is stored as + a single tf.train.SequenceExample. + """ + def parser(self, serialized_example): + """ + Parses a single tf.train.SequenceExample into images, states, actions, etc tensors. + """ + sequence_features = dict() + for example_name, (name, shape) in self.state_like_names_and_shapes.items(): + if example_name == 'images': # special handling for image + sequence_features[name] = tf.FixedLenSequenceFeature([1], tf.string) + else: + sequence_features[name] = tf.FixedLenSequenceFeature(shape, tf.float32) + for example_name, (name, shape) in self.action_like_names_and_shapes.items(): + sequence_features[name] = tf.FixedLenSequenceFeature(shape, tf.float32) + + _, sequence_features = tf.parse_single_sequence_example( + serialized_example, sequence_features=sequence_features) + + state_like_seqs = OrderedDict() + action_like_seqs = OrderedDict() + for example_name, (name, shape) in self.state_like_names_and_shapes.items(): + state_like_seqs[example_name] = sequence_features[name] + for example_name, (name, shape) in self.action_like_names_and_shapes.items(): + action_like_seqs[example_name] = sequence_features[name] + + # the sequence_length of this example is determined by the shortest sequence + example_sequence_length = [] + for example_name, seq in state_like_seqs.items(): + example_sequence_length.append(tf.shape(seq)[0]) + for example_name, seq in action_like_seqs.items(): + example_sequence_length.append(tf.shape(seq)[0] + 1) + example_sequence_length = tf.reduce_min(example_sequence_length) + #bing + state_like_seqs, action_like_seqs = \ + self.slice_sequences(state_like_seqs, action_like_seqs, example_sequence_length) + + # decode and preprocess images on the sampled slice only + _, image_shape = self.state_like_names_and_shapes['images'] + state_like_seqs['images'] = self.decode_and_preprocess_images(state_like_seqs['images'], image_shape) + return state_like_seqs, action_like_seqs + + +class VarLenFeatureVideoDataset(BaseVideoDataset): + """ + This class supports reading tfrecords where an entire sequence is stored as + a single tf.train.Example. + + https://github.com/tensorflow/tensorflow/issues/15977 + """ + def filter(self, serialized_example): + features = dict() + features['sequence_length'] = tf.FixedLenFeature((), tf.int64) + features = tf.parse_single_example(serialized_example, features=features) + example_sequence_length = features['sequence_length'] + return tf.greater_equal(example_sequence_length, self.hparams.sequence_length) + + def parser(self, serialized_example): + """ + Parses a single tf.train.SequenceExample into images, states, actions, etc tensors. + """ + print("1.***parser function from class VarLenFeatureVideoDatase") + features = dict() + features['sequence_length'] = tf.FixedLenFeature((), tf.int64) + for example_name, (name, shape) in self.state_like_names_and_shapes.items(): + if example_name == 'images': + #Bing + #features[name] = tf.FixedLenFeature([1], tf.string) + features[name] = tf.VarLenFeature(tf.string) + else: + features[name] = tf.VarLenFeature(tf.float32) + for example_name, (name, shape) in self.action_like_names_and_shapes.items(): + features[name] = tf.VarLenFeature(tf.float32) + + features = tf.parse_single_example(serialized_example, features=features) + example_sequence_length = features['sequence_length'] + + state_like_seqs = OrderedDict() + action_like_seqs = OrderedDict() + for example_name, (name, shape) in self.state_like_names_and_shapes.items(): + if example_name == 'images': + seq = tf.sparse_tensor_to_dense(features[name], '') + else: + seq = tf.sparse_tensor_to_dense(features[name]) + seq = tf.reshape(seq, [example_sequence_length] + list(shape)) + + state_like_seqs[example_name] = seq + + + for example_name, (name, shape) in self.action_like_names_and_shapes.items(): + + seq = tf.sparse_tensor_to_dense(features[name]) + seq = tf.reshape(seq, [example_sequence_length - 1] + list(shape)) + action_like_seqs[example_name] = seq + + #Bing: I replce the self.slice_sequence to the following three lines , the program works, but I need to figure it out what happend inside this function + state_like_seqs, action_like_seqs = \ + self.slice_sequences(state_like_seqs, action_like_seqs, example_sequence_length) + # seq = tf.convert_to_tensor(seq) + # print("25**************ses.shape",[self.hparams.sequence_length] + seq.shape.as_list()[1:]) + # seq.set_shape([self.hparams.sequence_length] + seq.shape.as_list()[1:]) + # state_like_seqs[example_name] = seq + #print("11**********Slide sequences**************** ", action_like_seqs) + # decode and preprocess images on the sampled slice only + _, image_shape = self.state_like_names_and_shapes['images'] + + state_like_seqs['images'] = self.decode_and_preprocess_images(state_like_seqs['images'], image_shape) + return state_like_seqs, action_like_seqs + + +if __name__ == '__main__': + import cv2 + from video_prediction import datasets + + datasets = [ + datasets.SV2PVideoDataset('data/shape', mode='val'), + datasets.SV2PVideoDataset('data/humans', mode='val'), + datasets.SoftmotionVideoDataset('data/bair', mode='val'), + datasets.KTHVideoDataset('data/kth', mode='val'), + datasets.KTHVideoDataset('data/kth_128', mode='val'), + datasets.UCF101VideoDataset('data/ucf101', mode='val'), + ] + batch_size = 4 + + sess = tf.Session() + + for dataset in datasets: + inputs = dataset.make_batch(batch_size) + images = inputs['images'] + images = tf.reshape(images, [-1] + images.get_shape().as_list()[2:]) + images = sess.run(images) + images = (images * 255).astype(np.uint8) + for image in images: + if image.shape[-1] == 1: + image = np.tile(image, [1, 1, 3]) + else: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + cv2.imshow(dataset.input_dir, image) + cv2.waitKey(50) diff --git a/video_prediction/datasets/cartgripper_dataset.py b/video_prediction/datasets/cartgripper_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..68275cb75945c9327290d3759e4e9912d5f5c4a3 --- /dev/null +++ b/video_prediction/datasets/cartgripper_dataset.py @@ -0,0 +1,24 @@ +import itertools + +from .base_dataset import VideoDataset +from .softmotion_dataset import SoftmotionVideoDataset + + +class CartgripperVideoDataset(SoftmotionVideoDataset): + def __init__(self, *args, **kwargs): + VideoDataset.__init__(self, *args, **kwargs) + self.state_like_names_and_shapes['images'] = '%d/image_view0/encoded', (48, 64, 3) + if self.hparams.use_state: + self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (6,) + self.action_like_names_and_shapes['actions'] = '%d/action', (3,) + self._check_or_infer_shapes() + + def get_default_hparams_dict(self): + default_hparams = super(CartgripperVideoDataset, self).get_default_hparams_dict() + hparams = dict( + context_frames=2, + sequence_length=15, + time_shift=3, + use_state=True, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) diff --git a/video_prediction/datasets/era5_dataset_v2.py b/video_prediction/datasets/era5_dataset_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9e32e0c638b4b39c588389e906ba29be5144ee35 --- /dev/null +++ b/video_prediction/datasets/era5_dataset_v2.py @@ -0,0 +1,331 @@ +import argparse +import glob +import itertools +import os +import pickle +import random +import re +import hickle as hkl +import numpy as np +import json +import tensorflow as tf +from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset +# ML 2020/04/14: hack for getting functions of process_netCDF_v2: +from os import path +import sys +sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/')) +import DataPreprocess.process_netCDF_v2 +from DataPreprocess.process_netCDF_v2 import get_unique_vars +from DataPreprocess.process_netCDF_v2 import Calc_data_stat +#from base_dataset import VarLenFeatureVideoDataset +from collections import OrderedDict +from tensorflow.contrib.training import HParams + +class ERA5Dataset_v2(VarLenFeatureVideoDataset): + def __init__(self, *args, **kwargs): + super(ERA5Dataset_v2, self).__init__(*args, **kwargs) + from google.protobuf.json_format import MessageToDict + example = next(tf.python_io.tf_record_iterator(self.filenames[0])) + dict_message = MessageToDict(tf.train.Example.FromString(example)) + feature = dict_message['features']['feature'] + self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels']) + self.image_shape = self.video_shape[1:] + self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape + + def get_default_hparams_dict(self): + default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict() + hparams = dict( + context_frames=10,#Bing: Todo oriignal is 10 + sequence_length=20,#bing: TODO original is 20, + long_sequence_length=20, + force_time_shift=True, + shuffle_on_val=True, + use_state=False, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + + @property + def jpeg_encoding(self): + return False + + + + def num_examples_per_epoch(self): + with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file: + sequence_lengths = sequence_lengths_file.readlines() + sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] + return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) + + + def filter(self, serialized_example): + return tf.convert_to_tensor(True) + + + def make_dataset_v2(self, batch_size): + def parser(serialized_example): + seqs = OrderedDict() + keys_to_features = { + 'width': tf.FixedLenFeature([], tf.int64), + 'height': tf.FixedLenFeature([], tf.int64), + 'sequence_length': tf.FixedLenFeature([], tf.int64), + 'channels': tf.FixedLenFeature([],tf.int64), + # 'images/encoded': tf.FixedLenFeature([], tf.string) + 'images/encoded': tf.VarLenFeature(tf.float32) + } + + # for i in range(20): + # keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string) + parsed_features = tf.parse_single_example(serialized_example, keys_to_features) + print ("Parse features", parsed_features) + seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) + # width = tf.sparse_tensor_to_dense(parsed_features["width"]) + # height = tf.sparse_tensor_to_dense(parsed_features["height"]) + # channels = tf.sparse_tensor_to_dense(parsed_features["channels"]) + # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"]) + images = [] + # for i in range(20): + # images.append(parsed_features["images/encoded"].values[i]) + # images = parsed_features["images/encoded"] + # images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets) + # seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '') + # Parse the string into an array of pixels corresponding to the image + # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32) + + # images = seq + print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2])) + images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new") + seqs["images"] = images + return seqs + filenames = self.filenames + print ("FILENAMES",filenames) + #TODO: + #temporal_filenames = self.temporal_filenames + shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) + if shuffle: + random.shuffle(filenames) + dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) # todo: what is buffer_size + print("files", self.filenames) + print("mode", self.mode) + dataset = dataset.filter(self.filter) + if shuffle: + dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs)) + else: + dataset = dataset.repeat(self.num_epochs) + + num_parallel_calls = None if shuffle else 1 + dataset = dataset.apply(tf.contrib.data.map_and_batch( + parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) + #dataset = dataset.map(parser) + # num_parallel_calls = None if shuffle else 1 # for reproducibility (e.g. sampled subclips from the test set) + # dataset = dataset.apply(tf.contrib.data.map_and_batch( + # _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs + dataset = dataset.prefetch(batch_size) # Bing: Take the data to buffer inorder to save the waiting time for GPU + + return dataset + + + + def make_batch(self, batch_size): + dataset = self.make_dataset_v2(batch_size) + iterator = dataset.make_one_shot_iterator() + return iterator.get_next() + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _bytes_list_feature(values): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + +def _floats_feature(value): + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + +def save_tf_record(output_fname, sequences): + print('saving sequences to %s' % output_fname) + with tf.python_io.TFRecordWriter(output_fname) as writer: + for sequence in sequences: + num_frames = len(sequence) + height, width, channels = sequence[0].shape + encoded_sequence = np.array([list(image) for image in sequence]) + + features = tf.train.Features(feature={ + 'sequence_length': _int64_feature(num_frames), + 'height': _int64_feature(height), + 'width': _int64_feature(width), + 'channels': _int64_feature(channels), + 'images/encoded': _floats_feature(encoded_sequence.flatten()), + }) + example = tf.train.Example(features=features) + writer.write(example.SerializeToString()) + +class Norm_data: + """ + Class for normalizing data. The statistical data for normalization (minimum, maximum, average, standard deviation etc.) is expected to be available from a statistics-dictionary + created with the calc_data_stat-class (see 'process_netCDF_v2.py'. + """ + + ### set known norms and the requested statistics (to be retrieved from statistics.json) here ### + known_norms = {} + known_norms["minmax"] = ["min","max"] + known_norms["znorm"] = ["avg","sigma"] + + def __init__(self,varnames): + """Initialize the instance by setting the variable names to be handled and the status (for sanity checks only) as attributes.""" + varnames_uni, _, nvars = get_unique_vars(varnames) + + self.varnames = varnames_uni + self.status_ok= False + + def check_and_set_norm(self,stat_dict,norm): + """ + Checks if the statistics-dictionary provides the required data for selected normalization method and expands the instance's attributes accordingly. + Example: minmax-normalization requires the minimum and maximum value of a variable named var1. + If the requested values are provided by the statistics-dictionary, the instance gets the attributes 'var1min' and 'var1max',respectively. + """ + + # some sanity checks + if not norm in self.known_norms.keys(): # valid normalization requested? + print("Please select one of the following known normalizations: ") + for norm_avail in self.known_norms.keys(): + print(norm_avail) + raise ValueError("Passed normalization '"+norm+"' is unknown.") + + if not all(items in stat_dict for items in self.varnames): # all variables found in dictionary? + print("Keys in stat_dict:") + print(stat_dict.keys()) + + print("Requested variables:") + print(self.varnames) + raise ValueError("Could not find all requested variables in statistics dictionary.") + + # create all attributes for the instance + for varname in self.varnames: + for stat_name in self.known_norms[norm]: + #setattr(self,varname+stat_name,stat_dict[varname][0][stat_name]) + setattr(self,varname+stat_name,Calc_data_stat.get_stat_vars(stat_dict,stat_name,varname)) + + self.status_ok = True # set status for normalization -> ready + + def norm_var(self,data,varname,norm): + """ + Performs given normalization on input data (given that the instance is already set up) + """ + + # some sanity checks + if not self.status_ok: raise ValueError("Norm_data-instance needs to be initialized and checked first.") # status ready? + + if not norm in self.known_norms.keys(): # valid normalization requested? + print("Please select one of the following known normalizations: ") + for norm_avail in self.known_norms.keys(): + print(norm_avail) + raise ValueError("Passed normalization '"+norm+"' is unknown.") + + # do the normalization and return + if norm == "minmax": + return((data[...] - getattr(self,varname+"min"))/(getattr(self,varname+"max") - getattr(self,varname+"min"))) + elif norm == "znorm": + return((data[...] - getattr(self,varname+"avg"))/getattr(self,varname+"sigma")**2) + + def denorm_var(self,data,varname,norm): + """ + Performs given denormalization on input data (given that the instance is already set up), i.e. inverse method to norm_var + """ + + # some sanity checks + if not self.status_ok: raise ValueError("Norm_data-instance needs to be initialized and checked first.") # status ready? + + if not norm in self.known_norms.keys(): # valid normalization requested? + print("Please select one of the following known normalizations: ") + for norm_avail in self.known_norms.keys(): + print(norm_avail) + raise ValueError("Passed normalization '"+norm+"' is unknown.") + + # do the denormalization and return + if norm == "minmax": + return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"min")) + elif norm == "znorm": + return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg")) + + +def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128 + # ML 2020/04/08: + # Include vars_in for more flexible data handling (normalization and reshaping) + # and optional keyword argument for kind of normalization + + if 'norm' in kwargs: + norm = kwargs.get("norm") + else: + norm = "minmax" + print("Make use of default minmax-normalization...") + + output_dir = os.path.join(output_dir,partition_name) + os.makedirs(output_dir,exist_ok=True) + + norm_cls = Norm_data(vars_in) # init normalization-instance + nvars = len(vars_in) + + # open statistics file and feed it to norm-instance + with open(os.path.join(input_dir,"statistics.json")) as js_file: + norm_cls.check_and_set_norm(json.load(js_file),norm) + + sequences = [] + sequence_iter = 0 + sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') + X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl")) + X_possible_starts = [i for i in range(len(X_train) - seq_length)] + for X_start in X_possible_starts: + print("Interation", sequence_iter) + X_end = X_start + seq_length + #seq = X_train[X_start:X_end, :, :,:] + seq = X_train[X_start:X_end,:,:] + #print("*****len of seq ***.{}".format(len(seq))) + #seq = list(np.array(seq).reshape((len(seq), 64, 64, 3))) + seq = list(np.array(seq).reshape((seq_length, height, width, nvars))) + if not sequences: + last_start_sequence_iter = sequence_iter + print("reading sequences starting at sequence %d" % sequence_iter) + sequences.append(seq) + sequence_iter += 1 + sequence_lengths_file.write("%d\n" % len(seq)) + + if len(sequences) == sequences_per_file: + ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables + sequences = np.array(sequences) + ### normalization + for i in range(nvars): + sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm) + + output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) + output_fname = os.path.join(output_dir, output_fname) + save_tf_record(output_fname, list(sequences)) + sequences = [] + sequence_lengths_file.close() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking") + parser.add_argument("output_dir", type=str) + # ML 2020/04/08 S + # Add vars for ensuring proper normalization and reshaping of sequences + parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.") + parser.add_argument("-height",type=int,default=64) + parser.add_argument("-width",type = int,default=64) + parser.add_argument("-seq_length",type=int,default=20) + args = parser.parse_args() + current_path = os.getcwd() + #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits" + #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5" + partition_names = ['train','val', 'test'] #64,64,3 val has issue# + + for partition_name in partition_names: + read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq + #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5 + +if __name__ == '__main__': + main() + diff --git a/video_prediction/datasets/era5_dataset_v2_anomaly.py b/video_prediction/datasets/era5_dataset_v2_anomaly.py new file mode 100644 index 0000000000000000000000000000000000000000..0daa13b332a50f848b1d41fb8bd5c75079c88b7e --- /dev/null +++ b/video_prediction/datasets/era5_dataset_v2_anomaly.py @@ -0,0 +1,274 @@ +import argparse +import glob +import itertools +import os +import pickle +import random +import re +import netCDF4 +import hickle as hkl +import numpy as np +import tensorflow as tf +import pandas as pd +from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset +from collections import OrderedDict +from tensorflow.contrib.training import HParams + +units = "hours since 2000-01-01 00:00:00" +calendar = "gregorian" + +class ERA5Dataset_v2_anomaly(VarLenFeatureVideoDataset): + def __init__(self, *args, **kwargs): + super(ERA5Dataset_v2_anomaly, self).__init__(*args, **kwargs) + from google.protobuf.json_format import MessageToDict + example = next(tf.python_io.tf_record_iterator(self.filenames[0])) + dict_message = MessageToDict(tf.train.Example.FromString(example)) + feature = dict_message['features']['feature'] + image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels']) + self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape + + def get_default_hparams_dict(self): + default_hparams = super(ERA5Dataset_v2_anomaly, self).get_default_hparams_dict() + hparams = dict( + context_frames=10, + sequence_length=20, + long_sequence_length=40, + force_time_shift=True, + shuffle_on_val=True, + use_state=False, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + @property + def jpeg_encoding(self): + return False + + + def num_examples_per_epoch(self): + with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file: + sequence_lengths = sequence_lengths_file.readlines() + sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] + return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) + + + def filter(self, serialized_example): + return tf.convert_to_tensor(True) + + + + def make_dataset_v2(self, batch_size): + def parser(serialized_example): + seqs = OrderedDict() + keys_to_features = { + # 'width': tf.FixedLenFeature([], tf.int64), + # 'height': tf.FixedLenFeature([], tf.int64), + 'sequence_length': tf.FixedLenFeature([], tf.int64), + # 'channels': tf.FixedLenFeature([],tf.int64), + # 'images/encoded': tf.FixedLenFeature([], tf.string) + 'images/encoded': tf.VarLenFeature(tf.float32) + } + # for i in range(20): + # keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string) + parsed_features = tf.parse_single_example(serialized_example, keys_to_features) + seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) + images = [] + # for i in range(20): + # images.append(parsed_features["images/encoded"].values[i]) + # images = parsed_features["images/encoded"] + # images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets) + # seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '') + # Parse the string into an array of pixels corresponding to the image + # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32) + + # images = seq + images = tf.reshape(seq, [20, 64, 64, 1], name = "reshape_new") + seqs["images"] = images + return seqs + filenames = self.filenames + filenames_mean = self.filenames_mean + shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) + if shuffle: + random.shuffle(filenames) + dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8 * 1024 * 1024) # todo: what is buffer_size + dataset = dataset.filter(self.filter) + #Bing: for Anomaly + dataset_mean = tf.data.TFRecordDataset(filenames_mean, buffer_size = 8 * 1024 * 1024) + dataset_mean = dataset_mean.filter(self.filter) + if shuffle: + dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size = 1024, count = self.num_epochs)) + dataset_mean = dataset_mean.apply(tf.contrib.data.shuffle_and_repeat(buffer_size = 1024, count = self.num_epochs)) + else: + dataset = dataset.repeat(self.num_epochs) + dataset_mean = dataset_mean.repeat(self.num_epochs) + + num_parallel_calls = None if shuffle else 1 + dataset = dataset.apply(tf.contrib.data.map_and_batch( + parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) + dataset_mean = dataset_mean.apply(tf.contrib.data.map_and_batch( + parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) + #dataset = dataset.map(parser) + # num_parallel_calls = None if shuffle else 1 # for reproducibility (e.g. sampled subclips from the test set) + # dataset = dataset.apply(tf.contrib.data.map_and_batch( + # _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs + dataset = dataset.prefetch(batch_size) # Bing: Take the data to buffer inorder to save the waiting time for GPU + dataset_mean = dataset_mean.prefetch(batch_size) + return dataset, dataset_mean + + def make_batch_v2(self, batch_size): + dataset, dataset_mean = self.make_dataset_v2(batch_size) + iterator = dataset.make_one_shot_iterator() + interator2 = dataset_mean.make_one_shot_iterator() + return iterator.get_next(), interator2.get_next() + + + def make_data_mean(self,batch_size): + pass + + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _bytes_list_feature(values): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + +def _floats_feature(value): + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + + +def save_tf_record(output_fname, sequences): + print('saving sequences to %s' % output_fname) + with tf.python_io.TFRecordWriter(output_fname) as writer: + for sequence in sequences: + num_frames = len(sequence) + height, width, channels = sequence[0].shape + encoded_sequence = np.array([list(image) for image in sequence]) + + features = tf.train.Features(feature={ + 'sequence_length': _int64_feature(num_frames), + 'height': _int64_feature(height), + 'width': _int64_feature(width), + 'channels': _int64_feature(channels), + 'images/encoded': _floats_feature(encoded_sequence.flatten()), + }) + example = tf.train.Example(features=features) + writer.write(example.SerializeToString()) + + +def extract_anomaly_one_pixel(X, X_timestamps,pixel): + print("Processing Pixel {}, {}".format(pixel[0],pixel[1])) + dates = [x.date() for x in X_timestamps] + df = pd.DataFrame(data = X[:, pixel[0], pixel[1]], index = dates) + df_mean = df.groupby(df.index).mean() + df2 = pd.merge(df, df_mean, left_index = True, right_index = True) + df2.columns = ["Real","Daily_mean"] + df2["Anomaly"] = df2["Real"] - df2["Daily_mean"] + daily_mean = df2["Daily_mean"].values + anomaly = df2["Anomaly"].values + return daily_mean, anomaly + +def extract_anomaly_all_pixels(X, X_timestamps): + #daily_mean, anomaly = extract_anomaly_one_pixel(X, X_timestamps, pixel = [0, 0]) + daily_mean_pixels = np.zeros((X.shape[0], X.shape[1], X.shape[2])) + anomaly_pixels = np.zeros((X.shape[0], X.shape[1], X.shape[2])) + #daily_mean_all_pixels = [extract_anomaly_one_pixel(X, X_timestamps, pixel = [i, j])[0] for i in range(X.shape[1]) for j in range(X.shape[2])] + #anomaly_all_pixels = [extract_anomaly_one_pixel(X, X_timestamps, pixel = [i, j])[1] for i in range(X.shape[1]) for j in range(X.shape[2])] + for i in range(X.shape[1]): + for j in range(X.shape[2]): + daily_mean, anomaly = extract_anomaly_one_pixel(X, X_timestamps, pixel = [i, j]) + daily_mean_pixels[:,i,j] = daily_mean + anomaly_pixels[:,i,j] = anomaly + return daily_mean_pixels, anomaly_pixels + + +def read_frames_and_save_tf_records(output_dir, input_dir, partition_name, N_seq, sequences_per_file=128):#Bing: original 128 + output_orig_dir = os.path.join(output_dir,partition_name + "_orig") + output_time_dir = os.path.join(output_dir,partition_name + "_time") + output_mean_dir = os.path.join(output_dir,partition_name + "_mean") + output_anomaly_dir = os.path.join(output_dir, partition_name ) + + + if not os.path.exists(output_orig_dir): os.mkdir(output_orig_dir) + if not os.path.exists(output_time_dir): os.mkdir(output_time_dir) + if not os.path.exists(output_mean_dir): os.mkdir(output_mean_dir) + if not os.path.exists(output_anomaly_dir): os.mkdir(output_anomaly_dir) + sequences = [] + sequences_time = [] + sequences_mean = [] + sequences_anomaly = [] + + sequence_iter = 0 + sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') + X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl")) + X_time = hkl.load(os.path.join(input_dir, "Time_time_" + partition_name + ".hkl")) + print ("X shape", X_train.shape) + X_timestamps = [netCDF4.num2date(x, units = units, calendar = calendar) for x in X_time] + + print("X_time example", X_time[:10]) + print("X_time after to date", X_timestamps[:10]) + daily_mean_all_pixels, anomaly_all_pixels = extract_anomaly_all_pixels(X_train, X_timestamps) + + X_possible_starts = [i for i in range(len(X_train) - N_seq)] + for X_start in X_possible_starts: + print("Interation", sequence_iter) + X_end = X_start + N_seq + #seq = X_train[X_start:X_end, :, :,:] + seq = X_train[X_start:X_end,:,:] + seq_time = X_time[X_start:X_end] + seq_mean = daily_mean_all_pixels[X_start:X_end,:,:] + seq_anomaly = anomaly_all_pixels[X_start:X_end,:,:] + #print("*****len of seq ***.{}".format(len(seq))) + seq = list(np.array(seq).reshape((len(seq), 64, 64, 1))) + seq_time = list(np.array(seq_time)) + seq_mean = list(np.array(seq_mean).reshape((len(seq_mean), 64, 64, 1))) + seq_anomaly = list(np.array(seq_anomaly).reshape((len(seq_anomaly), 64, 64, 1))) + if not sequences: + last_start_sequence_iter = sequence_iter + print("reading sequences starting at sequence %d" % sequence_iter) + sequences.append(seq) + sequences_time.append(seq_time) + sequences_mean.append(seq_mean) + sequences_anomaly.append(seq_anomaly) + sequence_iter += 1 + sequence_lengths_file.write("%d\n" % len(seq)) + + if len(sequences) == sequences_per_file: + output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) + output_orig_fname = os.path.join(output_orig_dir, output_fname) + output_time_fname = os.path.join(output_time_dir,'sequence_{0}_to_{1}.hkl'.format(last_start_sequence_iter, sequence_iter - 1)) + output_mean_fname = os.path.join(output_mean_dir, output_fname) + output_anomaly_fname = os.path.join(output_anomaly_dir, output_fname) + + save_tf_record(output_orig_fname, sequences) + hkl.dump(sequences_time,output_time_fname ) + #save_tf_record(output_time_fname,sequences_time) + save_tf_record(output_mean_fname, sequences_mean) + save_tf_record(output_anomaly_fname, sequences_anomaly) + sequences[:] = [] + sequences_time[:] = [] + sequences_mean[:] = [] + sequences_anomaly[:] = [] + sequence_lengths_file.close() + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking") + parser.add_argument("output_dir", type=str) + # parser.add_argument("image_size_h", type=int) + # parser.add_argument("image_size_v", type = int) + args = parser.parse_args() + current_path = os.getcwd() + #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits" + #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5" + partition_names = ['train', 'val', 'test'] + for partition_name in partition_names: + read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,partition_name=partition_name, N_seq=20) #Bing: Todo need check the N_seq + #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5 + +if __name__ == '__main__': + main() + diff --git a/video_prediction/datasets/google_robot_dataset.py b/video_prediction/datasets/google_robot_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fd2e975cf5bd83e7c342d7effbead68c1063498a --- /dev/null +++ b/video_prediction/datasets/google_robot_dataset.py @@ -0,0 +1,40 @@ +import itertools +import os + +from .base_dataset import VideoDataset + + +class GoogleRobotVideoDataset(VideoDataset): + """ + https://sites.google.com/site/brainrobotdata/home/push-dataset + """ + def __init__(self, *args, **kwargs): + super(GoogleRobotVideoDataset, self).__init__(*args, **kwargs) + self.state_like_names_and_shapes['images'] = 'move/%d/image/encoded', (512, 640, 3) + if self.hparams.use_state: + self.state_like_names_and_shapes['states'] = 'move/%d/endeffector/vec_pitch_yaw', (5,) + self.action_like_names_and_shapes['actions'] = 'move/%d/commanded_pose/vec_pitch_yaw', (5,) + self._check_or_infer_shapes() + + def get_default_hparams_dict(self): + default_hparams = super(GoogleRobotVideoDataset, self).get_default_hparams_dict() + hparams = dict( + context_frames=2, + sequence_length=15, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def num_examples_per_epoch(self): + if os.path.basename(self.input_dir) == 'push_train': + count = 51615 + elif os.path.basename(self.input_dir) == 'push_testseen': + count = 1038 + elif os.path.basename(self.input_dir) == 'push_testnovel': + count = 995 + else: + raise NotImplementedError + return count + + @property + def jpeg_encoding(self): + return True diff --git a/video_prediction/datasets/kth_dataset.py b/video_prediction/datasets/kth_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..40fb6bf57b8219fc7e75c3759df9e6b38fffeb30 --- /dev/null +++ b/video_prediction/datasets/kth_dataset.py @@ -0,0 +1,159 @@ +import argparse +import glob +import itertools +import os +import pickle +import random +import re +import tensorflow as tf +import numpy as np +import skimage.io + +from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset + + +class KTHVideoDataset(VarLenFeatureVideoDataset): + def __init__(self, *args, **kwargs): + super(KTHVideoDataset, self).__init__(*args, **kwargs) + from google.protobuf.json_format import MessageToDict + example = next(tf.python_io.tf_record_iterator(self.filenames[0])) + dict_message = MessageToDict(tf.train.Example.FromString(example)) + feature = dict_message['features']['feature'] + image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels']) + + self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape + + def get_default_hparams_dict(self): + default_hparams = super(KTHVideoDataset, self).get_default_hparams_dict() + hparams = dict( + context_frames=10, + sequence_length=20, + long_sequence_length=40, + force_time_shift=True, + shuffle_on_val=True, + use_state=False, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + @property + def jpeg_encoding(self): + return False + + def num_examples_per_epoch(self): + with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file: + sequence_lengths = sequence_lengths_file.readlines() + sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] + return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _bytes_list_feature(values): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def partition_data(input_dir): + # List files and corresponding person IDs + fnames = glob.glob(os.path.join(input_dir, '*/*')) + fnames = [fname for fname in fnames if os.path.isdir(fname)] + print("frames",fnames[0]) + + persons = [re.match('person(\d+)_\w+_\w+', os.path.split(fname)[1]).group(1) for fname in fnames] + persons = np.array([int(person) for person in persons]) + + train_mask = persons <= 16 + + train_fnames = [fnames[i] for i in np.where(train_mask)[0]] + test_fnames = [fnames[i] for i in np.where(~train_mask)[0]] + + random.shuffle(train_fnames) + + pivot = int(0.95 * len(train_fnames)) + train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:] + return train_fnames, val_fnames, test_fnames + + +def save_tf_record(output_fname, sequences): + print('saving sequences to %s' % output_fname) + with tf.python_io.TFRecordWriter(output_fname) as writer: + for sequence in sequences: + num_frames = len(sequence) + height, width, channels = sequence[0].shape + encoded_sequence = [image.tostring() for image in sequence] + features = tf.train.Features(feature={ + 'sequence_length': _int64_feature(num_frames), + 'height': _int64_feature(height), + 'width': _int64_feature(width), + 'channels': _int64_feature(channels), + 'images/encoded': _bytes_list_feature(encoded_sequence), + }) + example = tf.train.Example(features=features) + writer.write(example.SerializeToString()) + + +def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128): + partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test + sequences = [] + sequence_iter = 0 + sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') + for video_iter, video_dir in enumerate(video_dirs): #Interate group (e.g. walking) each person + meta_partition_name = partition_name if partition_name == 'test' else 'train' + meta_fname = os.path.join(os.path.split(video_dir)[0], '%s_meta%dx%d.pkl' % + (meta_partition_name, image_size, image_size)) + with open(meta_fname, "rb") as f: + data = pickle.load(f) # The data has 62 items, each item is a dict, with three keys. "vid","n", and "files", Each file has 4 channels, each channel has n sequence images with 64*64 png + + vid = os.path.split(video_dir)[1] + (d,) = [d for d in data if d['vid'] == vid] + for frame_fnames_iter, frame_fnames in enumerate(d['files']): + frame_fnames = [os.path.join(video_dir, frame_fname) for frame_fname in frame_fnames] + frames = skimage.io.imread_collection(frame_fnames) + # they are grayscale images, so just keep one of the channels + frames = [frame[..., 0:1] for frame in frames] + + if not sequences: #The length of the sequence in sequences could be different + last_start_sequence_iter = sequence_iter + print("reading sequences starting at sequence %d" % sequence_iter) + + sequences.append(frames) + sequence_iter += 1 + sequence_lengths_file.write("%d\n" % len(frames)) + + if (len(sequences) == sequences_per_file or + (video_iter == (len(video_dirs) - 1) and frame_fnames_iter == (len(d['files']) - 1))): + output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) + output_fname = os.path.join(output_dir, output_fname) + save_tf_record(output_fname, sequences) + sequences[:] = [] + sequence_lengths_file.close() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, help="directory containing the processed directories " + "boxing, handclapping, handwaving, " + "jogging, running, walking") + parser.add_argument("--output_dir", type=str) + parser.add_argument("--image_size", type=int) + args = parser.parse_args() + + partition_names = ['train', 'val', 'test'] + print("input dir", args.input_dir) + partition_fnames = partition_data(args.input_dir) + print("partiotion_fnames[0]", partition_fnames[0]) + + for partition_name, partition_fnames in zip(partition_names, partition_fnames): + partition_dir = os.path.join(args.output_dir, partition_name) + if not os.path.exists(partition_dir): + os.makedirs(partition_dir) + read_frames_and_save_tf_records(partition_dir, partition_fnames, args.image_size) + + +if __name__ == '__main__': + main() diff --git a/video_prediction/datasets/softmotion_dataset.py b/video_prediction/datasets/softmotion_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..106869ca46d857d56fdd8d0f0d20ae5e2b69c3b5 --- /dev/null +++ b/video_prediction/datasets/softmotion_dataset.py @@ -0,0 +1,82 @@ +import itertools +import os +import re + +import tensorflow as tf + +from video_prediction.utils import tf_utils +from .base_dataset import VideoDataset + + +class SoftmotionVideoDataset(VideoDataset): + """ + https://sites.google.com/view/sna-visual-mpc + """ + def __init__(self, *args, **kwargs): + super(SoftmotionVideoDataset, self).__init__(*args, **kwargs) + # infer name of image feature and check if object_pos feature is present + from google.protobuf.json_format import MessageToDict + example = next(tf.python_io.tf_record_iterator(self.filenames[0])) + dict_message = MessageToDict(tf.train.Example.FromString(example)) + feature = dict_message['features']['feature'] + image_names = set() + for name in feature.keys(): + m = re.search('\d+/(\w+)/encoded', name) + if m: + image_names.add(m.group(1)) + # look for image_aux1 and image_view0 in that order of priority + image_name = None + for name in ['image_aux1', 'image_view0']: + if name in image_names: + image_name = name + break + if not image_name: + if len(image_names) == 1: + image_name = image_names.pop() + else: + raise ValueError('The examples have images under more than one name.') + self.state_like_names_and_shapes['images'] = '%%d/%s/encoded' % image_name, None + if self.hparams.use_state: + self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (3,) + self.action_like_names_and_shapes['actions'] = '%d/action', (4,) + if any([re.search('\d+/object_pos', name) for name in feature.keys()]): + self.state_like_names_and_shapes['object_pos'] = '%d/object_pos', None # shape is (2 * num_designated_pixels) + self._check_or_infer_shapes() + + def get_default_hparams_dict(self): + default_hparams = super(SoftmotionVideoDataset, self).get_default_hparams_dict() + hparams = dict( + context_frames=2, + sequence_length=12, + long_sequence_length=30, + time_shift=2, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + @property + def jpeg_encoding(self): + return False + + def parser(self, serialized_example): + state_like_seqs, action_like_seqs = super(SoftmotionVideoDataset, self).parser(serialized_example) + if 'object_pos' in state_like_seqs: + object_pos = state_like_seqs['object_pos'] + height, width, _ = self.state_like_names_and_shapes['images'][1] + object_pos = tf.reshape(object_pos, [object_pos.shape[0].value, -1, 2]) + pix_distribs = tf.stack([tf_utils.pixel_distribution(object_pos_, height, width) + for object_pos_ in tf.unstack(object_pos, axis=1)], axis=-1) + state_like_seqs['pix_distribs'] = pix_distribs + return state_like_seqs, action_like_seqs + + def num_examples_per_epoch(self): + # extract information from filename to count the number of trajectories in the dataset + count = 0 + for filename in self.filenames: + match = re.search('traj_(\d+)_to_(\d+).tfrecords', os.path.basename(filename)) + start_traj_iter = int(match.group(1)) + end_traj_iter = int(match.group(2)) + count += end_traj_iter - start_traj_iter + 1 + + # alternatively, the dataset size can be determined like this, but it's very slow + # count = sum(sum(1 for _ in tf.python_io.tf_record_iterator(filename)) for filename in filenames) + return count diff --git a/video_prediction/datasets/sv2p_dataset.py b/video_prediction/datasets/sv2p_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed072297a7c853b2b389d7c1816885743bad0cf --- /dev/null +++ b/video_prediction/datasets/sv2p_dataset.py @@ -0,0 +1,65 @@ +import itertools +import os + +from .base_dataset import VideoDataset + + +class SV2PVideoDataset(VideoDataset): + def __init__(self, *args, **kwargs): + super(SV2PVideoDataset, self).__init__(*args, **kwargs) + self.dataset_name = os.path.basename(os.path.split(self.input_dir)[0]) + self.state_like_names_and_shapes['images'] = 'image_%d', (64, 64, 3) + if self.dataset_name == 'shape': + if self.hparams.use_state: + self.state_like_names_and_shapes['states'] = 'state_%d', (2,) + self.action_like_names_and_shapes['actions'] = 'action_%d', (2,) + elif self.dataset_name == 'humans': + if self.hparams.use_state: + raise ValueError('SV2PVideoDataset does not have states, use_state should be False') + else: + raise NotImplementedError + self._check_or_infer_shapes() + + def get_default_hparams_dict(self): + default_hparams = super(SV2PVideoDataset, self).get_default_hparams_dict() + if self.dataset_name == 'shape': + hparams = dict( + context_frames=1, + sequence_length=6, + time_shift=0, + use_state=False, + ) + elif self.dataset_name == 'humans': + hparams = dict( + context_frames=10, + sequence_length=20, + use_state=False, + ) + else: + raise NotImplementedError + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def num_examples_per_epoch(self): + if self.dataset_name == 'shape': + if os.path.basename(self.input_dir) == 'train': + count = 43415 + elif os.path.basename(self.input_dir) == 'val': + count = 2898 + else: # shape dataset doesn't have a test set + raise NotImplementedError + elif self.dataset_name == 'humans': + if os.path.basename(self.input_dir) == 'train': + count = 23910 + elif os.path.basename(self.input_dir) == 'val': + count = 10472 + elif os.path.basename(self.input_dir) == 'test': + count = 7722 + else: + raise NotImplementedError + else: + raise NotImplementedError + return count + + @property + def jpeg_encoding(self): + return True diff --git a/video_prediction/datasets/ucf101_dataset.py b/video_prediction/datasets/ucf101_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cba078ab4f66acb3fd80546a016fb9dd94b1d551 --- /dev/null +++ b/video_prediction/datasets/ucf101_dataset.py @@ -0,0 +1,212 @@ +import argparse +import glob +import itertools +import os +import random +import re +from multiprocessing import Pool +import cv2 +import tensorflow as tf + +from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset + + +class UCF101VideoDataset(VarLenFeatureVideoDataset): + def __init__(self, *args, **kwargs): + super(UCF101VideoDataset, self).__init__(*args, **kwargs) + self.state_like_names_and_shapes['images'] = 'images/encoded', (240, 320, 3) + + def get_default_hparams_dict(self): + default_hparams = super(UCF101VideoDataset, self).get_default_hparams_dict() + hparams = dict( + context_frames=4, + sequence_length=8, + random_crop_size=0, + use_state=False, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + @property + def jpeg_encoding(self): + return True + + def decode_and_preprocess_images(self, image_buffers, image_shape): + if self.hparams.crop_size: + raise NotImplementedError + if self.hparams.scale_size: + raise NotImplementedError + image_buffers = tf.reshape(image_buffers, [-1]) + if not isinstance(image_buffers, (list, tuple)): + image_buffers = tf.unstack(image_buffers) + image_size = tf.image.extract_jpeg_shape(image_buffers[0])[:2] # should be the same as image_shape[:2] + if self.hparams.random_crop_size: + random_crop_size = [self.hparams.random_crop_size] * 2 + crop_y = tf.random_uniform([], minval=0, maxval=image_size[0] - random_crop_size[0], dtype=tf.int32) + crop_x = tf.random_uniform([], minval=0, maxval=image_size[1] - random_crop_size[1], dtype=tf.int32) + crop_window = [crop_y, crop_x] + random_crop_size + images = [tf.image.decode_and_crop_jpeg(image_buffer, crop_window) for image_buffer in image_buffers] + images = tf.image.convert_image_dtype(images, dtype=tf.float32) + images.set_shape([None] + random_crop_size + [image_shape[-1]]) + else: + images = [tf.image.decode_jpeg(image_buffer) for image_buffer in image_buffers] + images = tf.image.convert_image_dtype(images, dtype=tf.float32) + images.set_shape([None] + list(image_shape)) + # TODO: only random crop for training + return images + + def num_examples_per_epoch(self): + # extract information from filename to count the number of trajectories in the dataset + count = 0 + for filename in self.filenames: + match = re.search('sequence_(\d+)_to_(\d+).tfrecords', os.path.basename(filename)) + start_traj_iter = int(match.group(1)) + end_traj_iter = int(match.group(2)) + count += end_traj_iter - start_traj_iter + 1 + + # alternatively, the dataset size can be determined like this, but it's very slow + # count = sum(sum(1 for _ in tf.python_io.tf_record_iterator(filename)) for filename in filenames) + return count + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _bytes_list_feature(values): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def partition_data(input_dir, train_test_list_dir): + train_list_fnames = glob.glob(os.path.join(train_test_list_dir, 'trainlist*.txt')) + test_list_fnames = glob.glob(os.path.join(train_test_list_dir, 'testlist*.txt')) + test_list_fnames_mathieu = [os.path.join(train_test_list_dir, 'testlist01.txt')] + + def read_fnames(list_fnames): + fnames = [] + for list_fname in sorted(list_fnames): + with open(list_fname, 'r') as f: + while True: + fname = f.readline() + if not fname: + break + fnames.append(fname.split('\n')[0].split(' ')[0]) + return fnames + + train_fnames = read_fnames(train_list_fnames) + test_fnames = read_fnames(test_list_fnames) + test_fnames_mathieu = read_fnames(test_list_fnames_mathieu) + + train_fnames = [os.path.join(input_dir, train_fname) for train_fname in train_fnames] + test_fnames = [os.path.join(input_dir, test_fname) for test_fname in test_fnames] + test_fnames_mathieu = [os.path.join(input_dir, test_fname) for test_fname in test_fnames_mathieu] + # only use every 10 videos as in Mathieu et al. + test_fnames_mathieu = test_fnames_mathieu[::10] + + random.shuffle(train_fnames) + + pivot = int(0.95 * len(train_fnames)) + train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:] + return train_fnames, val_fnames, test_fnames, test_fnames_mathieu + + +def read_video(fname): + if not os.path.isfile(fname): + raise FileNotFoundError + vidcap = cv2.VideoCapture(fname) + frames, (success, image) = [], vidcap.read() + while success: + frames.append(image) + success, image = vidcap.read() + return frames + + +def save_tf_record(output_fname, sequences, preprocess_image): + print('saving sequences to %s' % output_fname) + with tf.python_io.TFRecordWriter(output_fname) as writer: + for sequence in sequences: + num_frames = len(sequence) + height, width, channels = sequence[0].shape + encoded_sequence = [preprocess_image(image) for image in sequence] + features = tf.train.Features(feature={ + 'sequence_length': _int64_feature(num_frames), + 'height': _int64_feature(height), + 'width': _int64_feature(width), + 'channels': _int64_feature(channels), + 'images/encoded': _bytes_list_feature(encoded_sequence), + }) + example = tf.train.Example(features=features) + writer.write(example.SerializeToString()) + + +def read_videos_and_save_tf_records(output_dir, fnames, start_sequence_iter=None, + end_sequence_iter=None, sequences_per_file=128): + print('started process with PID:', os.getpid()) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if start_sequence_iter is None: + start_sequence_iter = 0 + if end_sequence_iter is None: + end_sequence_iter = len(fnames) + + def preprocess_image(image): + if image.shape != (240, 320, 3): + image = cv2.resize(image, (320, 240), interpolation=cv2.INTER_LINEAR) + return tf.compat.as_bytes(cv2.imencode(".jpg", image)[1].tobytes()) + + print('reading and saving sequences {0} to {1}'.format(start_sequence_iter, end_sequence_iter)) + + sequences = [] + for sequence_iter in range(start_sequence_iter, end_sequence_iter): + if not sequences: + last_start_sequence_iter = sequence_iter + print("reading sequences starting at sequence %d" % sequence_iter) + + sequences.append(read_video(fnames[sequence_iter])) + + if len(sequences) == sequences_per_file or sequence_iter == (end_sequence_iter - 1): + output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter) + output_fname = os.path.join(output_dir, output_fname) + save_tf_record(output_fname, sequences, preprocess_image) + sequences[:] = [] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_dir", type=str, help="directory containing the directories of " + "classes, each of which contains avi files.") + parser.add_argument("train_test_list_dir", type=str, help='directory containing trainlist*.txt' + 'and testlist*.txt files.') + parser.add_argument("output_dir", type=str) + parser.add_argument('--num_workers', type=int, default=1, help='number of parallel workers') + args = parser.parse_args() + + partition_names = ['train', 'val', 'test', 'test_mathieu'] + partition_fnames = partition_data(args.input_dir, args.train_test_list_dir) + + for partition_name, partition_fnames in zip(partition_names, partition_fnames): + partition_dir = os.path.join(args.output_dir, partition_name) + if not os.path.exists(partition_dir): + os.makedirs(partition_dir) + + if args.num_workers > 1: + num_seqs_per_worker = len(partition_fnames) // args.num_workers + start_seq_iters = [num_seqs_per_worker * i for i in range(args.num_workers)] + end_seq_iters = [num_seqs_per_worker * (i + 1) - 1 for i in range(args.num_workers)] + end_seq_iters[-1] = len(partition_fnames) + + p = Pool(args.num_workers) + p.starmap(read_videos_and_save_tf_records, zip([partition_dir] * args.num_workers, + [partition_fnames] * args.num_workers, + start_seq_iters, end_seq_iters)) + else: + read_videos_and_save_tf_records(partition_dir, partition_fnames) + + +if __name__ == '__main__': + main() diff --git a/video_prediction/flow_ops.py b/video_prediction/flow_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..caf9bad772621d6fbbf8a28895f16d03bed0561a --- /dev/null +++ b/video_prediction/flow_ops.py @@ -0,0 +1,79 @@ +import tensorflow as tf + + +def image_warp(im, flow): + """Performs a backward warp of an image using the predicted flow. + + Args: + im: Batch of images. [num_batch, height, width, channels] + flow: Batch of flow vectors. [num_batch, height, width, 2] + Returns: + warped: transformed image of the same shape as the input image. + + Implementation taken from here: https://github.com/simonmeister/UnFlow + """ + with tf.variable_scope('image_warp'): + + num_batch, height, width, channels = tf.unstack(tf.shape(im)) + max_x = tf.cast(width - 1, 'int32') + max_y = tf.cast(height - 1, 'int32') + zero = tf.zeros([], dtype='int32') + + # We have to flatten our tensors to vectorize the interpolation + im_flat = tf.reshape(im, [-1, channels]) + flow_flat = tf.reshape(flow, [-1, 2]) + + # Floor the flow, as the final indices are integers + # The fractional part is used to control the bilinear interpolation. + flow_floor = tf.to_int32(tf.floor(flow_flat)) + bilinear_weights = flow_flat - tf.floor(flow_flat) + + # Construct base indices which are displaced with the flow + pos_x = tf.tile(tf.range(width), [height * num_batch]) + grid_y = tf.tile(tf.expand_dims(tf.range(height), 1), [1, width]) + pos_y = tf.tile(tf.reshape(grid_y, [-1]), [num_batch]) + + x = flow_floor[:, 0] + y = flow_floor[:, 1] + xw = bilinear_weights[:, 0] + yw = bilinear_weights[:, 1] + + # Compute interpolation weights for 4 adjacent pixels + # expand to num_batch * height * width x 1 for broadcasting in add_n below + wa = tf.expand_dims((1 - xw) * (1 - yw), 1) # top left pixel + wb = tf.expand_dims((1 - xw) * yw, 1) # bottom left pixel + wc = tf.expand_dims(xw * (1 - yw), 1) # top right pixel + wd = tf.expand_dims(xw * yw, 1) # bottom right pixel + + x0 = pos_x + x + x1 = x0 + 1 + y0 = pos_y + y + y1 = y0 + 1 + + x0 = tf.clip_by_value(x0, zero, max_x) + x1 = tf.clip_by_value(x1, zero, max_x) + y0 = tf.clip_by_value(y0, zero, max_y) + y1 = tf.clip_by_value(y1, zero, max_y) + + dim1 = width * height + batch_offsets = tf.range(num_batch) * dim1 + base_grid = tf.tile(tf.expand_dims(batch_offsets, 1), [1, dim1]) + base = tf.reshape(base_grid, [-1]) + + base_y0 = base + y0 * width + base_y1 = base + y1 * width + idx_a = base_y0 + x0 + idx_b = base_y1 + x0 + idx_c = base_y0 + x1 + idx_d = base_y1 + x1 + + Ia = tf.gather(im_flat, idx_a) + Ib = tf.gather(im_flat, idx_b) + Ic = tf.gather(im_flat, idx_c) + Id = tf.gather(im_flat, idx_d) + + warped_flat = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) + warped = tf.reshape(warped_flat, [num_batch, height, width, channels]) + warped.set_shape(im.shape) + + return warped diff --git a/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction/layers/BasicConvLSTMCell.py new file mode 100644 index 0000000000000000000000000000000000000000..321f6cc7e05320cf83e1173d8004429edf07ec24 --- /dev/null +++ b/video_prediction/layers/BasicConvLSTMCell.py @@ -0,0 +1,148 @@ + +import tensorflow as tf +from .layer_def import * + +class ConvRNNCell(object): + """Abstract object representing an Convolutional RNN cell. + """ + + def __call__(self, inputs, state, scope=None): + """Run this RNN cell on inputs, starting from the given state. + """ + raise NotImplementedError("Abstract method") + + @property + def state_size(self): + """size(s) of state(s) used by this cell. + """ + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer or TensorShape: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + def zero_state(self,input, dtype): + """Return zero-filled state tensor(s). + Args: + batch_size: int, float, or unit Tensor representing the batch size. + dtype: the data type to use for the state. + Returns: + tensor of shape '[batch_size x shape[0] x shape[1] x num_features] + filled with zeros + """ + + shape = self.shape + num_features = self.num_features + #x= tf.placeholder(tf.float32, shape=[input.shape[0], shape[0], shape[1], num_features * 2])#Bing: add this to + zeros = tf.zeros([tf.shape(input)[0], shape[0], shape[1], num_features * 2]) + #zeros = tf.zeros_like(x) + return zeros + + +class BasicConvLSTMCell(ConvRNNCell): + """Basic Conv LSTM recurrent network cell. The + """ + + def __init__(self, shape, filter_size, num_features, forget_bias=1.0, input_size=None, + state_is_tuple=False, activation=tf.nn.tanh): + """Initialize the basic Conv LSTM cell. + Args: + shape: int tuple thats the height and width of the cell + filter_size: int tuple thats the height and width of the filter + num_features: int thats the depth of the cell + forget_bias: float, The bias added to forget gates (see above). + input_size: Deprecated and unused. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. The latter behavior will soon be deprecated. + activation: Activation function of the inner states. + """ + # if not state_is_tuple: + # logging.warn("%s: Using a concatenated state is slower and will soon be " + # "deprecated. Use state_is_tuple=True.", self) + if input_size is not None: + logging.warn("%s: The input_size parameter is deprecated.", self) + self.shape = shape + self.filter_size = filter_size + self.num_features = num_features + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + self._activation = activation + + @property + def state_size(self): + return (LSTMStateTuple(self._num_units, self._num_units) + if self._state_is_tuple else 2 * self._num_units) + + @property + def output_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None,reuse=None): + """Long short-term memory cell (LSTM).""" + with tf.variable_scope(scope or type(self).__name__,reuse=reuse): # "BasicLSTMCell" + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state) + concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat) + + new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * + self._activation(j)) + new_h = self._activation(new_c) * tf.nn.sigmoid(o) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = tf.concat(axis = 3, values = [new_c, new_h]) + return new_h, new_state + + +def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=None): + """convolution: + Args: + args: a 4D Tensor or a list of 4D, batch x n, Tensors. + filter_size: int tuple of filter height and width. + num_features: int, number of features. + bias_start: starting value to initialize the bias; 0 by default. + scope: VariableScope for the created subgraph; defaults to "Linear". + Returns: + A 4D Tensor with shape [batch h w num_features] + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + + # Calculate the total size of arguments on dimension 1. + total_arg_size_depth = 0 + shapes = [a.get_shape().as_list() for a in args] + for shape in shapes: + if len(shape) != 4: + raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes)) + if not shape[3]: + raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes)) + else: + total_arg_size_depth += shape[3] + + dtype = [a.dtype for a in args][0] + + # Now the computation. + with tf.variable_scope(scope or "Conv"): + matrix = tf.get_variable( + "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype) + if len(args) == 1: + res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME') + else: + res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME') + if not bias: + return res + bias_term = tf.get_variable( + "Bias", [num_features], + dtype = dtype, + initializer = tf.constant_initializer( + bias_start, dtype = dtype)) + return res + bias_term diff --git a/video_prediction/layers/__init__.py b/video_prediction/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8530ffd70f0899c8d8e0832d0dcd377b78bbe349 --- /dev/null +++ b/video_prediction/layers/__init__.py @@ -0,0 +1 @@ +from .normalization import fused_instance_norm diff --git a/video_prediction/layers/layer_def.py b/video_prediction/layers/layer_def.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7f4387001c9318507ad809d7176071312742d0 --- /dev/null +++ b/video_prediction/layers/layer_def.py @@ -0,0 +1,160 @@ +"""functions used to construct different architectures +""" + +import tensorflow as tf +import numpy as np + +weight_decay = 0.0005 +def _activation_summary(x): + """Helper to create summaries for activations. + Creates a summary that provides a histogram of activations. + Creates a summary that measure the sparsity of activations. + Args: + x: Tensor + Returns: + nothing + """ + tensor_name = x.op.name + tf.summary.histogram(tensor_name + '/activations', x) + tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) + +def _variable_on_cpu(name, shape, initializer): + """Helper to create a Variable stored on CPU memory. + Args: + name: name of the variable + shape: list of ints + initializer: initializer for Variable + Returns: + Variable Tensor + """ + with tf.device('/cpu:0'): + var = tf.get_variable(name, shape, initializer=initializer) + return var + + +def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.layers.xavier_initializer()): + """Helper to create an initialized Variable with weight decay. + Note that the Variable is initialized with a truncated normal distribution. + A weight decay is added only if one is specified. + Args: + name: name of the variable + shape: list of ints + stddev: standard deviation of a truncated Gaussian + wd: add L2Loss weight decay multiplied by this float. If None, weight + decay is not added for this Variable. + Returns: + Variable Tensor + """ + #var = _variable_on_cpu(name, shape,tf.truncated_normal_initializer(stddev = stddev)) + var = _variable_on_cpu(name, shape, initializer) + if wd: + weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name = 'weight_loss') + weight_decay.set_shape([]) + tf.add_to_collection('losses', weight_decay) + return var + + +def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"): + print("conv_layer activation function",activate) + + with tf.variable_scope('{0}_conv'.format(idx)) as scope: + + input_channels = inputs.get_shape()[-1] + weights = _variable_with_weight_decay('weights',shape = [kernel_size, kernel_size, + input_channels, num_features], + stddev = 0.01, wd = weight_decay) + biases = _variable_on_cpu('biases', [num_features], initializer) + conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding = 'SAME') + conv_biased = tf.nn.bias_add(conv, biases) + if activate == "linear": + return conv_biased + elif activate == "relu": + conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "elu": + conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "leaky_relu": + conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx)) + else: + raise ("activation function is not correct") + return conv_rect + + +def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer(),activate="relu"): + with tf.variable_scope('{0}_trans_conv'.format(idx)) as scope: + input_channels = inputs.get_shape()[3] + input_shape = inputs.get_shape().as_list() + + + weights = _variable_with_weight_decay('weights', + shape = [kernel_size, kernel_size, num_features, input_channels], + stddev = 0.1, wd = weight_decay) + biases = _variable_on_cpu('biases', [num_features],initializer) + batch_size = tf.shape(inputs)[0] + + output_shape = tf.stack( + [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features]) + print ("output_shape",output_shape) + conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME') + conv_biased = tf.nn.bias_add(conv, biases) + if activate == "linear": + return conv_biased + elif activate == "elu": + return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "relu": + return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "leaky_relu": + return tf.nn.leaky_relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "sigmoid": + return tf.nn.sigmoid(conv_biased, name ='sigmoid') + else: + return conv_biased + + +def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01,initializer=tf.contrib.layers.xavier_initializer()): + with tf.variable_scope('{0}_fc'.format(idx)) as scope: + input_shape = inputs.get_shape().as_list() + if flat: + dim = input_shape[1] * input_shape[2] * input_shape[3] + inputs_processed = tf.reshape(inputs, [-1, dim]) + else: + dim = input_shape[1] + inputs_processed = inputs + + weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init, + wd = weight_decay) + biases = _variable_on_cpu('biases', [hiddens],initializer) + if activate == "linear": + return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc') + elif activate == "sigmoid": + return tf.nn.sigmoid(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "softmax": + return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "relu": + return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "leaky_relu": + return tf.nn.leaky_relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + else: + ip = tf.add(tf.matmul(inputs_processed, weights), biases) + return tf.nn.elu(ip, name = str(idx) + '_fc') + +def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None): + with tf.variable_scope('{0}_bn'.format(idx)) as scope: + #Calculate batch mean and variance + shape = inputs.get_shape().as_list() + scale = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=is_training) + beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=is_training) + pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False) + pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False) + + if is_training: + batch_mean, batch_var = tf.nn.moments(inputs,[0]) + train_mean = tf.assign(pop_mean,pop_mean * decay + batch_mean * (1 - decay)) + train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) + with tf.control_dependencies([train_mean,train_var]): + return tf.nn.batch_normalization(inputs,batch_mean,batch_var,beta,scale,epsilon) + else: + return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon) + +def bn_layers_wrapper(inputs, is_training): + pass + \ No newline at end of file diff --git a/video_prediction/layers/mcnet_ops.py b/video_prediction/layers/mcnet_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..656f66c0df1cf199fff319f7b81b01594f96332c --- /dev/null +++ b/video_prediction/layers/mcnet_ops.py @@ -0,0 +1,178 @@ +import math +import numpy as np +import tensorflow as tf + +from tensorflow.python.framework import ops +from video_prediction.utils.mcnet_utils import * + + +def batch_norm(inputs, name, train=True, reuse=False): + return tf.contrib.layers.batch_norm(inputs=inputs,is_training=train, + reuse=reuse,scope=name,scale=True) + + +def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="conv2d", reuse=False, padding='SAME'): + with tf.variable_scope(name, reuse=reuse): + w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], + initializer=tf.contrib.layers.xavier_initializer()) + conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding) + + biases = tf.get_variable('biases', [output_dim], + initializer=tf.constant_initializer(0.0)) + conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + + return conv + + +def deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, + name="deconv2d", reuse=False, with_w=False, padding='SAME'): + with tf.variable_scope(name, reuse=reuse): + # filter : [height, width, output_channels, in_channels] + w = tf.get_variable('w', [k_h, k_h, output_shape[-1], + input_.get_shape()[-1]], + initializer=tf.contrib.layers.xavier_initializer()) + + try: + deconv = tf.nn.conv2d_transpose(input_, w, + output_shape=output_shape, + strides=[1, d_h, d_w, 1], + padding=padding) + + # Support for verisons of TensorFlow before 0.7.0 + except AttributeError: + deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, + strides=[1, d_h, d_w, 1]) + biases = tf.get_variable('biases', [output_shape[-1]], + initializer=tf.constant_initializer(0.0)) + deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) + + if with_w: + return deconv, w, biases + else: + return deconv + + +def lrelu(x, leak=0.2, name="lrelu"): + with tf.variable_scope(name): + f1 = 0.5 * (1 + leak) + f2 = 0.5 * (1 - leak) + return f1 * x + f2 * abs(x) + + +def relu(x): + return tf.nn.relu(x) + + +def tanh(x): + return tf.nn.tanh(x) + + +def shape2d(a): + """ + a: a int or tuple/list of length 2 + """ + if type(a) == int: + return [a, a] + if isinstance(a, (list, tuple)): + assert len(a) == 2 + return list(a) + raise RuntimeError("Illegal shape: {}".format(a)) + + +def shape4d(a): + # for use with tensorflow + return [1] + shape2d(a) + [1] + + +def UnPooling2x2ZeroFilled(x): + out = tf.concat(axis=3, values=[x, tf.zeros_like(x)]) + out = tf.concat(axis=2, values=[out, tf.zeros_like(out)]) + + sh = x.get_shape().as_list() + if None not in sh[1:]: + out_size = [-1, sh[1] * 2, sh[2] * 2, sh[3]] + return tf.reshape(out, out_size) + else: + sh = tf.shape(x) + return tf.reshape(out, [-1, sh[1] * 2, sh[2] * 2, sh[3]]) + + +def MaxPooling(x, shape, stride=None, padding='VALID'): + """ + MaxPooling on images. + :param input: NHWC tensor. + :param shape: int or [h, w] + :param stride: int or [h, w]. default to be shape. + :param padding: 'valid' or 'same'. default to 'valid' + :returns: NHWC tensor. + """ + padding = padding.upper() + shape = shape4d(shape) + if stride is None: + stride = shape + else: + stride = shape4d(stride) + + return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding) + + +#@layer_register() +def FixedUnPooling(x, shape): + """ + Unpool the input with a fixed mat to perform kronecker product with. + :param input: NHWC tensor + :param shape: int or [h, w] + :returns: NHWC tensor + """ + shape = shape2d(shape) + + # a faster implementation for this special case + return UnPooling2x2ZeroFilled(x) + + +def gdl(gen_frames, gt_frames, alpha): + """ + Calculates the sum of GDL losses between the predicted and gt frames. + @param gen_frames: The predicted frames at each scale. + @param gt_frames: The ground truth frames at each scale + @param alpha: The power to which each gradient term is raised. + @return: The GDL loss. + """ + # create filters [-1, 1] and [[1],[-1]] + # for diffing to the left and down respectively. + pos = tf.constant(np.identity(3), dtype=tf.float32) + neg = -1 * pos + # [-1, 1] + filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) + # [[1],[-1]] + filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) + strides = [1, 1, 1, 1] # stride of (1, 1) + padding = 'SAME' + + gen_dx = tf.abs(tf.nn.conv2d(gen_frames, filter_x, strides, padding=padding)) + gen_dy = tf.abs(tf.nn.conv2d(gen_frames, filter_y, strides, padding=padding)) + gt_dx = tf.abs(tf.nn.conv2d(gt_frames, filter_x, strides, padding=padding)) + gt_dy = tf.abs(tf.nn.conv2d(gt_frames, filter_y, strides, padding=padding)) + + grad_diff_x = tf.abs(gt_dx - gen_dx) + grad_diff_y = tf.abs(gt_dy - gen_dy) + + gdl_loss = tf.reduce_mean((grad_diff_x ** alpha + grad_diff_y ** alpha)) + + # condense into one tensor and avg + return gdl_loss + + +def linear(input_, output_size, name, stddev=0.02, bias_start=0.0, + reuse=False, with_w=False): + shape = input_.get_shape().as_list() + + with tf.variable_scope(name, reuse=reuse): + matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, + tf.random_normal_initializer(stddev=stddev)) + bias = tf.get_variable("bias", [output_size], + initializer=tf.constant_initializer(bias_start)) + if with_w: + return tf.matmul(input_, matrix) + bias, matrix, bias + else: + return tf.matmul(input_, matrix) + bias diff --git a/video_prediction/layers/normalization.py b/video_prediction/layers/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2a9bb9c0cf41eb7ed10f25e606c06791e0d2a3 --- /dev/null +++ b/video_prediction/layers/normalization.py @@ -0,0 +1,196 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Contains the normalization layer classes and their functional aliases.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.contrib.framework.python.ops import variables +from tensorflow.contrib.layers.python.layers import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope + + +DATA_FORMAT_NCHW = 'NCHW' +DATA_FORMAT_NHWC = 'NHWC' + + +def fused_instance_norm(inputs, + center=True, + scale=True, + epsilon=1e-6, + activation_fn=None, + param_initializers=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + data_format=DATA_FORMAT_NHWC, + scope=None): + """Functional interface for the instance normalization layer. + + Reference: https://arxiv.org/abs/1607.08022. + + "Instance Normalization: The Missing Ingredient for Fast Stylization" + Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky + + Args: + inputs: A tensor with 2 or more dimensions, where the first dimension has + `batch_size`. The normalization is over all but the last dimension if + `data_format` is `NHWC` and the second dimension if `data_format` is + `NCHW`. + center: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is + not used. When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling can be done by the next layer. + epsilon: Small float added to variance to avoid dividing by zero. + activation_fn: Activation function, default set to None to skip it and + maintain a linear activation. + param_initializers: Optional initializers for beta, gamma, moving mean and + moving variance. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional collections for the variables. + outputs_collections: Collections to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + data_format: A string. `NHWC` (default) and `NCHW` are supported. + scope: Optional scope for `variable_scope`. + + Returns: + A `Tensor` representing the output of the operation. + + Raises: + ValueError: If `data_format` is neither `NHWC` nor `NCHW`. + ValueError: If the rank of `inputs` is undefined. + ValueError: If rank or channels dimension of `inputs` is undefined. + """ + inputs = ops.convert_to_tensor(inputs) + inputs_shape = inputs.shape + inputs_rank = inputs.shape.ndims + + if inputs_rank is None: + raise ValueError('Inputs %s has undefined rank.' % inputs.name) + if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): + raise ValueError('data_format has to be either NCHW or NHWC.') + + with variable_scope.variable_scope( + scope, 'InstanceNorm', [inputs], reuse=reuse) as sc: + if data_format == DATA_FORMAT_NCHW: + reduction_axis = 1 + # For NCHW format, rather than relying on implicit broadcasting, we + # explicitly reshape the params to params_shape_broadcast when computing + # the moments and the batch normalization. + params_shape_broadcast = list( + [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) + else: + reduction_axis = inputs_rank - 1 + params_shape_broadcast = None + moments_axes = list(range(inputs_rank)) + del moments_axes[reduction_axis] + del moments_axes[0] + params_shape = inputs_shape[reduction_axis:reduction_axis + 1] + if not params_shape.is_fully_defined(): + raise ValueError('Inputs %s has undefined channels dimension %s.' % ( + inputs.name, params_shape)) + + # Allocate parameters for the beta and gamma of the normalization. + beta, gamma = None, None + dtype = inputs.dtype.base_dtype + if param_initializers is None: + param_initializers = {} + if center: + beta_collections = utils.get_variable_collections( + variables_collections, 'beta') + beta_initializer = param_initializers.get( + 'beta', init_ops.zeros_initializer()) + beta = variables.model_variable('beta', + shape=params_shape, + dtype=dtype, + initializer=beta_initializer, + collections=beta_collections, + trainable=trainable) + if params_shape_broadcast: + beta = array_ops.reshape(beta, params_shape_broadcast) + if scale: + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') + gamma_initializer = param_initializers.get( + 'gamma', init_ops.ones_initializer()) + gamma = variables.model_variable('gamma', + shape=params_shape, + dtype=dtype, + initializer=gamma_initializer, + collections=gamma_collections, + trainable=trainable) + if params_shape_broadcast: + gamma = array_ops.reshape(gamma, params_shape_broadcast) + + if data_format == DATA_FORMAT_NHWC: + inputs = array_ops.transpose(inputs, list(range(1, reduction_axis)) + [0, reduction_axis]) + if data_format == DATA_FORMAT_NCHW: + inputs = array_ops.transpose(inputs, list(range(2, inputs_rank)) + [0, reduction_axis]) + hw, n, c = inputs.shape.as_list()[:-2], inputs.shape[-2].value, inputs.shape[-1].value + inputs = array_ops.reshape(inputs, [1] + hw + [n * c]) + if inputs.shape.ndims != 4: + # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W] + if inputs.shape.ndims > 4: + inputs_ndims4_shape = [1, hw[0], -1, n * c] + else: + inputs_ndims4_shape = [1, 1, -1, n * c] + inputs = array_ops.reshape(inputs, inputs_ndims4_shape) + beta = array_ops.reshape(array_ops.tile(beta[None, :], [n, 1]), [-1]) + gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [n, 1]), [-1]) + + outputs, _, _ = nn.fused_batch_norm( + inputs, gamma, beta, epsilon=epsilon, + data_format=DATA_FORMAT_NHWC, name='instancenorm') + + outputs = array_ops.reshape(outputs, hw + [n, c]) + if data_format == DATA_FORMAT_NHWC: + outputs = array_ops.transpose(outputs, [inputs_rank - 2] + list(range(inputs_rank - 2)) + [inputs_rank - 1]) + if data_format == DATA_FORMAT_NCHW: + outputs = array_ops.transpose(outputs, [inputs_rank - 2, inputs_rank - 1] + list(range(inputs_rank - 2))) + + # if data_format == DATA_FORMAT_NHWC: + # inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis))) + # inputs_nchw_shape = inputs.shape + # inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:]) + # if inputs.shape.ndims != 4: + # # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W] + # if inputs.shape.ndims > 4: + # inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]] + # else: + # inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1] + # inputs = array_ops.reshape(inputs, inputs_ndims4_shape) + # beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) + # gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) + # + # outputs, _, _ = nn.fused_batch_norm( + # inputs, gamma, beta, epsilon=epsilon, + # data_format=DATA_FORMAT_NCHW, name='instancenorm') + # + # outputs = array_ops.reshape(outputs, inputs_nchw_shape) + # if data_format == DATA_FORMAT_NHWC: + # outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1]) + + if activation_fn is not None: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, sc.name, outputs) diff --git a/video_prediction/losses.py b/video_prediction/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..662da29dbadf091bc59fcd0e7ed62fd71bcf0f81 --- /dev/null +++ b/video_prediction/losses.py @@ -0,0 +1,67 @@ +import tensorflow as tf + +from video_prediction.ops import sigmoid_kl_with_logits + + +def l1_loss(pred, target): + return tf.reduce_mean(tf.abs(target - pred)) + + +def l2_loss(pred, target): + return tf.reduce_mean(tf.square(target - pred)) + + +def normalize_tensor(tensor, eps=1e-10): + norm_factor = tf.norm(tensor, axis=-1, keepdims=True) + return tensor / (norm_factor + eps) + + +def cosine_distance(tensor0, tensor1, keep_axis=None): + tensor0 = normalize_tensor(tensor0) + tensor1 = normalize_tensor(tensor1) + return tf.reduce_mean(tf.reduce_sum(tf.square(tensor0 - tensor1), axis=-1)) / 2.0 + + +def charbonnier_loss(x, epsilon=0.001): + return tf.reduce_mean(tf.sqrt(tf.square(x) + tf.square(epsilon))) + + +def gan_loss(logits, labels, gan_loss_type): + # use 1.0 (or 1.0 - discrim_label_smooth) for real data and 0.0 for fake data + if gan_loss_type == 'GAN': + # discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) + # gen_loss = tf.reduce_mean(-tf.log(predict_fake + EPS)) + if labels in (0.0, 1.0): + labels = tf.constant(labels, dtype=logits.dtype, shape=logits.get_shape()) + loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)) + else: + loss = tf.reduce_mean(sigmoid_kl_with_logits(logits, labels)) + elif gan_loss_type == 'LSGAN': + # discrim_loss = tf.reduce_mean((tf.square(predict_real - 1) + tf.square(predict_fake))) + # gen_loss = tf.reduce_mean(tf.square(predict_fake - 1)) + loss = tf.reduce_mean(tf.square(logits - labels)) + elif gan_loss_type == 'SNGAN': + # this is the form of the loss used in the official implementation of the SNGAN paper, but it leads to + # worse results in our video prediction experiments + if labels == 0.0: + loss = tf.reduce_mean(tf.nn.softplus(logits)) + elif labels == 1.0: + loss = tf.reduce_mean(tf.nn.softplus(-logits)) + else: + raise NotImplementedError + else: + raise ValueError('Unknown GAN loss type %s' % gan_loss_type) + return loss + + +def kl_loss(mu, log_sigma_sq, mu2=None, log_sigma2_sq=None): + if mu2 is None and log_sigma2_sq is None: + sigma_sq = tf.exp(log_sigma_sq) + return -0.5 * tf.reduce_mean(tf.reduce_sum(1 + log_sigma_sq - tf.square(mu) - sigma_sq, axis=-1)) + else: + mu1 = mu + log_sigma1_sq = log_sigma_sq + return tf.reduce_mean(tf.reduce_sum( + (log_sigma2_sq - log_sigma1_sq) / 2 + + (tf.exp(log_sigma1_sq) + tf.square(mu1 - mu2)) / (2 * tf.exp(log_sigma2_sq)) + - 1 / 2, axis=-1)) diff --git a/video_prediction/metrics.py b/video_prediction/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..19c057fc7ac59c137fcba64579e4b52a4bd0a55c --- /dev/null +++ b/video_prediction/metrics.py @@ -0,0 +1,26 @@ +import tensorflow as tf +#import lpips_tf + + +def mse(a, b): + return tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1]) + + +def psnr(a, b): + return tf.image.psnr(a, b, 1.0) + + +def ssim(a, b): + return tf.image.ssim(a, b, 1.0) + + +# def lpips(input0, input1): +# if input0.shape[-1].value == 1: +# input0 = tf.tile(input0, [1] * (input0.shape.ndims - 1) + [3]) +# if input1.shape[-1].value == 1: +# input1 = tf.tile(input1, [1] * (input1.shape.ndims - 1) + [3]) +# +# distance = lpips_tf.lpips(input0, input1) +# return -distance + + diff --git a/video_prediction/models/__init__.py b/video_prediction/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7323f3750949b0ddb411d4a98934928537bc53 --- /dev/null +++ b/video_prediction/models/__init__.py @@ -0,0 +1,30 @@ +from .base_model import BaseVideoPredictionModel +from .base_model import VideoPredictionModel +from .non_trainable_model import NonTrainableVideoPredictionModel +from .non_trainable_model import GroundTruthVideoPredictionModel +from .non_trainable_model import RepeatVideoPredictionModel +from .savp_model import SAVPVideoPredictionModel +from .dna_model import DNAVideoPredictionModel +from .sna_model import SNAVideoPredictionModel +from .sv2p_model import SV2PVideoPredictionModel +from .vanilla_vae_model import VanillaVAEVideoPredictionModel +from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel +from .mcnet_model import McNetVideoPredictionModel +def get_model_class(model): + model_mappings = { + 'ground_truth': 'GroundTruthVideoPredictionModel', + 'repeat': 'RepeatVideoPredictionModel', + 'savp': 'SAVPVideoPredictionModel', + 'dna': 'DNAVideoPredictionModel', + 'sna': 'SNAVideoPredictionModel', + 'sv2p': 'SV2PVideoPredictionModel', + 'vae': 'VanillaVAEVideoPredictionModel', + 'convLSTM': 'VanillaConvLstmVideoPredictionModel', + 'mcnet': 'McNetVideoPredictionModel', + + } + model_class = model_mappings.get(model, model) + model_class = globals().get(model_class) + if model_class is None or not issubclass(model_class, BaseVideoPredictionModel): + raise ValueError('Invalid model %s' % model) + return model_class diff --git a/video_prediction/models/base_model.py b/video_prediction/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebe228fcc9c90addf610bed44bb46f090c7e514 --- /dev/null +++ b/video_prediction/models/base_model.py @@ -0,0 +1,878 @@ +import functools +import itertools +import os +import re +from collections import OrderedDict + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.training import HParams +from tensorflow.python.util import nest + +import video_prediction as vp +from video_prediction.utils import tf_utils +from video_prediction.utils.tf_utils import compute_averaged_gradients, reduce_tensors, local_device_setter, \ + replace_read_ops, print_loss_info, transpose_batch_time, add_gif_summaries, add_scalar_summaries, \ + add_plot_and_scalar_summaries, add_summaries + + +class BaseVideoPredictionModel(object): + def __init__(self, mode='train', hparams_dict=None, hparams=None, + num_gpus=None, eval_num_samples=100, + eval_num_samples_for_diversity=10, eval_parallel_iterations=1): + """ + Base video prediction model. + + Trainable and non-trainable video prediction models can be derived + from this base class. + + Args: + mode: `'train'` or `'test'`. + hparams_dict: a dict of `name=value` pairs, where `name` must be + defined in `self.get_default_hparams()`. + hparams: a string of comma separated list of `name=value` pairs, + where `name` must be defined in `self.get_default_hparams()`. + These values overrides any values in hparams_dict (if any). + """ + if mode not in ('train', 'test'): + raise ValueError('mode must be train or test, but %s given' % mode) + self.mode = mode + cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0') + if cuda_visible_devices == '': + max_num_gpus = 0 + else: + max_num_gpus = len(cuda_visible_devices.split(',')) + if num_gpus is None: + num_gpus = max_num_gpus + elif num_gpus > max_num_gpus: + raise ValueError('num_gpus=%d is greater than the number of visible devices %d' % (num_gpus, max_num_gpus)) + self.num_gpus = num_gpus + self.eval_num_samples = eval_num_samples + self.eval_num_samples_for_diversity = eval_num_samples_for_diversity + self.eval_parallel_iterations = eval_parallel_iterations + self.hparams = self.parse_hparams(hparams_dict, hparams) + if self.hparams.context_frames == -1: + raise ValueError('Invalid context_frames %r. It might have to be ' + 'specified.' % self.hparams.context_frames) + if self.hparams.sequence_length == -1: + raise ValueError('Invalid sequence_length %r. It might have to be ' + 'specified.' % self.hparams.sequence_length) + + # should be overriden by descendant class if the model is stochastic + self.deterministic = True + + # member variables that should be set by `self.build_graph` + self.inputs = None + self.gen_images = None + self.outputs = None + self.metrics = None + self.eval_outputs = None + self.eval_metrics = None + self.accum_eval_metrics = None + self.saveable_variables = None + self.post_init_ops = None + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + repeat: the number of repeat actions (if applicable). + """ + hparams = dict( + context_frames=-1, + sequence_length=-1, + repeat=1, + ) + return hparams + + def get_default_hparams(self): + return HParams(**self.get_default_hparams_dict()) + + def parse_hparams(self, hparams_dict, hparams): + parsed_hparams = self.get_default_hparams().override_from_dict(hparams_dict or {}) + if hparams: + if not isinstance(hparams, (list, tuple)): + hparams = [hparams] + for hparam in hparams: + parsed_hparams.parse(hparam) + return parsed_hparams + + def build_graph(self, inputs): + self.inputs = inputs + + def metrics_fn(self, inputs, outputs): + metrics = OrderedDict() + sequence_length = tf.shape(inputs['images'])[0] + context_frames = self.hparams.context_frames + future_length = sequence_length - context_frames + # target_images and pred_images include only the future frames + target_images = inputs['images'][-future_length:] + pred_images = outputs['gen_images'][-future_length:] + metric_fns = [ + ('psnr', vp.metrics.psnr), + ('mse', vp.metrics.mse), + ('ssim', vp.metrics.ssim), + #('lpips', vp.metrics.lpips), #bing : remove lpips metric course the url fetching issue + ] + for metric_name, metric_fn in metric_fns: + metrics[metric_name] = tf.reduce_mean(metric_fn(target_images, pred_images)) + return metrics + + def eval_outputs_and_metrics_fn(self, inputs, outputs, num_samples=None, + num_samples_for_diversity=None, parallel_iterations=None): + num_samples = num_samples or self.eval_num_samples + num_samples_for_diversity = num_samples_for_diversity or self.eval_num_samples_for_diversity + parallel_iterations = parallel_iterations or self.eval_parallel_iterations + + sequence_length, batch_size = inputs['images'].shape[:2].as_list() + if batch_size is None: + batch_size = tf.shape(inputs['images'])[1] + if sequence_length is None: + sequence_length = tf.shape(inputs['images'])[0] + context_frames = self.hparams.context_frames + future_length = sequence_length - context_frames + # the outputs include all the frames, whereas the metrics include only the future frames + eval_outputs = OrderedDict() + eval_metrics = OrderedDict() + metric_fns = [ + ('psnr', vp.metrics.psnr), + ('mse', vp.metrics.mse), + ('ssim', vp.metrics.ssim), + # ('lpips', vp.metrics.lpips), #bing + ] + # images and gen_images include all the frames + images = inputs['images'] + gen_images = outputs['gen_images'] + # target_images and pred_images include only the future frames + target_images = inputs['images'][-future_length:] + pred_images = outputs['gen_images'][-future_length:] + # ground truth is the same for deterministic and stochastic models + eval_outputs['eval_images'] = images + if self.deterministic: + for metric_name, metric_fn in metric_fns: + metric = metric_fn(target_images, pred_images) + eval_metrics['eval_%s/min' % metric_name] = metric + eval_metrics['eval_%s/avg' % metric_name] = metric + eval_metrics['eval_%s/max' % metric_name] = metric + eval_outputs['eval_gen_images'] = gen_images + else: + def where_axis1(cond, x, y): + return transpose_batch_time(tf.where(cond, transpose_batch_time(x), transpose_batch_time(y))) + + def sort_criterion(x): + return tf.reduce_mean(x, axis=0) + + def accum_gen_images_and_metrics_fn(a, unused): + with tf.variable_scope(self.generator_scope, reuse=True): + outputs_sample = self.generator_fn(inputs) + gen_images_sample = outputs_sample['gen_images'] + pred_images_sample = gen_images_sample[-future_length:] + # set the posisbly static shape since it might not have been inferred correctly + pred_images_sample = tf.reshape(pred_images_sample, tf.shape(a['eval_pred_images_last'])) + for name, metric_fn in metric_fns: + metric = metric_fn(target_images, pred_images_sample) # time, batch_size + cond_min = tf.less(sort_criterion(metric), sort_criterion(a['eval_%s/min' % name])) + cond_max = tf.greater(sort_criterion(metric), sort_criterion(a['eval_%s/max' % name])) + a['eval_%s/min' % name] = where_axis1(cond_min, metric, a['eval_%s/min' % name]) + a['eval_%s/sum' % name] = metric + a['eval_%s/sum' % name] + a['eval_%s/max' % name] = where_axis1(cond_max, metric, a['eval_%s/max' % name]) + a['eval_gen_images_%s/min' % name] = where_axis1(cond_min, gen_images_sample, a['eval_gen_images_%s/min' % name]) + a['eval_gen_images_%s/sum' % name] = gen_images_sample + a['eval_gen_images_%s/sum' % name] + a['eval_gen_images_%s/max' % name] = where_axis1(cond_max, gen_images_sample, a['eval_gen_images_%s/max' % name]) + #bing + # a['eval_diversity'] = tf.cond( + # tf.logical_and(tf.less(0, a['eval_sample_ind']), + # tf.less_equal(a['eval_sample_ind'], num_samples_for_diversity)), + # lambda: -vp.metrics.lpips(a['eval_pred_images_last'], pred_images_sample) + a['eval_diversity'], + # lambda: a['eval_diversity']) + a['eval_sample_ind'] = 1 + a['eval_sample_ind'] + a['eval_pred_images_last'] = pred_images_sample + return a + + initializer = {} + for name, _ in metric_fns: + initializer['eval_gen_images_%s/min' % name] = tf.zeros_like(gen_images) + initializer['eval_gen_images_%s/sum' % name] = tf.zeros_like(gen_images) + initializer['eval_gen_images_%s/max' % name] = tf.zeros_like(gen_images) + initializer['eval_%s/min' % name] = tf.fill([future_length, batch_size], float('inf')) + initializer['eval_%s/sum' % name] = tf.zeros([future_length, batch_size]) + initializer['eval_%s/max' % name] = tf.fill([future_length, batch_size], float('-inf')) + #initializer['eval_diversity'] = tf.zeros([future_length, batch_size]) + initializer['eval_sample_ind'] = tf.zeros((), dtype=tf.int32) + initializer['eval_pred_images_last'] = tf.zeros_like(pred_images) + + eval_outputs_and_metrics = tf.foldl( + accum_gen_images_and_metrics_fn, tf.zeros([num_samples, 0]), initializer=initializer, back_prop=False, + parallel_iterations=parallel_iterations) + + for name, _ in metric_fns: + eval_outputs['eval_gen_images_%s/min' % name] = eval_outputs_and_metrics['eval_gen_images_%s/min' % name] + eval_outputs['eval_gen_images_%s/avg' % name] = eval_outputs_and_metrics['eval_gen_images_%s/sum' % name] / float(num_samples) + eval_outputs['eval_gen_images_%s/max' % name] = eval_outputs_and_metrics['eval_gen_images_%s/max' % name] + eval_metrics['eval_%s/min' % name] = eval_outputs_and_metrics['eval_%s/min' % name] + eval_metrics['eval_%s/avg' % name] = eval_outputs_and_metrics['eval_%s/sum' % name] / float(num_samples) + eval_metrics['eval_%s/max' % name] = eval_outputs_and_metrics['eval_%s/max' % name] + #eval_metrics['eval_diversity'] = eval_outputs_and_metrics['eval_diversity'] / float(num_samples_for_diversity) + return eval_outputs, eval_metrics + + def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): + if checkpoints: + var_list = self.saveable_variables + # possibly restore from multiple checkpoints. useful if subset of weights + # (e.g. generator or discriminator) are on different checkpoints. + if not isinstance(checkpoints, (list, tuple)): + checkpoints = [checkpoints] + # automatically skip global_step if more than one checkpoint is provided + skip_global_step = len(checkpoints) > 1 + savers = [] + for checkpoint in checkpoints: + print("creating restore saver from checkpoint %s" % checkpoint) + saver, _ = tf_utils.get_checkpoint_restore_saver( + checkpoint, var_list, skip_global_step=skip_global_step, + restore_to_checkpoint_mapping=restore_to_checkpoint_mapping) + savers.append(saver) + restore_op = [saver.saver_def.restore_op_name for saver in savers] + sess.run(restore_op) + + +class VideoPredictionModel(BaseVideoPredictionModel): + def __init__(self, + generator_fn, + discriminator_fn=None, + generator_scope='generator', + discriminator_scope='discriminator', + aggregate_nccl=False, + mode='train', + hparams_dict=None, + hparams=None, + **kwargs): + """ + Trainable video prediction model with CPU and multi-GPU support. + + If num_gpus <= 1, the devices for the ops in `self.build_graph` are + automatically chosen by TensorFlow (i.e. `tf.device` is not specified), + otherwise they are explicitly chosen. + + Args: + generator_fn: callable that takes in inputs and returns a dict of + tensors. + discriminator_fn: callable that takes in fake/real data (and + optionally conditioned on inputs) and returns a dict of + tensors. + hparams_dict: a dict of `name=value` pairs, where `name` must be + defined in `self.get_default_hparams()`. + hparams: a string of comma separated list of `name=value` pairs, + where `name` must be defined in `self.get_default_hparams()`. + These values overrides any values in hparams_dict (if any). + """ + super(VideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) + self.generator_fn = functools.partial(generator_fn, mode=self.mode, hparams=self.hparams) + self.discriminator_fn = functools.partial(discriminator_fn, mode=self.mode, hparams=self.hparams) if discriminator_fn else None + self.generator_scope = generator_scope + self.discriminator_scope = discriminator_scope + self.aggregate_nccl = aggregate_nccl + + if any(self.hparams.lr_boundaries): + global_step = tf.train.get_or_create_global_step() + lr_values = list(self.hparams.lr * 0.1 ** np.arange(len(self.hparams.lr_boundaries) + 1)) + self.learning_rate = tf.train.piecewise_constant(global_step, self.hparams.lr_boundaries, lr_values) + elif any(self.hparams.decay_steps): + lr, end_lr = self.hparams.lr, self.hparams.end_lr + start_step, end_step = self.hparams.decay_steps + if start_step == end_step: + schedule = tf.cond(tf.less(tf.train.get_or_create_global_step(), start_step), + lambda: 0.0, lambda: 1.0) + else: + step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step) + schedule = tf.to_float(step - start_step) / tf.to_float(end_step - start_step) + self.learning_rate = lr + (end_lr - lr) * schedule + else: + self.learning_rate = self.hparams.lr + + if self.hparams.kl_weight: + if self.hparams.kl_anneal == 'none': + self.kl_weight = tf.constant(self.hparams.kl_weight, tf.float32) + elif self.hparams.kl_anneal == 'sigmoid': + k = self.hparams.kl_anneal_k + if k == -1.0: + raise ValueError('Invalid kl_anneal_k %d when kl_anneal is sigmoid.' % k) + iter_num = tf.train.get_or_create_global_step() + self.kl_weight = self.hparams.kl_weight / (1 + k * tf.exp(-tf.to_float(iter_num) / k)) + elif self.hparams.kl_anneal == 'linear': + start_step, end_step = self.hparams.kl_anneal_steps + step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step) + self.kl_weight = self.hparams.kl_weight * tf.to_float(step - start_step) / tf.to_float(end_step - start_step) + else: + raise NotImplementedError + else: + self.kl_weight = None + + # member variables that should be set by `self.build_graph` + # (in addition to the ones in the base class) + self.gen_images_enc = None + self.g_losses = None + self.d_losses = None + self.g_loss = None + self.d_loss = None + self.g_vars = None + self.d_vars = None + self.train_op = None + self.summary_op = None + self.image_summary_op = None + self.eval_summary_op = None + self.accum_eval_summary_op = None + self.accum_eval_metrics_reset_op = None + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + batch_size: batch size for training. + lr: learning rate. if decay steps is non-zero, this is the + learning rate for steps <= decay_step. + end_lr: learning rate for steps >= end_decay_step if decay_steps + is non-zero, ignored otherwise. + decay_steps: (decay_step, end_decay_step) tuple. + max_steps: number of training steps. + beta1: momentum term of Adam. + beta2: momentum term of Adam. + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + """ + default_hparams = super(VideoPredictionModel, self).get_default_hparams_dict() + hparams = dict( + batch_size=16, + lr=0.001, + end_lr=0.0, + decay_steps=(200000, 300000), + lr_boundaries=(0,), + max_steps=350000, + beta1=0.9, + beta2=0.999, + context_frames=-1, + sequence_length=-1, + clip_length=10, #Bing: TODO What is the clip_length, original is 10, + l1_weight=0.0, + l2_weight=1.0, + vgg_cdist_weight=0.0, + feature_l2_weight=0.0, + ae_l2_weight=0.0, + state_weight=0.0, + tv_weight=0.0, + image_sn_gan_weight=0.0, + image_sn_vae_gan_weight=0.0, + images_sn_gan_weight=0.0, + images_sn_vae_gan_weight=0.0, + video_sn_gan_weight=0.0, + video_sn_vae_gan_weight=0.0, + gan_feature_l2_weight=0.0, + gan_feature_cdist_weight=0.0, + vae_gan_feature_l2_weight=0.0, + vae_gan_feature_cdist_weight=0.0, + gan_loss_type='LSGAN', + joint_gan_optimization=False, + kl_weight=0.0, + kl_anneal='linear', + kl_anneal_k=-1.0, + kl_anneal_steps=(50000, 100000), + z_l1_weight=0.0, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def tower_fn(self, inputs): + """ + This method doesn't have side-effects. `inputs`, `targets`, and + `outputs` are batch-major but internal calculations use time-major + tensors. + """ + # batch-major to time-major + inputs = nest.map_structure(transpose_batch_time, inputs) + + with tf.variable_scope(self.generator_scope): + gen_outputs = self.generator_fn(inputs) + + if self.discriminator_fn: + with tf.variable_scope(self.discriminator_scope) as discrim_scope: + discrim_outputs = self.discriminator_fn(inputs, gen_outputs) + # post-update discriminator tensors (i.e. after the discriminator weights have been updated) + with tf.variable_scope(discrim_scope, reuse=True): + discrim_outputs_post = self.discriminator_fn(inputs, gen_outputs) + else: + discrim_outputs = {} + discrim_outputs_post = {} + + outputs = [gen_outputs, discrim_outputs] + total_num_outputs = sum([len(output) for output in outputs]) + outputs = OrderedDict(itertools.chain(*[output.items() for output in outputs])) + assert len(outputs) == total_num_outputs # ensure no output is lost because of repeated keys + + if isinstance(self.learning_rate, tf.Tensor): + outputs['learning_rate'] = self.learning_rate + if isinstance(self.kl_weight, tf.Tensor): + outputs['kl_weight'] = self.kl_weight + + if self.mode == 'train': + with tf.name_scope("discriminator_loss"): + d_losses = self.discriminator_loss_fn(inputs, outputs) + print_loss_info(d_losses, inputs, outputs) + with tf.name_scope("generator_loss"): + g_losses = self.generator_loss_fn(inputs, outputs) + print_loss_info(g_losses, inputs, outputs) + if discrim_outputs_post: + outputs_post = OrderedDict(itertools.chain(gen_outputs.items(), discrim_outputs_post.items())) + # generator losses after the discriminator weights have been updated + g_losses_post = self.generator_loss_fn(inputs, outputs_post) + else: + g_losses_post = g_losses + else: + d_losses = {} + g_losses = {} + g_losses_post = {} + with tf.name_scope("metrics"): + metrics = self.metrics_fn(inputs, outputs) + with tf.name_scope("eval_outputs_and_metrics"): + eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs) + + # time-major to batch-major + outputs_tuple = (outputs, eval_outputs) + outputs_tuple = nest.map_structure(transpose_batch_time, outputs_tuple) + losses_tuple = (d_losses, g_losses, g_losses_post) + losses_tuple = nest.map_structure(tf.convert_to_tensor, losses_tuple) + loss_tuple = tuple(tf.accumulate_n([loss * weight for loss, weight in losses.values()]) + if losses else tf.zeros(()) for losses in losses_tuple) + metrics_tuple = (metrics, eval_metrics) + metrics_tuple = nest.map_structure(transpose_batch_time, metrics_tuple) + return outputs_tuple, losses_tuple, loss_tuple, metrics_tuple + + def build_graph(self, inputs,finetune=False): + BaseVideoPredictionModel.build_graph(self, inputs) + + global_step = tf.train.get_or_create_global_step() + # Capture the variables created from here until the train_op for the + # saveable_variables. Note that if variables are being reused (e.g. + # they were created by a previously built model), those variables won't + # be captured here. + original_global_variables = tf.global_variables() + + + # ########Bing: fine-tune####### + # variables_to_restore = tf.contrib.framework.get_variables_to_restore( + # exclude = ["discriminator/video/sn_fc4/dense/bias"]) + # init_fn = tf.contrib.framework.assign_from_checkpoint_fn(checkpoint) + # restore_variables = tf.contrib.framework.get_variables("discriminator/video/sn_fc4/dense/bias") + # restore_init = tf.variables_initializer(restore_variables) + # restore_optimizer = tf.train.GradientDescentOptimizer( + # learning_rate = 0.001) # TODO: need to change the learning rate + # ###Bing: fine-tune####### + # skip_vars = {" discriminator_encoder/video_sn_fc4/dense/bias"} + + if self.num_gpus <= 1: # cpu or 1 gpu + outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs) + self.outputs, self.eval_outputs = outputs_tuple + self.d_losses, self.g_losses, g_losses_post = losses_tuple + self.d_loss, self.g_loss, g_loss_post = loss_tuple + self.metrics, self.eval_metrics = metrics_tuple + + self.d_vars = tf.trainable_variables(self.discriminator_scope) + self.g_vars = tf.trainable_variables(self.generator_scope) + g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) + d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) + + if finetune: + ##Bing: fine-tune + #self.g_vars = tf.contrib.framework.get_variables("discriminator/video/sn_fc4/dense/bias")#generator/encoder/layer_3/conv2d/kernel/Adam_1 + self.g_vars = tf.contrib.framework.get_variables("discriminator/encoder/video/sn_conv3_0/conv3d/kernel") + self.g_vars_init = tf.variables_initializer(self.g_vars) + g_optimizer = tf.train.AdamOptimizer(0.00001) + print("############Bing: Fine Tune here##########") + + if self.mode == 'train' and (self.d_losses or self.g_losses): + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + if self.d_losses: + with tf.name_scope('d_compute_gradients'): + d_gradvars = d_optimizer.compute_gradients(self.d_loss, var_list=self.d_vars) + with tf.name_scope('d_apply_gradients'): + d_train_op = d_optimizer.apply_gradients(d_gradvars) + + else: + d_train_op = tf.no_op() + with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []): + if g_losses_post: + if not self.hparams.joint_gan_optimization: + replace_read_ops(g_loss_post, self.d_vars) + with tf.name_scope('g_compute_gradients'): + g_gradvars = g_optimizer.compute_gradients(g_loss_post, var_list=self.g_vars) + with tf.name_scope('g_apply_gradients'): + g_train_op = g_optimizer.apply_gradients(g_gradvars) + # #######Bing; finetune######## + # with tf.name_scope("finetune_gradients"): + # finetune_grads_vars = finetune_vars_optimizer.compute_gradients(self.d_loss, var_list = self.finetune_vars) + # with tf.name_scope("finetune_apply_gradients"): + # finetune_train_op = finetune_vars_optimizer.apply_gradients(finetune_grads_vars) + else: + g_train_op = tf.no_op() + with tf.control_dependencies([g_train_op]): + train_op = tf.assign_add(global_step, 1) + self.train_op = train_op + else: + self.train_op = None + + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [global_step] + global_variables + self.post_init_ops = [] + else: + if tf.get_variable_scope().name: + # This is because how variable scope works with empty strings when it's not the root scope, causing + # repeated forward slashes. + raise NotImplementedError('Unable to handle multi-gpu model created within a non-root variable scope.') + + tower_inputs = [OrderedDict() for _ in range(self.num_gpus)] + for name, input in self.inputs.items(): + input_splits = tf.split(input, self.num_gpus) # assumes batch_size is divisible by num_gpus + for i in range(self.num_gpus): + tower_inputs[i][name] = input_splits[i] + + tower_outputs_tuple = [] + tower_d_losses = [] + tower_g_losses = [] + tower_g_losses_post = [] + tower_d_loss = [] + tower_g_loss = [] + tower_g_loss_post = [] + tower_metrics_tuple = [] + for i in range(self.num_gpus): + worker_device = '/gpu:%d' % i + if self.aggregate_nccl: + scope_name = '' if i == 0 else 'v%d' % i + scope_reuse = False + device_setter = worker_device + else: + scope_name = '' + scope_reuse = i > 0 + device_setter = local_device_setter(worker_device=worker_device) + with tf.variable_scope(scope_name, reuse=scope_reuse): + with tf.device(device_setter): + outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(tower_inputs[i]) + tower_outputs_tuple.append(outputs_tuple) + d_losses, g_losses, g_losses_post = losses_tuple + tower_d_losses.append(d_losses) + tower_g_losses.append(g_losses) + tower_g_losses_post.append(g_losses_post) + d_loss, g_loss, g_loss_post = loss_tuple + tower_d_loss.append(d_loss) + tower_g_loss.append(g_loss) + tower_g_loss_post.append(g_loss_post) + tower_metrics_tuple.append(metrics_tuple) + self.d_vars = tf.trainable_variables(self.discriminator_scope) + self.g_vars = tf.trainable_variables(self.generator_scope) + + if self.aggregate_nccl: + scope_replica = lambda scope, i: ('' if i == 0 else 'v%d/' % i) + scope + tower_d_vars = [tf.trainable_variables( + scope_replica(self.discriminator_scope, i)) for i in range(self.num_gpus)] + tower_g_vars = [tf.trainable_variables( + scope_replica(self.generator_scope, i)) for i in range(self.num_gpus)] + assert self.d_vars == tower_d_vars[0] + assert self.g_vars == tower_g_vars[0] + tower_d_optimizer = [tf.train.AdamOptimizer( + self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)] + tower_g_optimizer = [tf.train.AdamOptimizer( + self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)] + + if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)): + tower_d_gradvars = [] + tower_g_gradvars = [] + tower_d_train_op = [] + tower_g_train_op = [] + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + if any(tower_d_losses): + for i in range(self.num_gpus): + with tf.device('/gpu:%d' % i): + with tf.name_scope(scope_replica('d_compute_gradients', i)): + d_gradvars = tower_d_optimizer[i].compute_gradients( + tower_d_loss[i], var_list=tower_d_vars[i]) + tower_d_gradvars.append(d_gradvars) + + all_d_grads, all_d_vars = tf_utils.split_grad_list(tower_d_gradvars) + all_d_grads = tf_utils.allreduce_grads(all_d_grads, average=True) + tower_d_gradvars = tf_utils.merge_grad_list(all_d_grads, all_d_vars) + + for i in range(self.num_gpus): + with tf.device('/gpu:%d' % i): + with tf.name_scope(scope_replica('d_apply_gradients', i)): + d_train_op = tower_d_optimizer[i].apply_gradients(tower_d_gradvars[i]) + tower_d_train_op.append(d_train_op) + d_train_op = tf.group(*tower_d_train_op) + else: + d_train_op = tf.no_op() + with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []): + if any(tower_g_losses_post): + for i in range(self.num_gpus): + with tf.device('/gpu:%d' % i): + if not self.hparams.joint_gan_optimization: + replace_read_ops(tower_g_loss_post[i], tower_d_vars[i]) + + with tf.name_scope(scope_replica('g_compute_gradients', i)): + g_gradvars = tower_g_optimizer[i].compute_gradients( + tower_g_loss_post[i], var_list=tower_g_vars[i]) + tower_g_gradvars.append(g_gradvars) + + all_g_grads, all_g_vars = tf_utils.split_grad_list(tower_g_gradvars) + all_g_grads = tf_utils.allreduce_grads(all_g_grads, average=True) + tower_g_gradvars = tf_utils.merge_grad_list(all_g_grads, all_g_vars) + + for i, g_gradvars in enumerate(tower_g_gradvars): + with tf.device('/gpu:%d' % i): + with tf.name_scope(scope_replica('g_apply_gradients', i)): + g_train_op = tower_g_optimizer[i].apply_gradients(g_gradvars) + tower_g_train_op.append(g_train_op) + g_train_op = tf.group(*tower_g_train_op) + else: + g_train_op = tf.no_op() + with tf.control_dependencies([g_train_op]): + train_op = tf.assign_add(global_step, 1) + self.train_op = train_op + else: + self.train_op = None + + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + tower_saveable_vars = [[] for _ in range(self.num_gpus)] + for var in global_variables: + m = re.match('v(\d+)/.*', var.name) + i = int(m.group(1)) if m else 0 + tower_saveable_vars[i].append(var) + self.saveable_variables = [global_step] + tower_saveable_vars[0] + + post_init_ops = [] + for i, saveable_vars in enumerate(tower_saveable_vars[1:], 1): + assert len(saveable_vars) == len(tower_saveable_vars[0]) + for var, var0 in zip(saveable_vars, tower_saveable_vars[0]): + assert var.name == 'v%d/%s' % (i, var0.name) + post_init_ops.append(var.assign(var0.read_value())) + self.post_init_ops = post_init_ops + else: # not self.aggregate_nccl (i.e. aggregation in cpu) + g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) + d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) + + if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)): + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + if any(tower_d_losses): + with tf.name_scope('d_compute_gradients'): + d_gradvars = compute_averaged_gradients( + d_optimizer, tower_d_loss, var_list=self.d_vars) + with tf.name_scope('d_apply_gradients'): + d_train_op = d_optimizer.apply_gradients(d_gradvars) + else: + d_train_op = tf.no_op() + with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []): + if any(tower_g_losses_post): + for g_loss_post in tower_g_loss_post: + if not self.hparams.joint_gan_optimization: + replace_read_ops(g_loss_post, self.d_vars) + with tf.name_scope('g_compute_gradients'): + g_gradvars = compute_averaged_gradients( + g_optimizer, tower_g_loss_post, var_list=self.g_vars) + with tf.name_scope('g_apply_gradients'): + g_train_op = g_optimizer.apply_gradients(g_gradvars) + else: + g_train_op = tf.no_op() + with tf.control_dependencies([g_train_op]): + train_op = tf.assign_add(global_step, 1) + self.train_op = train_op + else: + self.train_op = None + + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [global_step] + global_variables + self.post_init_ops = [] + + # Device that runs the ops to apply global gradient updates. + consolidation_device = '/cpu:0' + with tf.device(consolidation_device): + with tf.name_scope('consolidation'): + self.outputs, self.eval_outputs = reduce_tensors(tower_outputs_tuple) + self.d_losses = reduce_tensors(tower_d_losses, shallow=True) + self.g_losses = reduce_tensors(tower_g_losses, shallow=True) + self.metrics, self.eval_metrics = reduce_tensors(tower_metrics_tuple) + self.d_loss = reduce_tensors(tower_d_loss) + self.g_loss = reduce_tensors(tower_g_loss) + + original_local_variables = set(tf.local_variables()) + self.accum_eval_metrics = OrderedDict() + for name, eval_metric in self.eval_metrics.items(): + _, self.accum_eval_metrics['accum_' + name] = tf.metrics.mean_tensor(eval_metric) + local_variables = set(tf.local_variables()) - original_local_variables + self.accum_eval_metrics_reset_op = tf.group([tf.assign(v, tf.zeros_like(v)) for v in local_variables]) + + original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) + add_summaries(self.inputs) + add_summaries(self.outputs) + add_scalar_summaries(self.d_losses) + add_scalar_summaries(self.g_losses) + add_scalar_summaries(self.metrics) + if self.d_losses: + add_scalar_summaries({'d_loss': self.d_loss}) + if self.g_losses: + add_scalar_summaries({'g_loss': self.g_loss}) + if self.d_losses and self.g_losses: + add_scalar_summaries({'loss': self.d_loss + self.g_loss}) + summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries + # split summaries into non-image summaries and image summaries + self.summary_op = tf.summary.merge(list(summaries - set(tf.get_collection(tf_utils.IMAGE_SUMMARIES)))) + self.image_summary_op = tf.summary.merge(list(summaries & set(tf.get_collection(tf_utils.IMAGE_SUMMARIES)))) + + original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) + add_gif_summaries(self.eval_outputs) + add_plot_and_scalar_summaries( + {name: tf.reduce_mean(metric, axis=0) for name, metric in self.eval_metrics.items()}, + x_offset=self.hparams.context_frames + 1) + summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries + self.eval_summary_op = tf.summary.merge(list(summaries)) + + original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) + add_plot_and_scalar_summaries( + {name: tf.reduce_mean(metric, axis=0) for name, metric in self.accum_eval_metrics.items()}, + x_offset=self.hparams.context_frames + 1) + summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries + self.accum_eval_summary_op = tf.summary.merge(list(summaries)) + + def generator_loss_fn(self, inputs, outputs): + hparams = self.hparams + gen_losses = OrderedDict() + if hparams.l1_weight or hparams.l2_weight or hparams.vgg_cdist_weight: + gen_images = outputs.get('gen_images_enc', outputs['gen_images']) + target_images = inputs['images'][1:] + if hparams.l1_weight: + gen_l1_loss = vp.losses.l1_loss(gen_images, target_images) + gen_losses["gen_l1_loss"] = (gen_l1_loss, hparams.l1_weight) + if hparams.l2_weight: + gen_l2_loss = vp.losses.l2_loss(gen_images, target_images) + gen_losses["gen_l2_loss"] = (gen_l2_loss, hparams.l2_weight) + if hparams.vgg_cdist_weight: + gen_vgg_cdist_loss = vp.metrics.vgg_cosine_distance(gen_images, target_images) + gen_losses['gen_vgg_cdist_loss'] = (gen_vgg_cdist_loss, hparams.vgg_cdist_weight) + if hparams.feature_l2_weight: + gen_features = outputs.get('gen_features_enc', outputs['gen_features']) + target_features = outputs['features'][1:] + gen_feature_l2_loss = vp.losses.l2_loss(gen_features, target_features) + gen_losses["gen_feature_l2_loss"] = (gen_feature_l2_loss, hparams.feature_l2_weight) + if hparams.ae_l2_weight: + gen_images_dec = outputs.get('gen_images_dec_enc', outputs['gen_images_dec']) # they both should be the same + target_images = inputs['images'] + gen_ae_l2_loss = vp.losses.l2_loss(gen_images_dec, target_images) + gen_losses["gen_ae_l2_loss"] = (gen_ae_l2_loss, hparams.ae_l2_weight) + if hparams.state_weight: + gen_states = outputs.get('gen_states_enc', outputs['gen_states']) + target_states = inputs['states'][1:] + gen_state_loss = vp.losses.l2_loss(gen_states, target_states) + gen_losses["gen_state_loss"] = (gen_state_loss, hparams.state_weight) + if hparams.tv_weight: + gen_flows = outputs.get('gen_flows_enc', outputs['gen_flows']) + flow_diff1 = gen_flows[..., 1:, :, :, :] - gen_flows[..., :-1, :, :, :] + flow_diff2 = gen_flows[..., :, 1:, :, :] - gen_flows[..., :, :-1, :, :] + # sum over the multiple transformations but take the mean for the other dimensions + gen_tv_loss = (tf.reduce_mean(tf.reduce_sum(tf.abs(flow_diff1), axis=(-2, -1))) + + tf.reduce_mean(tf.reduce_sum(tf.abs(flow_diff2), axis=(-2, -1)))) + gen_losses['gen_tv_loss'] = (gen_tv_loss, hparams.tv_weight) + gan_weights = {'_image_sn': hparams.image_sn_gan_weight, + '_images_sn': hparams.images_sn_gan_weight, + '_video_sn': hparams.video_sn_gan_weight} + for infix, gan_weight in gan_weights.items(): + if gan_weight: + gen_gan_loss = vp.losses.gan_loss(outputs['discrim%s_logits_fake' % infix], 1.0, hparams.gan_loss_type) + gen_losses["gen%s_gan_loss" % infix] = (gen_gan_loss, gan_weight) + if gan_weight and (hparams.gan_feature_l2_weight or hparams.gan_feature_cdist_weight): + i_feature = 0 + discrim_features_fake = [] + discrim_features_real = [] + while True: + discrim_feature_fake = outputs.get('discrim%s_feature%d_fake' % (infix, i_feature)) + discrim_feature_real = outputs.get('discrim%s_feature%d_real' % (infix, i_feature)) + if discrim_feature_fake is None or discrim_feature_real is None: + break + discrim_features_fake.append(discrim_feature_fake) + discrim_features_real.append(discrim_feature_real) + i_feature += 1 + if hparams.gan_feature_l2_weight: + gen_gan_feature_l2_loss = sum([vp.losses.l2_loss(discrim_feature_fake, discrim_feature_real) + for discrim_feature_fake, discrim_feature_real in zip(discrim_features_fake, discrim_features_real)]) + gen_losses["gen%s_gan_feature_l2_loss" % infix] = (gen_gan_feature_l2_loss, hparams.gan_feature_l2_weight) + if hparams.gan_feature_cdist_weight: + gen_gan_feature_cdist_loss = sum([vp.losses.cosine_distance(discrim_feature_fake, discrim_feature_real) + for discrim_feature_fake, discrim_feature_real in zip(discrim_features_fake, discrim_features_real)]) + gen_losses["gen%s_gan_feature_cdist_loss" % infix] = (gen_gan_feature_cdist_loss, hparams.gan_feature_cdist_weight) + vae_gan_weights = {'_image_sn': hparams.image_sn_vae_gan_weight, + '_images_sn': hparams.images_sn_vae_gan_weight, + '_video_sn': hparams.video_sn_vae_gan_weight} + for infix, vae_gan_weight in vae_gan_weights.items(): + if vae_gan_weight: + gen_vae_gan_loss = vp.losses.gan_loss(outputs['discrim%s_logits_enc_fake' % infix], 1.0, hparams.gan_loss_type) + gen_losses["gen%s_vae_gan_loss" % infix] = (gen_vae_gan_loss, vae_gan_weight) + if vae_gan_weight and (hparams.vae_gan_feature_l2_weight or hparams.vae_gan_feature_cdist_weight): + i_feature = 0 + discrim_features_enc_fake = [] + discrim_features_enc_real = [] + while True: + discrim_feature_enc_fake = outputs.get('discrim%s_feature%d_enc_fake' % (infix, i_feature)) + discrim_feature_enc_real = outputs.get('discrim%s_feature%d_enc_real' % (infix, i_feature)) + if discrim_feature_enc_fake is None or discrim_feature_enc_real is None: + break + discrim_features_enc_fake.append(discrim_feature_enc_fake) + discrim_features_enc_real.append(discrim_feature_enc_real) + i_feature += 1 + if hparams.vae_gan_feature_l2_weight: + gen_vae_gan_feature_l2_loss = sum([vp.losses.l2_loss(discrim_feature_enc_fake, discrim_feature_enc_real) + for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real)]) + gen_losses["gen%s_vae_gan_feature_l2_loss" % infix] = (gen_vae_gan_feature_l2_loss, hparams.vae_gan_feature_l2_weight) + if hparams.vae_gan_feature_cdist_weight: + gen_vae_gan_feature_cdist_loss = sum([vp.losses.cosine_distance(discrim_feature_enc_fake, discrim_feature_enc_real) + for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real)]) + gen_losses["gen%s_vae_gan_feature_cdist_loss" % infix] = (gen_vae_gan_feature_cdist_loss, hparams.vae_gan_feature_cdist_weight) + if hparams.kl_weight: + gen_kl_loss = vp.losses.kl_loss(outputs['zs_mu_enc'], outputs['zs_log_sigma_sq_enc'], + outputs.get('zs_mu_prior'), outputs.get('zs_log_sigma_sq_prior')) + gen_losses["gen_kl_loss"] = (gen_kl_loss, self.kl_weight) # possibly annealed kl_weight + return gen_losses + + def discriminator_loss_fn(self, inputs, outputs): + hparams = self.hparams + discrim_losses = OrderedDict() + gan_weights = {'_image_sn': hparams.image_sn_gan_weight, + '_images_sn': hparams.images_sn_gan_weight, + '_video_sn': hparams.video_sn_gan_weight} + for infix, gan_weight in gan_weights.items(): + if gan_weight: + discrim_gan_loss_real = vp.losses.gan_loss(outputs['discrim%s_logits_real' % infix], 1.0, hparams.gan_loss_type) + discrim_gan_loss_fake = vp.losses.gan_loss(outputs['discrim%s_logits_fake' % infix], 0.0, hparams.gan_loss_type) + discrim_gan_loss = discrim_gan_loss_real + discrim_gan_loss_fake + discrim_losses["discrim%s_gan_loss" % infix] = (discrim_gan_loss, gan_weight) + vae_gan_weights = {'_image_sn': hparams.image_sn_vae_gan_weight, + '_images_sn': hparams.images_sn_vae_gan_weight, + '_video_sn': hparams.video_sn_vae_gan_weight} + for infix, vae_gan_weight in vae_gan_weights.items(): + if vae_gan_weight: + discrim_vae_gan_loss_real = vp.losses.gan_loss(outputs['discrim%s_logits_enc_real' % infix], 1.0, hparams.gan_loss_type) + discrim_vae_gan_loss_fake = vp.losses.gan_loss(outputs['discrim%s_logits_enc_fake' % infix], 0.0, hparams.gan_loss_type) + discrim_vae_gan_loss = discrim_vae_gan_loss_real + discrim_vae_gan_loss_fake + discrim_losses["discrim%s_vae_gan_loss" % infix] = (discrim_vae_gan_loss, vae_gan_weight) + return discrim_losses diff --git a/video_prediction/models/dna_model.py b/video_prediction/models/dna_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c4fa8b97bc523382adfa14f564aa30193920ed48 --- /dev/null +++ b/video_prediction/models/dna_model.py @@ -0,0 +1,476 @@ +# Copyright 2016 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Model architecture for predictive model, including CDNA, DNA, and STP.""" + +import itertools + +import numpy as np +import tensorflow as tf +import tensorflow.contrib.slim as slim +from tensorflow.contrib.layers.python import layers as tf_layers + +from video_prediction.models import VideoPredictionModel +from .sna_model import basic_conv_lstm_cell + + +# Amount to use when lower bounding tensors +RELU_SHIFT = 1e-12 + + +def construct_model(images, + actions=None, + states=None, + iter_num=-1.0, + kernel_size=(5, 5), + k=-1, + use_state=True, + num_masks=10, + stp=False, + cdna=True, + dna=False, + context_frames=2, + pix_distributions=None): + """Build convolutional lstm video predictor using STP, CDNA, or DNA. + + Args: + images: tensor of ground truth image sequences + actions: tensor of action sequences + states: tensor of ground truth state sequences + iter_num: tensor of the current training iteration (for sched. sampling) + k: constant used for scheduled sampling. -1 to feed in own prediction. + use_state: True to include state and action in prediction + num_masks: the number of different pixel motion predictions (and + the number of masks for each of those predictions) + stp: True to use Spatial Transformer Predictor (STP) + cdna: True to use Convoluational Dynamic Neural Advection (CDNA) + dna: True to use Dynamic Neural Advection (DNA) + context_frames: number of ground truth frames to pass in before + feeding in own predictions + Returns: + gen_images: predicted future image frames + gen_states: predicted future states + + Raises: + ValueError: if more than one network option specified or more than 1 mask + specified for DNA model. + """ + DNA_KERN_SIZE = kernel_size[0] + + if stp + cdna + dna != 1: + raise ValueError('More than one, or no network option specified.') + batch_size, img_height, img_width, color_channels = images[0].get_shape()[0:4] + lstm_func = basic_conv_lstm_cell + + # Generated robot states and images. + gen_states, gen_images = [], [] + gen_pix_distrib = [] + gen_masks = [] + current_state = states[0] + + if k == -1: + feedself = True + else: + # Scheduled sampling: + # Calculate number of ground-truth frames to pass in. + num_ground_truth = tf.to_int32( + tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k))))) + feedself = False + + # LSTM state sizes and states. + lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32])) + lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None + lstm_state5, lstm_state6, lstm_state7 = None, None, None + + for t, action in enumerate(actions): + # Reuse variables after the first timestep. + reuse = bool(gen_images) + + done_warm_start = len(gen_images) > context_frames - 1 + with slim.arg_scope( + [lstm_func, slim.layers.conv2d, slim.layers.fully_connected, + tf_layers.layer_norm, slim.layers.conv2d_transpose], + reuse=reuse): + + if feedself and done_warm_start: + # Feed in generated image. + prev_image = gen_images[-1] + if pix_distributions is not None: + prev_pix_distrib = gen_pix_distrib[-1] + elif done_warm_start: + # Scheduled sampling + prev_image = scheduled_sample(images[t], gen_images[-1], batch_size, + num_ground_truth) + else: + # Always feed in ground_truth + prev_image = images[t] + if pix_distributions is not None: + prev_pix_distrib = pix_distributions[t] + # prev_pix_distrib = tf.expand_dims(prev_pix_distrib, -1) + + # Predicted state is always fed back in + state_action = tf.concat(axis=1, values=[action, current_state]) + + enc0 = slim.layers.conv2d( + prev_image, + 32, [5, 5], + stride=2, + scope='scale1_conv1', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'layer_norm1'}) + + hidden1, lstm_state1 = lstm_func( + enc0, lstm_state1, lstm_size[0], scope='state1') + hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2') + hidden2, lstm_state2 = lstm_func( + hidden1, lstm_state2, lstm_size[1], scope='state2') + hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3') + enc1 = slim.layers.conv2d( + hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2') + + hidden3, lstm_state3 = lstm_func( + enc1, lstm_state3, lstm_size[2], scope='state3') + hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4') + hidden4, lstm_state4 = lstm_func( + hidden3, lstm_state4, lstm_size[3], scope='state4') + hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5') + enc2 = slim.layers.conv2d( + hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3') + + # Pass in state and action. + smear = tf.reshape( + state_action, + [int(batch_size), 1, 1, int(state_action.get_shape()[1])]) + smear = tf.tile( + smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1]) + if use_state: + enc2 = tf.concat(axis=3, values=[enc2, smear]) + enc3 = slim.layers.conv2d( + enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4') + + hidden5, lstm_state5 = lstm_func( + enc3, lstm_state5, lstm_size[4], scope='state5') # last 8x8 + hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6') + enc4 = slim.layers.conv2d_transpose( + hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1') + + hidden6, lstm_state6 = lstm_func( + enc4, lstm_state6, lstm_size[5], scope='state6') # 16x16 + hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7') + # Skip connection. + hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 + + enc5 = slim.layers.conv2d_transpose( + hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2') + hidden7, lstm_state7 = lstm_func( + enc5, lstm_state7, lstm_size[6], scope='state7') # 32x32 + hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8') + + # Skip connection. + hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 + + enc6 = slim.layers.conv2d_transpose( + hidden7, + hidden7.get_shape()[3], 3, stride=2, scope='convt3', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'layer_norm9'}) + + if dna: + # Using largest hidden state for predicting untied conv kernels. + enc7 = slim.layers.conv2d_transpose( + enc6, DNA_KERN_SIZE ** 2, 1, stride=1, scope='convt4') + else: + # Using largest hidden state for predicting a new image layer. + enc7 = slim.layers.conv2d_transpose( + enc6, color_channels, 1, stride=1, scope='convt4') + # This allows the network to also generate one image from scratch, + # which is useful when regions of the image become unoccluded. + transformed = [tf.nn.sigmoid(enc7)] + + if stp: + stp_input0 = tf.reshape(hidden5, [int(batch_size), -1]) + stp_input1 = slim.layers.fully_connected( + stp_input0, 100, scope='fc_stp') + + # disabling capability to generete pixels + reuse_stp = None + if reuse: + reuse_stp = reuse + transformed = stp_transformation(prev_image, stp_input1, num_masks, reuse_stp) + # transformed += stp_transformation(prev_image, stp_input1, num_masks) + + if pix_distributions is not None: + transf_distrib = stp_transformation(prev_pix_distrib, stp_input1, num_masks, reuse=True) + + elif cdna: + cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) + + new_transformed, cdna_kerns = cdna_transformation(prev_image, + cdna_input, + num_masks, + int(color_channels), + kernel_size, + reuse_sc=reuse) + transformed += new_transformed + + if pix_distributions is not None: + if not dna: + transf_distrib = [prev_pix_distrib] + new_transf_distrib, _ = cdna_transformation(prev_pix_distrib, + cdna_input, + num_masks, + prev_pix_distrib.shape[-1].value, + kernel_size, + reuse_sc=True) + transf_distrib += new_transf_distrib + + elif dna: + # Only one mask is supported (more should be unnecessary). + if num_masks != 1: + raise ValueError('Only one mask is supported for DNA model.') + transformed = [dna_transformation(prev_image, enc7, DNA_KERN_SIZE)] + + masks = slim.layers.conv2d_transpose( + enc6, num_masks + 1, 1, stride=1, scope='convt7') + masks = tf.reshape( + tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), + [int(batch_size), int(img_height), int(img_width), num_masks + 1]) + mask_list = tf.split(masks, num_masks + 1, axis=3) + output = mask_list[0] * prev_image + for layer, mask in zip(transformed, mask_list[1:]): + output += layer * mask + gen_images.append(output) + gen_masks.append(mask_list) + + if dna and pix_distributions is not None: + transf_distrib = [dna_transformation(prev_pix_distrib, enc7, DNA_KERN_SIZE)] + + if pix_distributions is not None: + pix_distrib_output = mask_list[0] * prev_pix_distrib + for layer, mask in zip(transf_distrib, mask_list[1:]): + pix_distrib_output += layer * mask + pix_distrib_output /= tf.reduce_sum(pix_distrib_output, axis=(1, 2), keepdims=True) + gen_pix_distrib.append(pix_distrib_output) + + if int(current_state.get_shape()[1]) == 0: + current_state = tf.zeros_like(state_action) + else: + current_state = slim.layers.fully_connected( + state_action, + int(current_state.get_shape()[1]), + scope='state_pred', + activation_fn=None) + gen_states.append(current_state) + + return gen_images, gen_states, gen_masks, gen_pix_distrib + + +## Utility functions +def stp_transformation(prev_image, stp_input, num_masks): + """Apply spatial transformer predictor (STP) to previous image. + + Args: + prev_image: previous image to be transformed. + stp_input: hidden layer to be used for computing STN parameters. + num_masks: number of masks and hence the number of STP transformations. + Returns: + List of images transformed by the predicted STP parameters. + """ + # Only import spatial transformer if needed. + from spatial_transformer import transformer + + identity_params = tf.convert_to_tensor( + np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32)) + transformed = [] + for i in range(num_masks - 1): + params = slim.layers.fully_connected( + stp_input, 6, scope='stp_params' + str(i), + activation_fn=None) + identity_params + transformed.append(transformer(prev_image, params)) + + return transformed + + +def cdna_transformation(prev_image, cdna_input, num_masks, color_channels, kernel_size, reuse_sc=None): + """Apply convolutional dynamic neural advection to previous image. + + Args: + prev_image: previous image to be transformed. + cdna_input: hidden lyaer to be used for computing CDNA kernels. + num_masks: the number of masks and hence the number of CDNA transformations. + color_channels: the number of color channels in the images. + Returns: + List of images transformed by the predicted CDNA kernels. + """ + batch_size = int(cdna_input.get_shape()[0]) + height = int(prev_image.get_shape()[1]) + width = int(prev_image.get_shape()[2]) + + # Predict kernels using linear function of last hidden layer. + cdna_kerns = slim.layers.fully_connected( + cdna_input, + kernel_size[0] * kernel_size[1] * num_masks, + scope='cdna_params', + activation_fn=None, + reuse=reuse_sc) + + # Reshape and normalize. + cdna_kerns = tf.reshape( + cdna_kerns, [batch_size, kernel_size[0], kernel_size[1], 1, num_masks]) + cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT + norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True) + cdna_kerns /= norm_factor + + # Treat the color channel dimension as the batch dimension since the same + # transformation is applied to each color channel. + # Treat the batch dimension as the channel dimension so that + # depthwise_conv2d can apply a different transformation to each sample. + cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3]) + cdna_kerns = tf.reshape(cdna_kerns, [kernel_size[0], kernel_size[1], batch_size, num_masks]) + # Swap the batch and channel dimensions. + prev_image = tf.transpose(prev_image, [3, 1, 2, 0]) + + # Transform image. + transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME') + + # Transpose the dimensions to where they belong. + transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks]) + transformed = tf.transpose(transformed, [3, 1, 2, 0, 4]) + transformed = tf.unstack(transformed, axis=-1) + return transformed, cdna_kerns + + +def dna_transformation(prev_image, dna_input, kernel_size): + """Apply dynamic neural advection to previous image. + + Args: + prev_image: previous image to be transformed. + dna_input: hidden lyaer to be used for computing DNA transformation. + Returns: + List of images transformed by the predicted CDNA kernels. + """ + # Construct translated images. + pad_along_height = (kernel_size[0] - 1) + pad_along_width = (kernel_size[1] - 1) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + prev_image_pad = tf.pad(prev_image, [[0, 0], + [pad_top, pad_bottom], + [pad_left, pad_right], + [0, 0]]) + image_height = int(prev_image.get_shape()[1]) + image_width = int(prev_image.get_shape()[2]) + + inputs = [] + for xkern in range(kernel_size[0]): + for ykern in range(kernel_size[1]): + inputs.append( + tf.expand_dims( + tf.slice(prev_image_pad, [0, xkern, ykern, 0], + [-1, image_height, image_width, -1]), [3])) + inputs = tf.concat(axis=3, values=inputs) + + # Normalize channels to 1. + kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT + kernel = tf.expand_dims( + kernel / tf.reduce_sum( + kernel, [3], keepdims=True), [4]) + return tf.reduce_sum(kernel * inputs, [3], keepdims=False) + + +def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): + """Sample batch with specified mix of ground truth and generated data points. + + Args: + ground_truth_x: tensor of ground-truth data points. + generated_x: tensor of generated data points. + batch_size: batch size + num_ground_truth: number of ground-truth examples to include in batch. + Returns: + New batch with num_ground_truth sampled from ground_truth_x and the rest + from generated_x. + """ + idx = tf.random_shuffle(tf.range(int(batch_size))) + ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) + generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size))) + + ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) + generated_examps = tf.gather(generated_x, generated_idx) + return tf.dynamic_stitch([ground_truth_idx, generated_idx], + [ground_truth_examps, generated_examps]) + + +def generator_fn(inputs, hparams=None): + images = tf.unstack(inputs['images'], axis=0) + actions = tf.unstack(inputs['actions'], axis=0) + states = tf.unstack(inputs['states'], axis=0) + pix_distributions = tf.unstack(inputs['pix_distribs'], axis=0) if 'pix_distribs' in inputs else None + iter_num = tf.to_float(tf.train.get_or_create_global_step()) + + gen_images, gen_states, gen_masks, gen_pix_distrib = \ + construct_model(images, + actions, + states, + iter_num=iter_num, + kernel_size=hparams.kernel_size, + k=hparams.schedule_sampling_k, + num_masks=hparams.num_masks, + cdna=hparams.transformation == 'cdna', + dna=hparams.transformation == 'dna', + stp=hparams.transformation == 'stp', + context_frames=hparams.context_frames, + pix_distributions=pix_distributions) + outputs = { + 'gen_images': tf.stack(gen_images, axis=0), + 'gen_states': tf.stack(gen_states, axis=0), + 'masks': tf.stack([tf.stack(gen_mask_list, axis=-1) for gen_mask_list in gen_masks], axis=0), + } + if 'pix_distribs' in inputs: + outputs['gen_pix_distribs'] = tf.stack(gen_pix_distrib, axis=0) + gen_images = outputs['gen_images'][hparams.context_frames - 1:] + return gen_images, outputs + + +class DNAVideoPredictionModel(VideoPredictionModel): + def __init__(self, *args, **kwargs): + super(DNAVideoPredictionModel, self).__init__( + generator_fn, *args, **kwargs) + + def get_default_hparams_dict(self): + default_hparams = super(DNAVideoPredictionModel, self).get_default_hparams_dict() + hparams = dict( + batch_size=32, + l1_weight=0.0, + l2_weight=1.0, + transformation='cdna', + kernel_size=(9, 9), + num_masks=10, + schedule_sampling_k=900.0, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def parse_hparams(self, hparams_dict, hparams): + hparams = super(DNAVideoPredictionModel, self).parse_hparams(hparams_dict, hparams) + if self.mode == 'test': + def override_hparams_maybe(name, value): + orig_value = hparams.values()[name] + if orig_value != value: + print('Overriding hparams from %s=%r to %r for mode=%s.' % + (name, orig_value, value, self.mode)) + hparams.set_hparam(name, value) + override_hparams_maybe('schedule_sampling_k', -1) + return hparams diff --git a/video_prediction/models/mcnet_model.py b/video_prediction/models/mcnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..725ce4f46a301b6aa07f3d50ef811584d5b502db --- /dev/null +++ b/video_prediction/models/mcnet_model.py @@ -0,0 +1,467 @@ +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from video_prediction import ops, flow_ops +from video_prediction.models import BaseVideoPredictionModel +from video_prediction.models import networks +from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from video_prediction.layers import layer_def as ld +from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell +from video_prediction.layers.mcnet_ops import * +from video_prediction.utils.mcnet_utils import * +import os + +class McNetVideoPredictionModel(BaseVideoPredictionModel): + def __init__(self, mode='train', hparams_dict=None, + hparams=None, **kwargs): + super(McNetVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) + self.mode = mode + self.lr = self.hparams.lr + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length + self.predict_frames = self.sequence_length - self.context_frames + self.df_dim = self.hparams.df_dim + self.gf_dim = self.hparams.gf_dim + self.alpha = self.hparams.alpha + self.beta = self.hparams.beta + self.gen_images_enc = None + self.recon_loss = None + self.latent_loss = None + self.total_loss = None + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + batch_size: batch size for training. + lr: learning rate. if decay steps is non-zero, this is the + learning rate for steps <= decay_step. + + + + + max_steps: number of training steps. + + + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + df_dim: specific parameters for mcnet + gf_dim: specific parameters for menet + alpha: specific parameters for mcnet + beta: specific paramters for mcnet + + """ + default_hparams = super(McNetVideoPredictionModel, self).get_default_hparams_dict() + hparams = dict( + batch_size=16, + lr=0.001, + max_steps=350000, + context_frames = 10, + sequence_length = 20, + nz = 16, + gf_dim = 64, + df_dim = 64, + alpha = 1, + beta = 0.0 + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def build_graph(self, x): + + self.x = x["images"] + self.x_shape = self.x.get_shape().as_list() + self.batch_size = self.x_shape[0] + self.image_size = [self.x_shape[2],self.x_shape[3]] + self.c_dim = self.x_shape[4] + self.diff_shape = [self.batch_size, self.context_frames-1, self.image_size[0], + self.image_size[1], self.c_dim] + self.xt_shape = [self.batch_size, self.image_size[0], self.image_size[1],self.c_dim] + self.is_train = True + + + self.global_step = tf.Variable(0, name='global_step', trainable=False) + original_global_variables = tf.global_variables() + + # self.xt = tf.placeholder(tf.float32, self.xt_shape, name='xt') + self.xt = self.x[:, self.context_frames - 1, :, :, :] + + self.diff_in = tf.placeholder(tf.float32, self.diff_shape, name='diff_in') + diff_in_all = [] + for t in range(1, self.context_frames): + prev = self.x[:, t-1:t, :, :, :] + next = self.x[:, t:t+1, :, :, :] + #diff_in = tf.reshape(next - prev, [self.batch_size, 1, self.image_size[0], self.image_size[1], -1]) + print("prev:",prev) + print("next:",next) + diff_in = tf.subtract(next,prev) + print("diff_in:",diff_in) + diff_in_all.append(diff_in) + + self.diff_in = tf.concat(axis = 1, values = diff_in_all) + + cell = BasicConvLSTMCell([self.image_size[0] / 8, self.image_size[1] / 8], [3, 3], 256) + + pred = self.forward(self.diff_in, self.xt, cell) + + + self.G = tf.concat(axis=1, values=pred)#[batch_size,context_frames,image1,image2,channels] + print ("1:self.G:",self.G) + if self.is_train: + + true_sim = self.x[:, self.context_frames:, :, :, :] + + # Bing: the following make sure the channel is three dimension, if the channle is 3 then will be duplicated + if self.c_dim == 1: true_sim = tf.tile(true_sim, [1, 1, 1, 1, 3]) + + # Bing: the raw inputs shape is [batch_size, image_size[0],self.image_size[1], num_seq, channel]. tf.transpose will transpoe the shape into + # [batch size*num_seq, image_size0, image_size1, channels], for our era5 case, we do not need transpose + # true_sim = tf.reshape(tf.transpose(true_sim,[0,3,1,2,4]), + # [-1, self.image_size[0], + # self.image_size[1], 3]) + true_sim = tf.reshape(true_sim, [-1, self.image_size[0], self.image_size[1], 3]) + + + + + gen_sim = self.G + + #combine groud truth and predict frames + self.x_hat = tf.concat([self.x[:, :self.context_frames, :, :, :], self.G], 1) + print ("self.x_hat:",self.x_hat) + if self.c_dim == 1: gen_sim = tf.tile(gen_sim, [1, 1, 1, 1, 3]) + # gen_sim = tf.reshape(tf.transpose(gen_sim,[0,3,1,2,4]), + # [-1, self.image_size[0], + # self.image_size[1], 3]) + + gen_sim = tf.reshape(gen_sim, [-1, self.image_size[0], self.image_size[1], 3]) + + + binput = tf.reshape(tf.transpose(self.x[:, :self.context_frames, :, :, :], [0, 1, 2, 3, 4]), + [self.batch_size, self.image_size[0], + self.image_size[1], -1]) + + btarget = tf.reshape(tf.transpose(self.x[:, self.context_frames:, :, :, :], [0, 1, 2, 3, 4]), + [self.batch_size, self.image_size[0], + self.image_size[1], -1]) + bgen = tf.reshape(self.G, [self.batch_size, + self.image_size[0], + self.image_size[1], -1]) + + print ("binput:",binput) + print("btarget:",btarget) + print("bgen:",bgen) + + good_data = tf.concat(axis=3, values=[binput, btarget]) + gen_data = tf.concat(axis=3, values=[binput, bgen]) + self.gen_data = gen_data + print ("2:self.gen_data:", self.gen_data) + with tf.variable_scope("DIS", reuse=False): + self.D, self.D_logits = self.discriminator(good_data) + + with tf.variable_scope("DIS", reuse=True): + self.D_, self.D_logits_ = self.discriminator(gen_data) + + self.L_p = tf.reduce_mean( + tf.square(self.G - self.x[:, self.context_frames:, :, :, :])) + + self.L_gdl = gdl(gen_sim, true_sim, 1.) + self.L_img = self.L_p + self.L_gdl + + self.d_loss_real = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits( + logits = self.D_logits, labels = tf.ones_like(self.D) + )) + self.d_loss_fake = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits( + logits = self.D_logits_, labels = tf.zeros_like(self.D_) + )) + self.d_loss = self.d_loss_real + self.d_loss_fake + self.L_GAN = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits( + logits = self.D_logits_, labels = tf.ones_like(self.D_) + )) + + self.loss_sum = tf.summary.scalar("L_img", self.L_img) + self.L_p_sum = tf.summary.scalar("L_p", self.L_p) + self.L_gdl_sum = tf.summary.scalar("L_gdl", self.L_gdl) + self.L_GAN_sum = tf.summary.scalar("L_GAN", self.L_GAN) + self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) + self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real) + self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake) + + self.total_loss = self.alpha * self.L_img + self.beta * self.L_GAN + self._loss_sum = tf.summary.scalar("total_loss", self.total_loss) + self.g_sum = tf.summary.merge([self.L_p_sum, + self.L_gdl_sum, self.loss_sum, + self.L_GAN_sum]) + self.d_sum = tf.summary.merge([self.d_loss_real_sum, self.d_loss_sum, + self.d_loss_fake_sum]) + + + self.t_vars = tf.trainable_variables() + self.g_vars = [var for var in self.t_vars if 'DIS' not in var.name] + self.d_vars = [var for var in self.t_vars if 'DIS' in var.name] + num_param = 0.0 + for var in self.g_vars: + num_param += int(np.prod(var.get_shape())); + print("Number of parameters: %d" % num_param) + + # Training + self.d_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize( + self.d_loss, var_list = self.d_vars) + self.g_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize( + self.alpha * self.L_img + self.beta * self.L_GAN, var_list = self.g_vars, global_step=self.global_step) + + self.train_op = [self.d_optim,self.g_optim] + self.outputs = {} + self.outputs["gen_images"] = self.x_hat + + + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + return + + + def forward(self, diff_in, xt, cell): + # Initial state + state = tf.zeros([self.batch_size, self.image_size[0] / 8, + self.image_size[1] / 8, 512]) + reuse = False + # Encoder + for t in range(self.context_frames - 1): + enc_h, res_m = self.motion_enc(diff_in[:, t, :, :, :], reuse = reuse) + h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = reuse) + reuse = True + pred = [] + # Decoder + for t in range(self.predict_frames): + if t == 0: + h_cont, res_c = self.content_enc(xt, reuse = False) + h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = False) + res_connect = self.residual(res_m, res_c, reuse = False) + x_hat = self.dec_cnn(h_tp1, res_connect, reuse = False) + + else: + + enc_h, res_m = self.motion_enc(diff_in, reuse = True) + h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = True) + h_cont, res_c = self.content_enc(xt, reuse = reuse) + h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = True) + res_connect = self.residual(res_m, res_c, reuse = True) + x_hat = self.dec_cnn(h_tp1, res_connect, reuse = True) + print ("x_hat :",x_hat) + if self.c_dim == 3: + # Network outputs are BGR so they need to be reversed to use + # rgb_to_grayscale + #x_hat_gray = tf.concat(axis=3,values=[x_hat[:,:,:,2:3], x_hat[:,:,:,1:2],x_hat[:,:,:,0:1]]) + #xt_gray = tf.concat(axis=3,values=[xt[:,:,:,2:3], xt[:,:,:,1:2],xt[:,:,:,0:1]]) + + # x_hat_gray = 1./255.*tf.image.rgb_to_grayscale( + # inverse_transform(x_hat_rgb)*255. + # ) + # xt_gray = 1./255.*tf.image.rgb_to_grayscale( + # inverse_transform(xt_rgb)*255. + # ) + + x_hat_gray = x_hat + xt_gray = xt + else: + x_hat_gray = inverse_transform(x_hat) + xt_gray = inverse_transform(xt) + + diff_in = x_hat_gray - xt_gray + xt = x_hat + + + pred.append(tf.reshape(x_hat, [self.batch_size, 1, self.image_size[0], + self.image_size[1], self.c_dim])) + + return pred + + def motion_enc(self, diff_in, reuse): + res_in = [] + + conv1 = relu(conv2d(diff_in, output_dim = self.gf_dim, k_h = 5, k_w = 5, + d_h = 1, d_w = 1, name = 'dyn1_conv1', reuse = reuse)) + res_in.append(conv1) + pool1 = MaxPooling(conv1, [2, 2]) + + conv2 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 5, k_w = 5, + d_h = 1, d_w = 1, name = 'dyn_conv2', reuse = reuse)) + res_in.append(conv2) + pool2 = MaxPooling(conv2, [2, 2]) + + conv3 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 7, k_w = 7, + d_h = 1, d_w = 1, name = 'dyn_conv3', reuse = reuse)) + res_in.append(conv3) + pool3 = MaxPooling(conv3, [2, 2]) + return pool3, res_in + + def content_enc(self, xt, reuse): + res_in = [] + conv1_1 = relu(conv2d(xt, output_dim = self.gf_dim, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv1_1', reuse = reuse)) + conv1_2 = relu(conv2d(conv1_1, output_dim = self.gf_dim, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv1_2', reuse = reuse)) + res_in.append(conv1_2) + pool1 = MaxPooling(conv1_2, [2, 2]) + + conv2_1 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv2_1', reuse = reuse)) + conv2_2 = relu(conv2d(conv2_1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv2_2', reuse = reuse)) + res_in.append(conv2_2) + pool2 = MaxPooling(conv2_2, [2, 2]) + + conv3_1 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv3_1', reuse = reuse)) + conv3_2 = relu(conv2d(conv3_1, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv3_2', reuse = reuse)) + conv3_3 = relu(conv2d(conv3_2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv3_3', reuse = reuse)) + res_in.append(conv3_3) + pool3 = MaxPooling(conv3_3, [2, 2]) + return pool3, res_in + + def comb_layers(self, h_dyn, h_cont, reuse=False): + comb1 = relu(conv2d(tf.concat(axis = 3, values = [h_dyn, h_cont]), + output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'comb1', reuse = reuse)) + comb2 = relu(conv2d(comb1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'comb2', reuse = reuse)) + h_comb = relu(conv2d(comb2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'h_comb', reuse = reuse)) + return h_comb + + def residual(self, input_dyn, input_cont, reuse=False): + n_layers = len(input_dyn) + res_out = [] + for l in range(n_layers): + input_ = tf.concat(axis = 3, values = [input_dyn[l], input_cont[l]]) + out_dim = input_cont[l].get_shape()[3] + res1 = relu(conv2d(input_, output_dim = out_dim, + k_h = 3, k_w = 3, d_h = 1, d_w = 1, + name = 'res' + str(l) + '_1', reuse = reuse)) + res2 = conv2d(res1, output_dim = out_dim, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'res' + str(l) + '_2', reuse = reuse) + res_out.append(res2) + return res_out + + def dec_cnn(self, h_comb, res_connect, reuse=False): + + shapel3 = [self.batch_size, int(self.image_size[0] / 4), + int(self.image_size[1] / 4), self.gf_dim * 4] + shapeout3 = [self.batch_size, int(self.image_size[0] / 4), + int(self.image_size[1] / 4), self.gf_dim * 2] + depool3 = FixedUnPooling(h_comb, [2, 2]) + deconv3_3 = relu(deconv2d(relu(tf.add(depool3, res_connect[2])), + output_shape = shapel3, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv3_3', reuse = reuse)) + deconv3_2 = relu(deconv2d(deconv3_3, output_shape = shapel3, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv3_2', reuse = reuse)) + deconv3_1 = relu(deconv2d(deconv3_2, output_shape = shapeout3, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv3_1', reuse = reuse)) + + shapel2 = [self.batch_size, int(self.image_size[0] / 2), + int(self.image_size[1] / 2), self.gf_dim * 2] + shapeout3 = [self.batch_size, int(self.image_size[0] / 2), + int(self.image_size[1] / 2), self.gf_dim] + depool2 = FixedUnPooling(deconv3_1, [2, 2]) + deconv2_2 = relu(deconv2d(relu(tf.add(depool2, res_connect[1])), + output_shape = shapel2, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv2_2', reuse = reuse)) + deconv2_1 = relu(deconv2d(deconv2_2, output_shape = shapeout3, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv2_1', reuse = reuse)) + + shapel1 = [self.batch_size, self.image_size[0], + self.image_size[1], self.gf_dim] + shapeout1 = [self.batch_size, self.image_size[0], + self.image_size[1], self.c_dim] + depool1 = FixedUnPooling(deconv2_1, [2, 2]) + deconv1_2 = relu(deconv2d(relu(tf.add(depool1, res_connect[0])), + output_shape = shapel1, k_h = 3, k_w = 3, d_h = 1, d_w = 1, + name = 'dec_deconv1_2', reuse = reuse)) + xtp1 = tanh(deconv2d(deconv1_2, output_shape = shapeout1, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv1_1', reuse = reuse)) + return xtp1 + + def discriminator(self, image): + h0 = lrelu(conv2d(image, self.df_dim, name = 'dis_h0_conv')) + h1 = lrelu(batch_norm(conv2d(h0, self.df_dim * 2, name = 'dis_h1_conv'), + "bn1")) + h2 = lrelu(batch_norm(conv2d(h1, self.df_dim * 4, name = 'dis_h2_conv'), + "bn2")) + h3 = lrelu(batch_norm(conv2d(h2, self.df_dim * 8, name = 'dis_h3_conv'), + "bn3")) + h = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'dis_h3_lin') + + return tf.nn.sigmoid(h), h + + def save(self, sess, checkpoint_dir, step): + model_name = "MCNET.model" + + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + self.saver.save(sess, + os.path.join(checkpoint_dir, model_name), + global_step = step) + + def load(self, sess, checkpoint_dir, model_name=None): + print(" [*] Reading checkpoints...") + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + ckpt_name = os.path.basename(ckpt.model_checkpoint_path) + if model_name is None: model_name = ckpt_name + self.saver.restore(sess, os.path.join(checkpoint_dir, model_name)) + print(" Loaded model: " + str(model_name)) + return True, model_name + else: + return False, None + + # Execute the forward and the backward pass + + def run_single_step(self, global_step): + print("global_step:", global_step) + try: + train_batch = self.sess.run(self.train_iterator.get_next()) + # z=np.random.uniform(-1,1,size=(self.batch_size,self.nz)) + x = self.sess.run([self.x], feed_dict = {self.x: train_batch["images"]}) + _, g_sum = self.sess.run([self.g_optim, self.g_sum], feed_dict = {self.x: train_batch["images"]}) + _, d_sum = self.sess.run([self.d_optim, self.d_sum], feed_dict = {self.x: train_batch["images"]}) + + gen_data, train_loss = self.sess.run([self.gen_data, self.total_loss], + feed_dict = {self.x: train_batch["images"]}) + + except tf.errors.OutOfRangeError: + print("train out of range error") + + try: + val_batch = self.sess.run(self.val_iterator.get_next()) + val_loss = self.sess.run([self.total_loss], feed_dict = {self.x: val_batch["images"]}) + # self.val_writer.add_summary(val_summary, global_step) + except tf.errors.OutOfRangeError: + print("train out of range error") + + return train_loss, val_total_loss + + + diff --git a/video_prediction/models/networks.py b/video_prediction/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..844c28295cfea3597bf6a1ce52c9b0f3891370a9 --- /dev/null +++ b/video_prediction/models/networks.py @@ -0,0 +1,110 @@ +import tensorflow as tf +from tensorflow.python.util import nest + +from video_prediction import ops +from video_prediction.ops import conv2d +from video_prediction.ops import dense +from video_prediction.ops import lrelu +from video_prediction.ops import pool2d +from video_prediction.utils import tf_utils + + +def encoder(inputs, nef=64, n_layers=3, norm_layer='instance'): + print("********inputs*******",inputs.get_shape()) + print("*********nef********", nef) + norm_layer = ops.get_norm_layer(norm_layer) + layers = [] + paddings = [[0, 0], [1, 1], [1, 1], [0, 0]] + + with tf.variable_scope("layer_1"): + convolved = conv2d(tf.pad(inputs, paddings), nef, kernel_size=4, strides=2, padding='VALID') + rectified = lrelu(convolved, 0.2) + layers.append(rectified) + + for i in range(1, n_layers): + with tf.variable_scope("layer_%d" % (len(layers) + 1)): + out_channels = nef * min(2**i, 4) + convolved = conv2d(tf.pad(layers[-1], paddings), out_channels, kernel_size=4, strides=2, padding='VALID') + normalized = norm_layer(convolved) + rectified = lrelu(normalized, 0.2) + layers.append(rectified) + + pooled = pool2d(rectified, rectified.shape.as_list()[1:3], padding='VALID', pool_mode='avg') + squeezed = tf.squeeze(pooled, [1, 2]) + return squeezed + + +def image_sn_discriminator(images, ndf=64): + batch_size = images.shape[0].value + layers = [] + paddings = [[0, 0], [1, 1], [1, 1], [0, 0]] + + def conv2d(inputs, *args, **kwargs): + kwargs.setdefault('padding', 'VALID') + kwargs.setdefault('use_spectral_norm', True) + return ops.conv2d(tf.pad(inputs, paddings), *args, **kwargs) + + with tf.variable_scope("sn_conv0_0"): + layers.append(lrelu(conv2d(images, ndf, kernel_size=3, strides=1), 0.1)) + + with tf.variable_scope("sn_conv0_1"): + layers.append(lrelu(conv2d(layers[-1], ndf * 2, kernel_size=4, strides=2), 0.1)) + + with tf.variable_scope("sn_conv1_0"): + layers.append(lrelu(conv2d(layers[-1], ndf * 2, kernel_size=3, strides=1), 0.1)) + + with tf.variable_scope("sn_conv1_1"): + layers.append(lrelu(conv2d(layers[-1], ndf * 4, kernel_size=4, strides=2), 0.1)) + + with tf.variable_scope("sn_conv2_0"): + layers.append(lrelu(conv2d(layers[-1], ndf * 4, kernel_size=3, strides=1), 0.1)) + + with tf.variable_scope("sn_conv2_1"): + layers.append(lrelu(conv2d(layers[-1], ndf * 8, kernel_size=4, strides=2), 0.1)) + + with tf.variable_scope("sn_conv3_0"): + layers.append(lrelu(conv2d(layers[-1], ndf * 8, kernel_size=3, strides=1), 0.1)) + + with tf.variable_scope("sn_fc4"): + logits = dense(tf.reshape(layers[-1], [batch_size, -1]), 1, use_spectral_norm=True) + layers.append(logits) + return layers + + +def video_sn_discriminator(clips, ndf=64): + clips = tf_utils.transpose_batch_time(clips) + batch_size = clips.shape[0].value + layers = [] + paddings = [[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]] + + def conv3d(inputs, *args, **kwargs): + kwargs.setdefault('padding', 'VALID') + kwargs.setdefault('use_spectral_norm', True) + return ops.conv3d(tf.pad(inputs, paddings), *args, **kwargs) + + with tf.variable_scope("sn_conv0_0"): + layers.append(lrelu(conv3d(clips, ndf, kernel_size=3, strides=1), 0.1)) + + with tf.variable_scope("sn_conv0_1"): + layers.append(lrelu(conv3d(layers[-1], ndf * 2, kernel_size=4, strides=(1, 2, 2)), 0.1)) + + with tf.variable_scope("sn_conv1_0"): + layers.append(lrelu(conv3d(layers[-1], ndf * 2, kernel_size=3, strides=1), 0.1)) + + with tf.variable_scope("sn_conv1_1"): + layers.append(lrelu(conv3d(layers[-1], ndf * 4, kernel_size=4, strides=(1, 2, 2)), 0.1)) + + with tf.variable_scope("sn_conv2_0"): + layers.append(lrelu(conv3d(layers[-1], ndf * 4, kernel_size=3, strides=1), 0.1)) + + with tf.variable_scope("sn_conv2_1"): + layers.append(lrelu(conv3d(layers[-1], ndf * 8, kernel_size=4, strides=2), 0.1)) + + with tf.variable_scope("sn_conv3_0"): + layers.append(lrelu(conv3d(layers[-1], ndf * 8, kernel_size=3, strides=1), 0.1)) + + with tf.variable_scope("sn_fc4"): + logits = dense(tf.reshape(layers[-1], [batch_size, -1]), 1, use_spectral_norm=True) + layers.append(logits) + layers = nest.map_structure(tf_utils.transpose_batch_time, layers) + return layers diff --git a/video_prediction/models/non_trainable_model.py b/video_prediction/models/non_trainable_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a5082b5c0357d37262b3236185be89e146d92a84 --- /dev/null +++ b/video_prediction/models/non_trainable_model.py @@ -0,0 +1,54 @@ +from collections import OrderedDict +from tensorflow.python.util import nest +from video_prediction.utils.tf_utils import transpose_batch_time + +import tensorflow as tf + +from .base_model import BaseVideoPredictionModel + + +class NonTrainableVideoPredictionModel(BaseVideoPredictionModel): + pass + + +class GroundTruthVideoPredictionModel(NonTrainableVideoPredictionModel): + def build_graph(self, inputs): + super(GroundTruthVideoPredictionModel, self).build_graph(inputs) + + self.outputs = OrderedDict() + self.outputs['gen_images'] = self.inputs['images'][:, 1:] + if 'pix_distribs' in self.inputs: + self.outputs['gen_pix_distribs'] = self.inputs['pix_distribs'][:, 1:] + + inputs, outputs = nest.map_structure(transpose_batch_time, (self.inputs, self.outputs)) + with tf.name_scope("metrics"): + metrics = self.metrics_fn(inputs, outputs) + with tf.name_scope("eval_outputs_and_metrics"): + eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs) + self.metrics, self.eval_outputs, self.eval_metrics = nest.map_structure( + transpose_batch_time, (metrics, eval_outputs, eval_metrics)) + + +class RepeatVideoPredictionModel(NonTrainableVideoPredictionModel): + def build_graph(self, inputs): + super(RepeatVideoPredictionModel, self).build_graph(inputs) + + self.outputs = OrderedDict() + tile_pattern = [1, self.hparams.sequence_length - self.hparams.context_frames, 1, 1, 1] + last_context_images = self.inputs['images'][:, self.hparams.context_frames - 1] + self.outputs['gen_images'] = tf.concat([ + self.inputs['images'][:, 1:self.hparams.context_frames - 1], + tf.tile(last_context_images[:, None], tile_pattern)], axis=-1) + if 'pix_distribs' in self.inputs: + last_context_pix_distrib = self.inputs['pix_distribs'][:, self.hparams.context_frames - 1] + self.outputs['gen_pix_distribs'] = tf.concat([ + self.inputs['pix_distribs'][:, 1:self.hparams.context_frames - 1], + tf.tile(last_context_pix_distrib[:, None], tile_pattern)], axis=-1) + + inputs, outputs = nest.map_structure(transpose_batch_time, (self.inputs, self.outputs)) + with tf.name_scope("metrics"): + metrics = self.metrics_fn(inputs, outputs) + with tf.name_scope("eval_outputs_and_metrics"): + eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs) + self.metrics, self.eval_outputs, self.eval_metrics = nest.map_structure( + transpose_batch_time, (metrics, eval_outputs, eval_metrics)) diff --git a/video_prediction/models/savp_model.py b/video_prediction/models/savp_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ca8acd3f32a5ea1772c9fbf36003149acfdcb950 --- /dev/null +++ b/video_prediction/models/savp_model.py @@ -0,0 +1,993 @@ +import collections +import functools +import itertools +from collections import OrderedDict + +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest + +from video_prediction import ops, flow_ops +from video_prediction.models import VideoPredictionModel +from video_prediction.models import networks +from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from video_prediction.utils import tf_utils + +# Amount to use when lower bounding tensors +RELU_SHIFT = 1e-12 + + +def posterior_fn(inputs, hparams): + images = inputs['images'] + image_pairs = tf.concat([images[:-1], images[1:]], axis=-1) + if 'actions' in inputs: + image_pairs = tile_concat( + [image_pairs, inputs['actions'][..., None, None, :]], axis=-1) + + h = tf_utils.with_flat_batch(networks.encoder)( + image_pairs, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer) + + if hparams.use_e_rnn: + with tf.variable_scope('layer_%d' % (hparams.n_layers + 1)): + h = tf_utils.with_flat_batch(dense, 2)(h, hparams.nef * 4) + + if hparams.rnn == 'lstm': + RNNCell = tf.contrib.rnn.BasicLSTMCell + elif hparams.rnn == 'gru': + RNNCell = tf.contrib.rnn.GRUCell + else: + raise NotImplementedError + with tf.variable_scope('%s' % hparams.rnn): + rnn_cell = RNNCell(hparams.nef * 4) + h, _ = tf_utils.unroll_rnn(rnn_cell, h) + + with tf.variable_scope('z_mu'): + z_mu = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz) + with tf.variable_scope('z_log_sigma_sq'): + z_log_sigma_sq = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz) + z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10) + outputs = {'zs_mu': z_mu, 'zs_log_sigma_sq': z_log_sigma_sq} + return outputs + + +def prior_fn(inputs, hparams): + images = inputs['images'] + image_pairs = tf.concat([images[:hparams.context_frames - 1], images[1:hparams.context_frames]], axis=-1) + if 'actions' in inputs: + image_pairs = tile_concat( + [image_pairs, inputs['actions'][..., None, None, :]], axis=-1) + + h = tf_utils.with_flat_batch(networks.encoder)( + image_pairs, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer) + h_zeros = tf.zeros(tf.concat([[hparams.sequence_length - hparams.context_frames], tf.shape(h)[1:]], axis=0)) + h = tf.concat([h, h_zeros], axis=0) + + with tf.variable_scope('layer_%d' % (hparams.n_layers + 1)): + h = tf_utils.with_flat_batch(dense, 2)(h, hparams.nef * 4) + + if hparams.rnn == 'lstm': + RNNCell = tf.contrib.rnn.BasicLSTMCell + elif hparams.rnn == 'gru': + RNNCell = tf.contrib.rnn.GRUCell + else: + raise NotImplementedError + with tf.variable_scope('%s' % hparams.rnn): + rnn_cell = RNNCell(hparams.nef * 4) + h, _ = tf_utils.unroll_rnn(rnn_cell, h) + + with tf.variable_scope('z_mu'): + z_mu = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz) + with tf.variable_scope('z_log_sigma_sq'): + z_log_sigma_sq = tf_utils.with_flat_batch(dense, 2)(h, hparams.nz) + z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10) + outputs = {'zs_mu': z_mu, 'zs_log_sigma_sq': z_log_sigma_sq} + return outputs + + +def discriminator_given_video_fn(targets, hparams): + sequence_length, batch_size = targets.shape.as_list()[:2] + clip_length = hparams.clip_length + + # sample an image and apply the image distriminator on that frame + t_sample = tf.random_uniform([batch_size], minval=0, maxval=sequence_length, dtype=tf.int32) + image_sample = tf.gather_nd(targets, tf.stack([t_sample, tf.range(batch_size)], axis=1)) + + # sample a subsequence of length clip_length and apply the images/video discriminators on those frames + t_start = tf.random_uniform([batch_size], minval=0, maxval=sequence_length - clip_length + 1, dtype=tf.int32) + t_start_indices = tf.stack([t_start, tf.range(batch_size)], axis=1) + t_offset_indices = tf.stack([tf.range(clip_length), tf.zeros(clip_length, dtype=tf.int32)], axis=1) + indices = t_start_indices[None] + t_offset_indices[:, None] + clip_sample = tf.gather_nd(targets, flatten(indices, 0, 1)) + clip_sample = tf.reshape(clip_sample, [clip_length] + targets.shape.as_list()[1:]) + + outputs = {} + if hparams.image_sn_gan_weight or hparams.image_sn_vae_gan_weight: + with tf.variable_scope('image'): + image_features = networks.image_sn_discriminator(image_sample, ndf=hparams.ndf) + image_features, image_logits = image_features[:-1], image_features[-1] + outputs['discrim_image_sn_logits'] = image_logits + for i, image_feature in enumerate(image_features): + outputs['discrim_image_sn_feature%d' % i] = image_feature + if hparams.video_sn_gan_weight or hparams.video_sn_vae_gan_weight: + with tf.variable_scope('video'): + video_features = networks.video_sn_discriminator(clip_sample, ndf=hparams.ndf) + video_features, video_logits = video_features[:-1], video_features[-1] + outputs['discrim_video_sn_logits'] = video_logits + for i, video_feature in enumerate(video_features): + outputs['discrim_video_sn_feature%d' % i] = video_feature + if hparams.images_sn_gan_weight or hparams.images_sn_vae_gan_weight: + with tf.variable_scope('images'): + images_features = tf_utils.with_flat_batch(networks.image_sn_discriminator)(clip_sample, ndf=hparams.ndf) + images_features, images_logits = images_features[:-1], images_features[-1] + outputs['discrim_images_sn_logits'] = images_logits + for i, images_feature in enumerate(images_features): + outputs['discrim_images_sn_feature%d' % i] = images_feature + return outputs + + +def discriminator_fn(inputs, outputs, mode, hparams): + # do the encoder version first so that it isn't affected by the reuse_variables() call + if hparams.nz == 0: + discrim_outputs_enc_real = collections.OrderedDict() + discrim_outputs_enc_fake = collections.OrderedDict() + else: + images_enc_real = inputs['images'][1:] + images_enc_fake = outputs['gen_images_enc'] + if hparams.use_same_discriminator: + with tf.name_scope("real"): + discrim_outputs_enc_real = discriminator_given_video_fn(images_enc_real, hparams) + tf.get_variable_scope().reuse_variables() + with tf.name_scope("fake"): + discrim_outputs_enc_fake = discriminator_given_video_fn(images_enc_fake, hparams) + else: + with tf.variable_scope('encoder'), tf.name_scope("real"): + discrim_outputs_enc_real = discriminator_given_video_fn(images_enc_real, hparams) + with tf.variable_scope('encoder', reuse=True), tf.name_scope("fake"): + discrim_outputs_enc_fake = discriminator_given_video_fn(images_enc_fake, hparams) + + images_real = inputs['images'][1:] + images_fake = outputs['gen_images'] + with tf.name_scope("real"): + discrim_outputs_real = discriminator_given_video_fn(images_real, hparams) + tf.get_variable_scope().reuse_variables() + with tf.name_scope("fake"): + discrim_outputs_fake = discriminator_given_video_fn(images_fake, hparams) + + discrim_outputs_real = OrderedDict([(k + '_real', v) for k, v in discrim_outputs_real.items()]) + discrim_outputs_fake = OrderedDict([(k + '_fake', v) for k, v in discrim_outputs_fake.items()]) + discrim_outputs_enc_real = OrderedDict([(k + '_enc_real', v) for k, v in discrim_outputs_enc_real.items()]) + discrim_outputs_enc_fake = OrderedDict([(k + '_enc_fake', v) for k, v in discrim_outputs_enc_fake.items()]) + outputs = [discrim_outputs_real, discrim_outputs_fake, + discrim_outputs_enc_real, discrim_outputs_enc_fake] + total_num_outputs = sum([len(output) for output in outputs]) + outputs = collections.OrderedDict(itertools.chain(*[output.items() for output in outputs])) + assert len(outputs) == total_num_outputs # ensure no output is lost because of repeated keys + return outputs + + +class SAVPCell(tf.nn.rnn_cell.RNNCell): + def __init__(self, inputs, mode, hparams, reuse=None): + super(SAVPCell, self).__init__(_reuse=reuse) + self.inputs = inputs + self.mode = mode + self.hparams = hparams + + if self.hparams.where_add not in ('input', 'all', 'middle'): + raise ValueError('Invalid where_add %s' % self.hparams.where_add) + + batch_size = inputs['images'].shape[1].value + image_shape = inputs['images'].shape.as_list()[2:] + height, width, _ = image_shape + scale_size = min(height, width) + if scale_size >= 256: + self.encoder_layer_specs = [ + (self.hparams.ngf, False), + (self.hparams.ngf * 2, False), + (self.hparams.ngf * 4, True), + (self.hparams.ngf * 8, True), + (self.hparams.ngf * 8, True), + ] + self.decoder_layer_specs = [ + (self.hparams.ngf * 8, True), + (self.hparams.ngf * 4, True), + (self.hparams.ngf * 2, False), + (self.hparams.ngf, False), + (self.hparams.ngf, False), + ] + elif scale_size >= 128: + self.encoder_layer_specs = [ + (self.hparams.ngf, False), + (self.hparams.ngf * 2, True), + (self.hparams.ngf * 4, True), + (self.hparams.ngf * 8, True), + ] + self.decoder_layer_specs = [ + (self.hparams.ngf * 8, True), + (self.hparams.ngf * 4, True), + (self.hparams.ngf * 2, False), + (self.hparams.ngf, False), + ] + elif scale_size >= 64: + self.encoder_layer_specs = [ + (self.hparams.ngf, True), + (self.hparams.ngf * 2, True), + (self.hparams.ngf * 4, True), + ] + self.decoder_layer_specs = [ + (self.hparams.ngf * 2, True), + (self.hparams.ngf, True), + (self.hparams.ngf, False), + ] + elif scale_size >= 32: + self.encoder_layer_specs = [ + (self.hparams.ngf, True), + (self.hparams.ngf * 2, True), + ] + self.decoder_layer_specs = [ + (self.hparams.ngf, True), + (self.hparams.ngf, False), + ] + else: + raise NotImplementedError + assert len(self.encoder_layer_specs) == len(self.decoder_layer_specs) + total_stride = 2 ** len(self.encoder_layer_specs) + if (height % total_stride) or (width % total_stride): + raise ValueError("The image has dimension (%d, %d), but it should be divisible " + "by the total stride, which is %d." % (height, width, total_stride)) + + # output_size + num_masks = self.hparams.last_frames * self.hparams.num_transformed_images + \ + int(bool(self.hparams.prev_image_background)) + \ + int(bool(self.hparams.first_image_background and not self.hparams.context_images_background)) + \ + int(bool(self.hparams.last_image_background and not self.hparams.context_images_background)) + \ + int(bool(self.hparams.last_context_image_background and not self.hparams.context_images_background)) + \ + (self.hparams.context_frames if self.hparams.context_images_background else 0) + \ + int(bool(self.hparams.generate_scratch_image)) + output_size = { + 'gen_images': tf.TensorShape(image_shape), + 'transformed_images': tf.TensorShape(image_shape + [num_masks]), + 'masks': tf.TensorShape([height, width, 1, num_masks]), + } + if 'pix_distribs' in inputs: + num_motions = inputs['pix_distribs'].shape[-1].value + output_size['gen_pix_distribs'] = tf.TensorShape([height, width, num_motions]) + output_size['transformed_pix_distribs'] = tf.TensorShape([height, width, num_motions, num_masks]) + if 'states' in inputs: + output_size['gen_states'] = inputs['states'].shape[2:] + if self.hparams.transformation == 'flow': + output_size['gen_flows'] = tf.TensorShape([height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images]) + output_size['gen_flows_rgb'] = tf.TensorShape([height, width, 3, self.hparams.last_frames * self.hparams.num_transformed_images]) + self._output_size = output_size + + # state_size + conv_rnn_state_sizes = [] + conv_rnn_height, conv_rnn_width = height, width + for out_channels, use_conv_rnn in self.encoder_layer_specs: + conv_rnn_height //= 2 + conv_rnn_width //= 2 + if use_conv_rnn and not self.hparams.ablation_rnn: + conv_rnn_state_sizes.append(tf.TensorShape([conv_rnn_height, conv_rnn_width, out_channels])) + for out_channels, use_conv_rnn in self.decoder_layer_specs: + conv_rnn_height *= 2 + conv_rnn_width *= 2 + if use_conv_rnn and not self.hparams.ablation_rnn: + conv_rnn_state_sizes.append(tf.TensorShape([conv_rnn_height, conv_rnn_width, out_channels])) + if self.hparams.conv_rnn == 'lstm': + conv_rnn_state_sizes = [tf.nn.rnn_cell.LSTMStateTuple(conv_rnn_state_size, conv_rnn_state_size) + for conv_rnn_state_size in conv_rnn_state_sizes] + state_size = {'time': tf.TensorShape([]), + 'gen_image': tf.TensorShape(image_shape), + 'last_images': [tf.TensorShape(image_shape)] * self.hparams.last_frames, + 'conv_rnn_states': conv_rnn_state_sizes} + if 'zs' in inputs and self.hparams.use_rnn_z and not self.hparams.ablation_rnn: + rnn_z_state_size = tf.TensorShape([self.hparams.nz]) + if self.hparams.rnn == 'lstm': + rnn_z_state_size = tf.nn.rnn_cell.LSTMStateTuple(rnn_z_state_size, rnn_z_state_size) + state_size['rnn_z_state'] = rnn_z_state_size + if 'pix_distribs' in inputs: + state_size['gen_pix_distrib'] = tf.TensorShape([height, width, num_motions]) + state_size['last_pix_distribs'] = [tf.TensorShape([height, width, num_motions])] * self.hparams.last_frames + if 'states' in inputs: + state_size['gen_state'] = inputs['states'].shape[2:] + self._state_size = state_size + + if self.hparams.learn_initial_state: + learnable_initial_state_size = {k: v for k, v in state_size.items() + if k in ('conv_rnn_states', 'rnn_z_state')} + else: + learnable_initial_state_size = {} + learnable_initial_state_flat = [] + for i, size in enumerate(nest.flatten(learnable_initial_state_size)): + with tf.variable_scope('initial_state_%d' % i): + state = tf.get_variable('initial_state', size, + dtype=tf.float32, initializer=tf.zeros_initializer()) + learnable_initial_state_flat.append(state) + self._learnable_initial_state = nest.pack_sequence_as( + learnable_initial_state_size, learnable_initial_state_flat) + + ground_truth_sampling_shape = [self.hparams.sequence_length - 1 - self.hparams.context_frames, batch_size] + if self.hparams.schedule_sampling == 'none' or self.mode != 'train': + ground_truth_sampling = tf.constant(False, dtype=tf.bool, shape=ground_truth_sampling_shape) + elif self.hparams.schedule_sampling in ('inverse_sigmoid', 'linear'): + if self.hparams.schedule_sampling == 'inverse_sigmoid': + k = self.hparams.schedule_sampling_k + start_step = self.hparams.schedule_sampling_steps[0] + iter_num = tf.to_float(tf.train.get_or_create_global_step()) + prob = (k / (k + tf.exp((iter_num - start_step) / k))) + prob = tf.cond(tf.less(iter_num, start_step), lambda: 1.0, lambda: prob) + elif self.hparams.schedule_sampling == 'linear': + start_step, end_step = self.hparams.schedule_sampling_steps + step = tf.clip_by_value(tf.train.get_or_create_global_step(), start_step, end_step) + prob = 1.0 - tf.to_float(step - start_step) / tf.to_float(end_step - start_step) + log_probs = tf.log([1 - prob, prob]) + ground_truth_sampling = tf.multinomial([log_probs] * batch_size, ground_truth_sampling_shape[0]) + ground_truth_sampling = tf.cast(tf.transpose(ground_truth_sampling, [1, 0]), dtype=tf.bool) + # Ensure that eventually, the model is deterministically + # autoregressive (as opposed to autoregressive with very high probability). + ground_truth_sampling = tf.cond(tf.less(prob, 0.001), + lambda: tf.constant(False, dtype=tf.bool, shape=ground_truth_sampling_shape), + lambda: ground_truth_sampling) + else: + raise NotImplementedError + ground_truth_context = tf.constant(True, dtype=tf.bool, shape=[self.hparams.context_frames, batch_size]) + self.ground_truth = tf.concat([ground_truth_context, ground_truth_sampling], axis=0) + + @property + def output_size(self): + return self._output_size + + @property + def state_size(self): + return self._state_size + + def zero_state(self, batch_size, dtype): + init_state = super(SAVPCell, self).zero_state(batch_size, dtype) + learnable_init_state = nest.map_structure( + lambda x: tf.tile(x[None], [batch_size] + [1] * x.shape.ndims), self._learnable_initial_state) + init_state.update(learnable_init_state) + init_state['last_images'] = [self.inputs['images'][0]] * self.hparams.last_frames + if 'pix_distribs' in self.inputs: + init_state['last_pix_distribs'] = [self.inputs['pix_distribs'][0]] * self.hparams.last_frames + return init_state + + def _rnn_func(self, inputs, state, num_units): + if self.hparams.rnn == 'lstm': + RNNCell = functools.partial(tf.nn.rnn_cell.LSTMCell, name='basic_lstm_cell') + elif self.hparams.rnn == 'gru': + RNNCell = tf.contrib.rnn.GRUCell + else: + raise NotImplementedError + rnn_cell = RNNCell(num_units, reuse=tf.get_variable_scope().reuse) + return rnn_cell(inputs, state) + + def _conv_rnn_func(self, inputs, state, filters): + if isinstance(inputs, (list, tuple)): + inputs_shape = inputs[0].shape.as_list() + else: + inputs_shape = inputs.shape.as_list() + input_shape = inputs_shape[1:] + if self.hparams.conv_rnn_norm_layer == 'none': + normalizer_fn = None + else: + normalizer_fn = ops.get_norm_layer(self.hparams.conv_rnn_norm_layer) + if self.hparams.conv_rnn == 'lstm': + Conv2DRNNCell = BasicConv2DLSTMCell + elif self.hparams.conv_rnn == 'gru': + Conv2DRNNCell = Conv2DGRUCell + else: + raise NotImplementedError + if self.hparams.ablation_conv_rnn_norm: + conv_rnn_cell = Conv2DRNNCell(input_shape, filters, kernel_size=(5, 5), + reuse=tf.get_variable_scope().reuse) + h, state = conv_rnn_cell(inputs, state) + outputs = (normalizer_fn(h), state) + else: + conv_rnn_cell = Conv2DRNNCell(input_shape, filters, kernel_size=(5, 5), + normalizer_fn=normalizer_fn, + separate_norms=self.hparams.conv_rnn_norm_layer == 'layer', + reuse=tf.get_variable_scope().reuse) + outputs = conv_rnn_cell(inputs, state) + return outputs + + def call(self, inputs, states): + norm_layer = ops.get_norm_layer(self.hparams.norm_layer) + downsample_layer = ops.get_downsample_layer(self.hparams.downsample_layer) + upsample_layer = ops.get_upsample_layer(self.hparams.upsample_layer) + activation_layer = ops.get_activation_layer(self.hparams.activation_layer) + image_shape = inputs['images'].get_shape().as_list() + batch_size, height, width, color_channels = image_shape + conv_rnn_states = states['conv_rnn_states'] + + time = states['time'] + with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): + t = tf.to_int32(tf.identity(time[0])) + + image = tf.where(self.ground_truth[t], inputs['images'], states['gen_image']) # schedule sampling (if any) + last_images = states['last_images'][1:] + [image] + if 'pix_distribs' in inputs: + pix_distrib = tf.where(self.ground_truth[t], inputs['pix_distribs'], states['gen_pix_distrib']) + last_pix_distribs = states['last_pix_distribs'][1:] + [pix_distrib] + if 'states' in inputs: + state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state']) + + state_action = [] + state_action_z = [] + if 'actions' in inputs: + state_action.append(inputs['actions']) + state_action_z.append(inputs['actions']) + if 'states' in inputs: + state_action.append(state) + # don't backpropagate the convnet through the state dynamics + state_action_z.append(tf.stop_gradient(state)) + + if 'zs' in inputs: + if self.hparams.use_rnn_z: + with tf.variable_scope('%s_z' % ('fc' if self.hparams.ablation_rnn else self.hparams.rnn)): + if self.hparams.ablation_rnn: + rnn_z = dense(inputs['zs'], self.hparams.nz) + rnn_z = tf.nn.tanh(rnn_z) + else: + rnn_z, rnn_z_state = self._rnn_func(inputs['zs'], states['rnn_z_state'], self.hparams.nz) + state_action_z.append(rnn_z) + else: + state_action_z.append(inputs['zs']) + + def concat(tensors, axis): + if len(tensors) == 0: + return tf.zeros([batch_size, 0]) + elif len(tensors) == 1: + return tensors[0] + else: + return tf.concat(tensors, axis=axis) + state_action = concat(state_action, axis=-1) + state_action_z = concat(state_action_z, axis=-1) + + layers = [] + new_conv_rnn_states = [] + for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs): + with tf.variable_scope('h%d' % i): + if i == 0: + h = tf.concat([image, self.inputs['images'][0]], axis=-1) + kernel_size = (5, 5) + else: + h = layers[-1][-1] + kernel_size = (3, 3) + if self.hparams.where_add == 'all' or (self.hparams.where_add == 'input' and i == 0): + if self.hparams.use_tile_concat: + h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) + else: + h = [h, state_action_z] + h = _maybe_tile_concat_layer(downsample_layer)( + h, out_channels, kernel_size=kernel_size, strides=(2, 2)) + h = norm_layer(h) + h = activation_layer(h) + if use_conv_rnn: + with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, i)): + if self.hparams.where_add == 'all': + if self.hparams.use_tile_concat: + conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) + else: + conv_rnn_h = [h, state_action_z] + else: + conv_rnn_h = h + if self.hparams.ablation_rnn: + conv_rnn_h = _maybe_tile_concat_layer(conv2d)( + conv_rnn_h, out_channels, kernel_size=(5, 5)) + conv_rnn_h = norm_layer(conv_rnn_h) + conv_rnn_h = activation_layer(conv_rnn_h) + else: + conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] + conv_rnn_h, conv_rnn_state = self._conv_rnn_func(conv_rnn_h, conv_rnn_state, out_channels) + new_conv_rnn_states.append(conv_rnn_state) + layers.append((h, conv_rnn_h) if use_conv_rnn else (h,)) + + num_encoder_layers = len(layers) + for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs): + with tf.variable_scope('h%d' % len(layers)): + if i == 0: + h = layers[-1][-1] + else: + h = tf.concat([layers[-1][-1], layers[num_encoder_layers - i - 1][-1]], axis=-1) + if self.hparams.where_add == 'all' or (self.hparams.where_add == 'middle' and i == 0): + if self.hparams.use_tile_concat: + h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) + else: + h = [h, state_action_z] + h = _maybe_tile_concat_layer(upsample_layer)( + h, out_channels, kernel_size=(3, 3), strides=(2, 2)) + h = norm_layer(h) + h = activation_layer(h) + if use_conv_rnn: + with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, len(layers))): + if self.hparams.where_add == 'all': + if self.hparams.use_tile_concat: + conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) + else: + conv_rnn_h = [h, state_action_z] + else: + conv_rnn_h = h + if self.hparams.ablation_rnn: + conv_rnn_h = _maybe_tile_concat_layer(conv2d)(conv_rnn_h, out_channels, kernel_size=(5, 5)) + conv_rnn_h = norm_layer(conv_rnn_h) + conv_rnn_h = activation_layer(conv_rnn_h) + else: + conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] + conv_rnn_h, conv_rnn_state = self._conv_rnn_func(conv_rnn_h, conv_rnn_state, out_channels) + new_conv_rnn_states.append(conv_rnn_state) + layers.append((h, conv_rnn_h) if use_conv_rnn else (h,)) + assert len(new_conv_rnn_states) == len(conv_rnn_states) + + if self.hparams.last_frames and self.hparams.num_transformed_images: + if self.hparams.transformation == 'flow': + with tf.variable_scope('h%d_flow' % len(layers)): + h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) + h_flow = norm_layer(h_flow) + h_flow = activation_layer(h_flow) + + with tf.variable_scope('flows'): + flows = conv2d(h_flow, 2 * self.hparams.last_frames * self.hparams.num_transformed_images, kernel_size=(3, 3), strides=(1, 1)) + flows = tf.reshape(flows, [batch_size, height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images]) + else: + assert len(self.hparams.kernel_size) == 2 + kernel_shape = list(self.hparams.kernel_size) + [self.hparams.last_frames * self.hparams.num_transformed_images] + if self.hparams.transformation == 'dna': + with tf.variable_scope('h%d_dna_kernel' % len(layers)): + h_dna_kernel = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) + h_dna_kernel = norm_layer(h_dna_kernel) + h_dna_kernel = activation_layer(h_dna_kernel) + + # Using largest hidden state for predicting untied conv kernels. + with tf.variable_scope('dna_kernels'): + kernels = conv2d(h_dna_kernel, np.prod(kernel_shape), kernel_size=(3, 3), strides=(1, 1)) + kernels = tf.reshape(kernels, [batch_size, height, width] + kernel_shape) + kernels = kernels + identity_kernel(self.hparams.kernel_size)[None, None, None, :, :, None] + kernel_spatial_axes = [3, 4] + elif self.hparams.transformation == 'cdna': + with tf.variable_scope('cdna_kernels'): + smallest_layer = layers[num_encoder_layers - 1][-1] + kernels = dense(flatten(smallest_layer), np.prod(kernel_shape)) + kernels = tf.reshape(kernels, [batch_size] + kernel_shape) + kernels = kernels + identity_kernel(self.hparams.kernel_size)[None, :, :, None] + kernel_spatial_axes = [1, 2] + else: + raise ValueError('Invalid transformation %s' % self.hparams.transformation) + + if self.hparams.transformation != 'flow': + with tf.name_scope('kernel_normalization'): + kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT + kernels /= tf.reduce_sum(kernels, axis=kernel_spatial_axes, keepdims=True) + + if self.hparams.generate_scratch_image: + with tf.variable_scope('h%d_scratch' % len(layers)): + h_scratch = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) + h_scratch = norm_layer(h_scratch) + h_scratch = activation_layer(h_scratch) + + # Using largest hidden state for predicting a new image layer. + # This allows the network to also generate one image from scratch, + # which is useful when regions of the image become unoccluded. + with tf.variable_scope('scratch_image'): + scratch_image = conv2d(h_scratch, color_channels, kernel_size=(3, 3), strides=(1, 1)) + scratch_image = tf.nn.sigmoid(scratch_image) + + with tf.name_scope('transformed_images'): + transformed_images = [] + if self.hparams.last_frames and self.hparams.num_transformed_images: + if self.hparams.transformation == 'flow': + transformed_images.extend(apply_flows(last_images, flows)) + else: + transformed_images.extend(apply_kernels(last_images, kernels, self.hparams.dilation_rate)) + if self.hparams.prev_image_background: + transformed_images.append(image) + if self.hparams.first_image_background and not self.hparams.context_images_background: + transformed_images.append(self.inputs['images'][0]) + if self.hparams.last_image_background and not self.hparams.context_images_background: + transformed_images.append(self.inputs['images'][self.hparams.context_frames - 1]) + if self.hparams.last_context_image_background and not self.hparams.context_images_background: + last_context_image = tf.cond( + tf.less(t, self.hparams.context_frames), + lambda: self.inputs['images'][t], + lambda: self.inputs['images'][self.hparams.context_frames - 1]) + transformed_images.append(last_context_image) + if self.hparams.context_images_background: + transformed_images.extend(tf.unstack(self.inputs['images'][:self.hparams.context_frames])) + if self.hparams.generate_scratch_image: + transformed_images.append(scratch_image) + + if 'pix_distribs' in inputs: + with tf.name_scope('transformed_pix_distribs'): + transformed_pix_distribs = [] + if self.hparams.last_frames and self.hparams.num_transformed_images: + if self.hparams.transformation == 'flow': + transformed_pix_distribs.extend(apply_flows(last_pix_distribs, flows)) + else: + transformed_pix_distribs.extend(apply_kernels(last_pix_distribs, kernels, self.hparams.dilation_rate)) + if self.hparams.prev_image_background: + transformed_pix_distribs.append(pix_distrib) + if self.hparams.first_image_background and not self.hparams.context_images_background: + transformed_pix_distribs.append(self.inputs['pix_distribs'][0]) + if self.hparams.last_image_background and not self.hparams.context_images_background: + transformed_pix_distribs.append(self.inputs['pix_distribs'][self.hparams.context_frames - 1]) + if self.hparams.last_context_image_background and not self.hparams.context_images_background: + last_context_pix_distrib = tf.cond( + tf.less(t, self.hparams.context_frames), + lambda: self.inputs['pix_distribs'][t], + lambda: self.inputs['pix_distribs'][self.hparams.context_frames - 1]) + transformed_pix_distribs.append(last_context_pix_distrib) + if self.hparams.context_images_background: + transformed_pix_distribs.extend(tf.unstack(self.inputs['pix_distribs'][:self.hparams.context_frames])) + if self.hparams.generate_scratch_image: + transformed_pix_distribs.append(pix_distrib) + + with tf.name_scope('masks'): + if len(transformed_images) > 1: + with tf.variable_scope('h%d_masks' % len(layers)): + h_masks = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) + h_masks = norm_layer(h_masks) + h_masks = activation_layer(h_masks) + + with tf.variable_scope('masks'): + if self.hparams.dependent_mask: + h_masks = tf.concat([h_masks] + transformed_images, axis=-1) + masks = conv2d(h_masks, len(transformed_images), kernel_size=(3, 3), strides=(1, 1)) + masks = tf.nn.softmax(masks) + masks = tf.split(masks, len(transformed_images), axis=-1) + elif len(transformed_images) == 1: + masks = [tf.ones([batch_size, height, width, 1])] + else: + raise ValueError("Either one of the following should be true: " + "last_frames and num_transformed_images, first_image_background, " + "prev_image_background, generate_scratch_image") + + with tf.name_scope('gen_images'): + assert len(transformed_images) == len(masks) + gen_image = tf.add_n([transformed_image * mask + for transformed_image, mask in zip(transformed_images, masks)]) + + if 'pix_distribs' in inputs: + with tf.name_scope('gen_pix_distribs'): + assert len(transformed_pix_distribs) == len(masks) + gen_pix_distrib = tf.add_n([transformed_pix_distrib * mask + for transformed_pix_distrib, mask in zip(transformed_pix_distribs, masks)]) + gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True) + + if 'states' in inputs: + with tf.name_scope('gen_states'): + with tf.variable_scope('state_pred'): + gen_state = dense(state_action, inputs['states'].shape[-1].value) + + outputs = {'gen_images': gen_image, + 'transformed_images': tf.stack(transformed_images, axis=-1), + 'masks': tf.stack(masks, axis=-1)} + if 'pix_distribs' in inputs: + outputs['gen_pix_distribs'] = gen_pix_distrib + outputs['transformed_pix_distribs'] = tf.stack(transformed_pix_distribs, axis=-1) + if 'states' in inputs: + outputs['gen_states'] = gen_state + if self.hparams.transformation == 'flow': + outputs['gen_flows'] = flows + flows_transposed = tf.transpose(flows, [0, 1, 2, 4, 3]) + flows_rgb_transposed = tf_utils.flow_to_rgb(flows_transposed) + flows_rgb = tf.transpose(flows_rgb_transposed, [0, 1, 2, 4, 3]) + outputs['gen_flows_rgb'] = flows_rgb + + new_states = {'time': time + 1, + 'gen_image': gen_image, + 'last_images': last_images, + 'conv_rnn_states': new_conv_rnn_states} + if 'zs' in inputs and self.hparams.use_rnn_z and not self.hparams.ablation_rnn: + new_states['rnn_z_state'] = rnn_z_state + if 'pix_distribs' in inputs: + new_states['gen_pix_distrib'] = gen_pix_distrib + new_states['last_pix_distribs'] = last_pix_distribs + if 'states' in inputs: + new_states['gen_state'] = gen_state + return outputs, new_states + + +def generator_given_z_fn(inputs, mode, hparams): + # all the inputs needs to have the same length for unrolling the rnn + inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) + for name, input in inputs.items()} + cell = SAVPCell(inputs, mode, hparams) + outputs, _ = tf_utils.unroll_rnn(cell, inputs) + outputs['ground_truth_sampling_mean'] = tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:])) + return outputs + + +def generator_fn(inputs, mode, hparams): + batch_size = tf.shape(inputs['images'])[1] + print("2********mode*****",mode) + print("3*******hparams",hparams) + + + if hparams.nz == 0: + # no zs is given in inputs + outputs = generator_given_z_fn(inputs, mode, hparams) + else: + zs_shape = [hparams.sequence_length - 1, batch_size, hparams.nz] + + # posterior + with tf.variable_scope('encoder'): + outputs_posterior = posterior_fn(inputs, hparams) + eps = tf.random_normal(zs_shape, 0, 1) + zs_posterior = outputs_posterior['zs_mu'] + tf.sqrt(tf.exp(outputs_posterior['zs_log_sigma_sq'])) * eps + inputs_posterior = dict(inputs) + inputs_posterior['zs'] = zs_posterior + + # prior + if hparams.learn_prior: + with tf.variable_scope('prior'): + outputs_prior = prior_fn(inputs, hparams) + eps = tf.random_normal(zs_shape, 0, 1) + zs_prior = outputs_prior['zs_mu'] + tf.sqrt(tf.exp(outputs_prior['zs_log_sigma_sq'])) * eps + else: + outputs_prior = {} + zs_prior = tf.random_normal([hparams.sequence_length - hparams.context_frames] + zs_shape[1:], 0, 1) + zs_prior = tf.concat([zs_posterior[:hparams.context_frames - 1], zs_prior], axis=0) + inputs_prior = dict(inputs) + inputs_prior['zs'] = zs_prior + + # generate + gen_outputs_posterior = generator_given_z_fn(inputs_posterior, mode, hparams) + tf.get_variable_scope().reuse_variables() + gen_outputs = generator_given_z_fn(inputs_prior, mode, hparams) + + # rename tensors to avoid name collisions + output_prior = collections.OrderedDict([(k + '_prior', v) for k, v in outputs_prior.items()]) + outputs_posterior = collections.OrderedDict([(k + '_enc', v) for k, v in outputs_posterior.items()]) + gen_outputs_posterior = collections.OrderedDict([(k + '_enc', v) for k, v in gen_outputs_posterior.items()]) + + outputs = [output_prior, gen_outputs, outputs_posterior, gen_outputs_posterior] + total_num_outputs = sum([len(output) for output in outputs]) + outputs = collections.OrderedDict(itertools.chain(*[output.items() for output in outputs])) + assert len(outputs) == total_num_outputs # ensure no output is lost because of repeated keys + + # generate multiple samples from the prior for visualization purposes + inputs_samples = { + name: tf.tile(input[:, None], [1, hparams.num_samples] + [1] * (input.shape.ndims - 1)) + for name, input in inputs.items()} + zs_samples_shape = [hparams.sequence_length - 1, hparams.num_samples, batch_size, hparams.nz] + if hparams.learn_prior: + eps = tf.random_normal(zs_samples_shape, 0, 1) + zs_prior_samples = (outputs_prior['zs_mu'][:, None] + + tf.sqrt(tf.exp(outputs_prior['zs_log_sigma_sq']))[:, None] * eps) + else: + zs_prior_samples = tf.random_normal( + [hparams.sequence_length - hparams.context_frames] + zs_samples_shape[1:], 0, 1) + zs_prior_samples = tf.concat( + [tf.tile(zs_posterior[:hparams.context_frames - 1][:, None], [1, hparams.num_samples, 1, 1]), + zs_prior_samples], axis=0) + inputs_prior_samples = dict(inputs_samples) + inputs_prior_samples['zs'] = zs_prior_samples + inputs_prior_samples = {name: flatten(input, 1, 2) for name, input in inputs_prior_samples.items()} + gen_outputs_samples = generator_given_z_fn(inputs_prior_samples, mode, hparams) + gen_images_samples = gen_outputs_samples['gen_images'] + gen_images_samples = tf.stack(tf.split(gen_images_samples, hparams.num_samples, axis=1), axis=-1) + gen_images_samples_avg = tf.reduce_mean(gen_images_samples, axis=-1) + outputs['gen_images_samples'] = gen_images_samples + outputs['gen_images_samples_avg'] = gen_images_samples_avg + return outputs + + +class SAVPVideoPredictionModel(VideoPredictionModel): + def __init__(self, *args, **kwargs): + super(SAVPVideoPredictionModel, self).__init__( + generator_fn, discriminator_fn, *args, **kwargs) + if self.mode != 'train': + self.discriminator_fn = None + self.deterministic = not self.hparams.nz + + def get_default_hparams_dict(self): + default_hparams = super(SAVPVideoPredictionModel, self).get_default_hparams_dict() + hparams = dict( + l1_weight=1.0, + l2_weight=0.0, + n_layers=3, + ndf=32, + norm_layer='instance', + use_same_discriminator=False, + ngf=32, + downsample_layer='conv_pool2d', + upsample_layer='upsample_conv2d', + activation_layer='relu', # for generator only + transformation='cdna', + kernel_size=(5, 5), + dilation_rate=(1, 1), + where_add='all', + use_tile_concat=True, + learn_initial_state=False, + rnn='lstm', + conv_rnn='lstm', + conv_rnn_norm_layer='instance', + num_transformed_images=4, + last_frames=1, + prev_image_background=True, + first_image_background=True, + last_image_background=False, + last_context_image_background=False, + context_images_background=False, + generate_scratch_image=True, + dependent_mask=True, + schedule_sampling='inverse_sigmoid', + schedule_sampling_k=900.0, + schedule_sampling_steps=(0, 100000), + use_e_rnn=False, + learn_prior=False, + nz=8, + num_samples=8, + nef=64, + use_rnn_z=True, + ablation_conv_rnn_norm=False, + ablation_rnn=False, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def parse_hparams(self, hparams_dict, hparams): + # backwards compatibility + deprecated_hparams_keys = [ + 'num_gpus', + 'e_net', + 'd_conditional', + 'd_downsample_layer', + 'd_net', + 'd_use_gt_inputs', + 'acvideo_gan_weight', + 'acvideo_vae_gan_weight', + 'image_gan_weight', + 'image_vae_gan_weight', + 'tuple_gan_weight', + 'tuple_vae_gan_weight', + 'gan_weight', + 'vae_gan_weight', + 'video_gan_weight', + 'video_vae_gan_weight', + ] + for deprecated_hparams_key in deprecated_hparams_keys: + hparams_dict.pop(deprecated_hparams_key, None) + return super(SAVPVideoPredictionModel, self).parse_hparams(hparams_dict, hparams) + + def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): + def restore_to_checkpoint_mapping(restore_name, checkpoint_var_names): + restore_name = restore_name.split(':')[0] + if restore_name not in checkpoint_var_names: + restore_name = restore_name.replace('savp_cell', 'dna_cell') + return restore_name + + super(SAVPVideoPredictionModel, self).restore(sess, checkpoints, restore_to_checkpoint_mapping) + + +def apply_dna_kernels(image, kernels, dilation_rate=(1, 1)): + """ + Args: + image: A 4-D tensor of shape + `[batch, in_height, in_width, in_channels]`. + kernels: A 6-D of shape + `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`. + Returns: + A list of `num_transformed_images` 4-D tensors, each of shape + `[batch, in_height, in_width, in_channels]`. + """ + dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2 + batch_size, height, width, color_channels = image.get_shape().as_list() + batch_size, height, width, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list() + kernel_size = [kernel_height, kernel_width] + + # Flatten the spatial dimensions. + kernels_reshaped = tf.reshape(kernels, [batch_size, height, width, + kernel_size[0] * kernel_size[1], num_transformed_images]) + image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC') + # Combine channel and batch dimensions into the first dimension. + image_transposed = tf.transpose(image_padded, [3, 0, 1, 2]) + image_reshaped = flatten(image_transposed, 0, 1)[..., None] + patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1], + strides=[1] * 4, rates=[1] + dilation_rate + [1], padding='VALID') + # Separate channel and batch dimensions, and move channel dimension. + patches_transposed = tf.reshape(patches_reshaped, [color_channels, batch_size, height, width, kernel_size[0] * kernel_size[1]]) + patches = tf.transpose(patches_transposed, [1, 2, 3, 0, 4]) + # Reduce along the spatial dimensions of the kernel. + outputs = tf.matmul(patches, kernels_reshaped) + outputs = tf.unstack(outputs, axis=-1) + return outputs + + +def apply_cdna_kernels(image, kernels, dilation_rate=(1, 1)): + """ + Args: + image: A 4-D tensor of shape + `[batch, in_height, in_width, in_channels]`. + kernels: A 4-D of shape + `[batch, kernel_size[0], kernel_size[1], num_transformed_images]`. + Returns: + A list of `num_transformed_images` 4-D tensors, each of shape + `[batch, in_height, in_width, in_channels]`. + """ + batch_size, height, width, color_channels = image.get_shape().as_list() + batch_size, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list() + kernel_size = [kernel_height, kernel_width] + image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC') + # Treat the color channel dimension as the batch dimension since the same + # transformation is applied to each color channel. + # Treat the batch dimension as the channel dimension so that + # depthwise_conv2d can apply a different transformation to each sample. + kernels = tf.transpose(kernels, [1, 2, 0, 3]) + kernels = tf.reshape(kernels, [kernel_size[0], kernel_size[1], batch_size, num_transformed_images]) + # Swap the batch and channel dimensions. + image_transposed = tf.transpose(image_padded, [3, 1, 2, 0]) + # Transform image. + outputs = tf.nn.depthwise_conv2d(image_transposed, kernels, [1, 1, 1, 1], padding='VALID', rate=dilation_rate) + # Transpose the dimensions to where they belong. + outputs = tf.reshape(outputs, [color_channels, height, width, batch_size, num_transformed_images]) + outputs = tf.transpose(outputs, [4, 3, 1, 2, 0]) + outputs = tf.unstack(outputs, axis=0) + return outputs + + +def apply_kernels(image, kernels, dilation_rate=(1, 1)): + """ + Args: + image: A 4-D tensor of shape + `[batch, in_height, in_width, in_channels]`. + kernels: A 4-D or 6-D tensor of shape + `[batch, kernel_size[0], kernel_size[1], num_transformed_images]` or + `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`. + Returns: + A list of `num_transformed_images` 4-D tensors, each of shape + `[batch, in_height, in_width, in_channels]`. + """ + if isinstance(image, list): + image_list = image + kernels_list = tf.split(kernels, len(image_list), axis=-1) + outputs = [] + for image, kernels in zip(image_list, kernels_list): + outputs.extend(apply_kernels(image, kernels)) + else: + if len(kernels.get_shape()) == 4: + outputs = apply_cdna_kernels(image, kernels, dilation_rate=dilation_rate) + elif len(kernels.get_shape()) == 6: + outputs = apply_dna_kernels(image, kernels, dilation_rate=dilation_rate) + else: + raise ValueError + return outputs + + +def apply_flows(image, flows): + if isinstance(image, list): + image_list = image + flows_list = tf.split(flows, len(image_list), axis=-1) + outputs = [] + for image, flows in zip(image_list, flows_list): + outputs.extend(apply_flows(image, flows)) + else: + flows = tf.unstack(flows, axis=-1) + outputs = [flow_ops.image_warp(image, flow) for flow in flows] + return outputs + + +def identity_kernel(kernel_size): + kh, kw = kernel_size + kernel = np.zeros(kernel_size) + + def center_slice(k): + if k % 2 == 0: + return slice(k // 2 - 1, k // 2 + 1) + else: + return slice(k // 2, k // 2 + 1) + + kernel[center_slice(kh), center_slice(kw)] = 1.0 + kernel /= np.sum(kernel) + return kernel + + +def _maybe_tile_concat_layer(conv2d_layer): + def layer(inputs, out_channels, *args, **kwargs): + if isinstance(inputs, (list, tuple)): + inputs_spatial, inputs_non_spatial = inputs + outputs = (conv2d_layer(inputs_spatial, out_channels, *args, **kwargs) + + dense(inputs_non_spatial, out_channels, use_bias=False)[:, None, None, :]) + else: + outputs = conv2d_layer(inputs, out_channels, *args, **kwargs) + return outputs + + return layer diff --git a/video_prediction/models/sna_model.py b/video_prediction/models/sna_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ddb04deafc73f49d0466acd74ac4a43d94ac72f0 --- /dev/null +++ b/video_prediction/models/sna_model.py @@ -0,0 +1,667 @@ +# Copyright 2016 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Model architecture for predictive model, including CDNA, DNA, and STP.""" + +import itertools + +import numpy as np +import tensorflow as tf +import tensorflow.contrib.slim as slim +from tensorflow.contrib.layers.python import layers as tf_layers +from tensorflow.contrib.slim import add_arg_scope +from tensorflow.contrib.slim import layers + +from video_prediction.models import VideoPredictionModel + + +# Amount to use when lower bounding tensors +RELU_SHIFT = 1e-12 + + +@add_arg_scope +def basic_conv_lstm_cell(inputs, + state, + num_channels, + filter_size=5, + forget_bias=1.0, + scope=None, + reuse=None, + ): + """Basic LSTM recurrent network cell, with 2D convolution connctions. + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + Args: + inputs: input Tensor, 4D, batch x height x width x channels. + state: state Tensor, 4D, batch x height x width x channels. + num_channels: the number of output channels in the layer. + filter_size: the shape of the each convolution filter. + forget_bias: the initial value of the forget biases. + scope: Optional scope for variable_scope. + reuse: whether or not the layer and the variables should be reused. + Returns: + a tuple of tensors representing output and the new state. + """ + if state is None: + state = tf.zeros(inputs.get_shape().as_list()[:3] + [2 * num_channels], name='init_state') + + with tf.variable_scope(scope, + 'BasicConvLstmCell', + [inputs, state], + reuse=reuse): + + inputs.get_shape().assert_has_rank(4) + state.get_shape().assert_has_rank(4) + c, h = tf.split(axis=3, num_or_size_splits=2, value=state) + inputs_h = tf.concat(values=[inputs, h], axis=3) + # Parameters of gates are concatenated into one conv for efficiency. + i_j_f_o = layers.conv2d(inputs_h, + 4 * num_channels, [filter_size, filter_size], + stride=1, + activation_fn=None, + scope='Gates', + ) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = tf.split(value=i_j_f_o, num_or_size_splits=4, axis=3) + + new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j) + new_h = tf.tanh(new_c) * tf.sigmoid(o) + + return new_h, tf.concat(values=[new_c, new_h], axis=3) + + +class Prediction_Model(object): + + def __init__(self, + images, + actions=None, + states=None, + iter_num=-1.0, + pix_distributions1=None, + pix_distributions2=None, + conf=None): + + self.pix_distributions1 = pix_distributions1 + self.pix_distributions2 = pix_distributions2 + self.actions = actions + self.iter_num = iter_num + self.conf = conf + self.images = images + + self.cdna, self.stp, self.dna = False, False, False + if self.conf['model'] == 'CDNA': + self.cdna = True + elif self.conf['model'] == 'DNA': + self.dna = True + elif self.conf['model'] == 'STP': + self.stp = True + if self.stp + self.cdna + self.dna != 1: + raise ValueError("More than one option selected!") + + self.k = conf['schedsamp_k'] + self.use_state = conf['use_state'] + self.num_masks = conf['num_masks'] + self.context_frames = conf['context_frames'] + + self.batch_size, self.img_height, self.img_width, self.color_channels = [int(i) for i in + images[0].get_shape()[0:4]] + self.lstm_func = basic_conv_lstm_cell + + # Generated robot states and images. + self.gen_states = [] + self.gen_images = [] + self.gen_masks = [] + + self.moved_images = [] + + self.moved_pix_distrib1 = [] + self.moved_pix_distrib2 = [] + + self.states = states + self.gen_distrib1 = [] + self.gen_distrib2 = [] + + self.trafos = [] + + def build(self): + + if 'kern_size' in self.conf.keys(): + KERN_SIZE = self.conf['kern_size'] + else: + KERN_SIZE = 5 + + batch_size, img_height, img_width, color_channels = self.images[0].get_shape()[0:4] + lstm_func = basic_conv_lstm_cell + + + if self.states != None: + current_state = self.states[0] + else: + current_state = None + + if self.actions == None: + self.actions = [None for _ in self.images] + + if self.k == -1: + feedself = True + else: + # Scheduled sampling: + # Calculate number of ground-truth frames to pass in. + num_ground_truth = tf.to_int32( + tf.round(tf.to_float(batch_size) * (self.k / (self.k + tf.exp(self.iter_num / self.k))))) + feedself = False + + # LSTM state sizes and states. + + if 'lstm_size' in self.conf: + lstm_size = self.conf['lstm_size'] + print('using lstm size', lstm_size) + else: + ngf = self.conf['ngf'] + lstm_size = np.int32(np.array([ngf, ngf * 2, ngf * 4, ngf * 2, ngf])) + + + lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None + lstm_state5, lstm_state6, lstm_state7 = None, None, None + + for t, action in enumerate(self.actions): + print(t) + # Reuse variables after the first timestep. + reuse = bool(self.gen_images) + + done_warm_start = len(self.gen_images) > self.context_frames - 1 + with slim.arg_scope( + [lstm_func, slim.layers.conv2d, slim.layers.fully_connected, + tf_layers.layer_norm, slim.layers.conv2d_transpose], + reuse=reuse): + + if feedself and done_warm_start: + # Feed in generated image. + prev_image = self.gen_images[-1] # 64x64x6 + if self.pix_distributions1 != None: + prev_pix_distrib1 = self.gen_distrib1[-1] + if 'ndesig' in self.conf: + prev_pix_distrib2 = self.gen_distrib2[-1] + elif done_warm_start: + # Scheduled sampling + prev_image = scheduled_sample(self.images[t], self.gen_images[-1], batch_size, + num_ground_truth) + else: + # Always feed in ground_truth + prev_image = self.images[t] + if self.pix_distributions1 != None: + prev_pix_distrib1 = self.pix_distributions1[t] + if 'ndesig' in self.conf: + prev_pix_distrib2 = self.pix_distributions2[t] + if len(prev_pix_distrib1.get_shape()) == 3: + prev_pix_distrib1 = tf.expand_dims(prev_pix_distrib1, -1) + if 'ndesig' in self.conf: + prev_pix_distrib2 = tf.expand_dims(prev_pix_distrib2, -1) + + if 'refeed_firstimage' in self.conf: + assert self.conf['model']=='STP' + if t > 1: + input_image = self.images[1] + print('refeed with image 1') + else: + input_image = prev_image + else: + input_image = prev_image + + # Predicted state is always fed back in + if not 'ignore_state_action' in self.conf: + state_action = tf.concat(axis=1, values=[action, current_state]) + + enc0 = slim.layers.conv2d( #32x32x32 + input_image, + 32, [5, 5], + stride=2, + scope='scale1_conv1', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'layer_norm1'}) + + hidden1, lstm_state1 = lstm_func( # 32x32x16 + enc0, lstm_state1, lstm_size[0], scope='state1') + hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2') + + enc1 = slim.layers.conv2d( # 16x16x16 + hidden1, hidden1.get_shape()[3], [3, 3], stride=2, scope='conv2') + + hidden3, lstm_state3 = lstm_func( #16x16x32 + enc1, lstm_state3, lstm_size[1], scope='state3') + hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4') + + enc2 = slim.layers.conv2d( # 8x8x32 + hidden3, hidden3.get_shape()[3], [3, 3], stride=2, scope='conv3') + + if not 'ignore_state_action' in self.conf: + # Pass in state and action. + if 'ignore_state' in self.conf: + lowdim = action + print('ignoring state') + else: + lowdim = state_action + + smear = tf.reshape( + lowdim, + [int(batch_size), 1, 1, int(lowdim.get_shape()[1])]) + smear = tf.tile( + smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1]) + + enc2 = tf.concat(axis=3, values=[enc2, smear]) + else: + print('ignoring states and actions') + + enc3 = slim.layers.conv2d( #8x8x32 + enc2, hidden3.get_shape()[3], [1, 1], stride=1, scope='conv4') + + hidden5, lstm_state5 = lstm_func( #8x8x64 + enc3, lstm_state5, lstm_size[2], scope='state5') + hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6') + enc4 = slim.layers.conv2d_transpose( #16x16x64 + hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1') + + hidden6, lstm_state6 = lstm_func( #16x16x32 + enc4, lstm_state6, lstm_size[3], scope='state6') + hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7') + + if 'noskip' not in self.conf: + # Skip connection. + hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 + + enc5 = slim.layers.conv2d_transpose( #32x32x32 + hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2') + hidden7, lstm_state7 = lstm_func( # 32x32x16 + enc5, lstm_state7, lstm_size[4], scope='state7') + hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8') + + if not 'noskip' in self.conf: + # Skip connection. + hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 + + enc6 = slim.layers.conv2d_transpose( # 64x64x16 + hidden7, + hidden7.get_shape()[3], 3, stride=2, scope='convt3', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'layer_norm9'}) + + if 'transform_from_firstimage' in self.conf: + prev_image = self.images[1] + if self.pix_distributions1 != None: + prev_pix_distrib1 = self.pix_distributions1[1] + prev_pix_distrib1 = tf.expand_dims(prev_pix_distrib1, -1) + print('transform from image 1') + + if self.conf['model'] == 'DNA': + # Using largest hidden state for predicting untied conv kernels. + trafo_input = slim.layers.conv2d_transpose( + enc6, KERN_SIZE ** 2, 1, stride=1, scope='convt4_cam2') + + transformed_l = [self.dna_transformation(prev_image, trafo_input, self.conf['kern_size'])] + if self.pix_distributions1 != None: + transf_distrib_ndesig1 = [self.dna_transformation(prev_pix_distrib1, trafo_input, KERN_SIZE)] + if 'ndesig' in self.conf: + transf_distrib_ndesig2 = [ + self.dna_transformation(prev_pix_distrib2, trafo_input, KERN_SIZE)] + + + extra_masks = 1 ## extra_masks = 2 is needed for running singleview_shifted!! + # print('using extra masks 2 because of single view shifted!!') + # extra_masks = 2 + + if self.conf['model'] == 'CDNA': + if 'gen_pix' in self.conf: + # Using largest hidden state for predicting a new image layer. + enc7 = slim.layers.conv2d_transpose( + enc6, color_channels, 1, stride=1, scope='convt4', activation_fn=None) + # This allows the network to also generate one image from scratch, + # which is useful when regions of the image become unoccluded. + transformed_l = [tf.nn.sigmoid(enc7)] + extra_masks = 2 + else: + transformed_l = [] + extra_masks = 1 + + cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) + new_transformed, _ = self.cdna_transformation(prev_image, + cdna_input, + reuse_sc=reuse) + transformed_l += new_transformed + self.moved_images.append(transformed_l) + + if self.pix_distributions1 != None: + transf_distrib_ndesig1, _ = self.cdna_transformation(prev_pix_distrib1, + cdna_input, + reuse_sc=True) + self.moved_pix_distrib1.append(transf_distrib_ndesig1) + if 'ndesig' in self.conf: + transf_distrib_ndesig2, _ = self.cdna_transformation( + prev_pix_distrib2, + cdna_input, + reuse_sc=True) + + self.moved_pix_distrib2.append(transf_distrib_ndesig2) + + if self.conf['model'] == 'STP': + enc7 = slim.layers.conv2d_transpose(enc6, color_channels, 1, stride=1, scope='convt5', activation_fn= None) + # This allows the network to also generate one image from scratch, + # which is useful when regions of the image become unoccluded. + if 'gen_pix' in self.conf: + transformed_l = [tf.nn.sigmoid(enc7)] + extra_masks = 2 + else: + transformed_l = [] + extra_masks = 1 + + enc_stp = tf.reshape(hidden5, [int(batch_size), -1]) + stp_input = slim.layers.fully_connected( + enc_stp, 200, scope='fc_stp_cam2') + + # disabling capability to generete pixels + reuse_stp = None + if reuse: + reuse_stp = reuse + + # enable the generation of pixels: + transformed, trafo = self.stp_transformation(prev_image, stp_input, self.num_masks, reuse_stp, suffix='cam2') + transformed_l += transformed + + self.trafos.append(trafo) + self.moved_images.append(transformed_l) + + if self.pix_distributions1 != None: + transf_distrib_ndesig1, _ = self.stp_transformation(prev_pix_distrib1, stp_input, suffix='cam2', reuse=True) + self.moved_pix_distrib1.append(transf_distrib_ndesig1) + + if '1stimg_bckgd' in self.conf: + background = self.images[0] + print('using background from first image..') + else: background = prev_image + output, mask_list = self.fuse_trafos(enc6, background, + transformed_l, + scope='convt7_cam2', + extra_masks= extra_masks) + self.gen_images.append(output) + self.gen_masks.append(mask_list) + + if self.pix_distributions1!=None: + pix_distrib_output = self.fuse_pix_distrib(extra_masks, + mask_list, + self.pix_distributions1, + prev_pix_distrib1, + transf_distrib_ndesig1) + + self.gen_distrib1.append(pix_distrib_output) + if 'ndesig' in self.conf: + pix_distrib_output = self.fuse_pix_distrib(extra_masks, + mask_list, + self.pix_distributions2, + prev_pix_distrib2, + transf_distrib_ndesig2) + + self.gen_distrib2.append(pix_distrib_output) + + if int(current_state.get_shape()[1]) == 0: + current_state = tf.zeros_like(state_action) + else: + current_state = slim.layers.fully_connected( + state_action, + int(current_state.get_shape()[1]), + scope='state_pred', + activation_fn=None) + + self.gen_states.append(current_state) + + def fuse_trafos(self, enc6, background_image, transformed, scope, extra_masks): + masks = slim.layers.conv2d_transpose( + enc6, (self.conf['num_masks']+ extra_masks), 1, stride=1, activation_fn=None, scope=scope) + + img_height = 64 + img_width = 64 + num_masks = self.conf['num_masks'] + + if self.conf['model']=='DNA': + if num_masks != 1: + raise ValueError('Only one mask is supported for DNA model.') + + # the total number of masks is num_masks +extra_masks because of background and generated pixels! + masks = tf.reshape( + tf.nn.softmax(tf.reshape(masks, [-1, num_masks +extra_masks])), + [int(self.batch_size), int(img_height), int(img_width), num_masks +extra_masks]) + mask_list = tf.split(axis=3, num_or_size_splits=num_masks +extra_masks, value=masks) + output = mask_list[0] * background_image + + assert len(transformed) == len(mask_list[1:]) + for layer, mask in zip(transformed, mask_list[1:]): + output += layer * mask + + return output, mask_list + + def fuse_pix_distrib(self, extra_masks, mask_list, pix_distributions, prev_pix_distrib, + transf_distrib): + + if '1stimg_bckgd' in self.conf: + background_pix = pix_distributions[0] + if len(background_pix.get_shape()) == 3: + background_pix = tf.expand_dims(background_pix, -1) + print('using pix_distrib-background from first image..') + else: + background_pix = prev_pix_distrib + pix_distrib_output = mask_list[0] * background_pix + if 'gen_pix' in self.conf: + pix_distrib_output += mask_list[1] * prev_pix_distrib # assume pixels don't when image is generated from scratch + for i in range(self.num_masks): + pix_distrib_output += transf_distrib[i] * mask_list[i + extra_masks] + pix_distrib_output /= tf.reduce_sum(pix_distrib_output, axis=(1, 2), keepdims=True) + return pix_distrib_output + + ## Utility functions + def stp_transformation(self, prev_image, stp_input, num_masks, reuse= None, suffix = None): + """Apply spatial transformer predictor (STP) to previous image. + + Args: + prev_image: previous image to be transformed. + stp_input: hidden layer to be used for computing STN parameters. + num_masks: number of masks and hence the number of STP transformations. + Returns: + List of images transformed by the predicted STP parameters. + """ + # Only import spatial transformer if needed. + from spatial_transformer import transformer + + identity_params = tf.convert_to_tensor( + np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32)) + transformed = [] + trafos = [] + for i in range(num_masks): + params = slim.layers.fully_connected( + stp_input, 6, scope='stp_params' + str(i) + suffix, + activation_fn=None, + reuse= reuse) + identity_params + outsize = (prev_image.get_shape()[1], prev_image.get_shape()[2]) + transformed.append(transformer(prev_image, params, outsize)) + trafos.append(params) + + return transformed, trafos + + def dna_transformation(self, prev_image, dna_input, DNA_KERN_SIZE): + """Apply dynamic neural advection to previous image. + + Args: + prev_image: previous image to be transformed. + dna_input: hidden lyaer to be used for computing DNA transformation. + Returns: + List of images transformed by the predicted CDNA kernels. + """ + # Construct translated images. + pad_len = int(np.floor(DNA_KERN_SIZE / 2)) + prev_image_pad = tf.pad(prev_image, [[0, 0], [pad_len, pad_len], [pad_len, pad_len], [0, 0]]) + image_height = int(prev_image.get_shape()[1]) + image_width = int(prev_image.get_shape()[2]) + + inputs = [] + for xkern in range(DNA_KERN_SIZE): + for ykern in range(DNA_KERN_SIZE): + inputs.append( + tf.expand_dims( + tf.slice(prev_image_pad, [0, xkern, ykern, 0], + [-1, image_height, image_width, -1]), [3])) + inputs = tf.concat(axis=3, values=inputs) + + # Normalize channels to 1. + kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT + kernel = tf.expand_dims( + kernel / tf.reduce_sum( + kernel, [3], keepdims=True), [4]) + + return tf.reduce_sum(kernel * inputs, [3], keepdims=False) + + def cdna_transformation(self, prev_image, cdna_input, reuse_sc=None): + """Apply convolutional dynamic neural advection to previous image. + + Args: + prev_image: previous image to be transformed. + cdna_input: hidden lyaer to be used for computing CDNA kernels. + num_masks: the number of masks and hence the number of CDNA transformations. + color_channels: the number of color channels in the images. + Returns: + List of images transformed by the predicted CDNA kernels. + """ + batch_size = int(cdna_input.get_shape()[0]) + height = int(prev_image.get_shape()[1]) + width = int(prev_image.get_shape()[2]) + + DNA_KERN_SIZE = self.conf['kern_size'] + num_masks = self.conf['num_masks'] + color_channels = int(prev_image.get_shape()[3]) + + # Predict kernels using linear function of last hidden layer. + cdna_kerns = slim.layers.fully_connected( + cdna_input, + DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks, + scope='cdna_params', + activation_fn=None, + reuse = reuse_sc) + + # Reshape and normalize. + cdna_kerns = tf.reshape( + cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks]) + cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT + norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True) + cdna_kerns /= norm_factor + cdna_kerns_summary = cdna_kerns + + # Transpose and reshape. + cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3]) + cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks]) + prev_image = tf.transpose(prev_image, [3, 1, 2, 0]) + + transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME') + + # Transpose and reshape. + transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks]) + transformed = tf.transpose(transformed, [3, 1, 2, 0, 4]) + transformed = tf.unstack(value=transformed, axis=-1) + + return transformed, cdna_kerns_summary + + +def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): + """Sample batch with specified mix of ground truth and generated data_files points. + + Args: + ground_truth_x: tensor of ground-truth data_files points. + generated_x: tensor of generated data_files points. + batch_size: batch size + num_ground_truth: number of ground-truth examples to include in batch. + Returns: + New batch with num_ground_truth sampled from ground_truth_x and the rest + from generated_x. + """ + idx = tf.random_shuffle(tf.range(int(batch_size))) + ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) + generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size))) + + ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) + generated_examps = tf.gather(generated_x, generated_idx) + return tf.dynamic_stitch([ground_truth_idx, generated_idx], + [ground_truth_examps, generated_examps]) + + +def generator_fn(inputs, mode, hparams): + images = tf.unstack(inputs['images'], axis=0) + actions = tf.unstack(inputs['actions'], axis=0) + states = tf.unstack(inputs['states'], axis=0) + pix_distributions1 = tf.unstack(inputs['pix_distribs'], axis=0) if 'pix_distribs' in inputs else None + iter_num = tf.to_float(tf.train.get_or_create_global_step()) + + if isinstance(hparams.kernel_size, (tuple, list)): + kernel_height, kernel_width = hparams.kernel_size + assert kernel_height == kernel_width + kern_size = kernel_height + else: + kern_size = hparams.kernel_size + + schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1 + conf = { + 'context_frames': hparams.context_frames, # of frames before predictions.' , + 'use_state': 1, # 'Whether or not to give the state+action to the model' , + 'ngf': hparams.ngf, + 'model': hparams.transformation.upper(), # 'model architecture to use - CDNA, DNA, or STP' , + 'num_masks': hparams.num_masks, # 'number of masks, usually 1 for DNA, 10 for CDNA, STN.' , + 'schedsamp_k': schedule_sampling_k, # 'The k hyperparameter for scheduled sampling -1 for no scheduled sampling.' , + 'kern_size': kern_size, # size of DNA kerns + } + if hparams.first_image_background: + conf['1stimg_bckgd'] = '' + if hparams.generate_scratch_image: + conf['gen_pix'] = '' + + m = Prediction_Model(images, actions, states, + pix_distributions1=pix_distributions1, + iter_num=iter_num, conf=conf) + m.build() + outputs = { + 'gen_images': tf.stack(m.gen_images, axis=0), + 'gen_states': tf.stack(m.gen_states, axis=0), + } + if 'pix_distribs' in inputs: + outputs['gen_pix_distribs'] = tf.stack(m.gen_distrib1, axis=0) + return outputs + + +class SNAVideoPredictionModel(VideoPredictionModel): + def __init__(self, *args, **kwargs): + super(SNAVideoPredictionModel, self).__init__( + generator_fn, *args, **kwargs) + + def get_default_hparams_dict(self): + default_hparams = super(SNAVideoPredictionModel, self).get_default_hparams_dict() + hparams = dict( + batch_size=32, + l1_weight=0.0, + l2_weight=1.0, + ngf=16, + transformation='cdna', + kernel_size=(5, 5), + num_masks=10, + first_image_background=True, + generate_scratch_image=True, + schedule_sampling_k=900.0, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) diff --git a/video_prediction/models/sv2p_model.py b/video_prediction/models/sv2p_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a06364178dc380456e47569787ab693ad121de --- /dev/null +++ b/video_prediction/models/sv2p_model.py @@ -0,0 +1,678 @@ +# Copyright 2016 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Model architecture for predictive model, including CDNA, DNA, and STP.""" + +import itertools +import numpy as np +import tensorflow as tf +import tensorflow.contrib.slim as slim +from tensorflow.contrib.layers.python import layers as tf_layers +from tensorflow.contrib.slim import add_arg_scope +from tensorflow.contrib.slim import layers + +from video_prediction.models import VideoPredictionModel + + +# Amount to use when lower bounding tensors +RELU_SHIFT = 1e-12 + +# kernel size for DNA and CDNA. +DNA_KERN_SIZE = 5 + + +def init_state(inputs, + state_shape, + state_initializer=tf.zeros_initializer(), + dtype=tf.float32): + """Helper function to create an initial state given inputs. + Args: + inputs: input Tensor, at least 2D, the first dimension being batch_size + state_shape: the shape of the state. + state_initializer: Initializer(shape, dtype) for state Tensor. + dtype: Optional dtype, needed when inputs is None. + Returns: + A tensors representing the initial state. + """ + if inputs is not None: + # Handle both the dynamic shape as well as the inferred shape. + inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0] + dtype = inputs.dtype + else: + inferred_batch_size = 0 + initial_state = state_initializer( + [inferred_batch_size] + state_shape, dtype=dtype) + return initial_state + + +@add_arg_scope +def basic_conv_lstm_cell(inputs, + state, + num_channels, + filter_size=5, + forget_bias=1.0, + scope=None, + reuse=None): + """Basic LSTM recurrent network cell, with 2D convolution connctions. + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + Args: + inputs: input Tensor, 4D, batch x height x width x channels. + state: state Tensor, 4D, batch x height x width x channels. + num_channels: the number of output channels in the layer. + filter_size: the shape of the each convolution filter. + forget_bias: the initial value of the forget biases. + scope: Optional scope for variable_scope. + reuse: whether or not the layer and the variables should be reused. + Returns: + a tuple of tensors representing output and the new state. + """ + spatial_size = inputs.get_shape()[1:3] + if state is None: + state = init_state(inputs, list(spatial_size) + [2 * num_channels]) + with tf.variable_scope(scope, + 'BasicConvLstmCell', + [inputs, state], + reuse=reuse): + inputs.get_shape().assert_has_rank(4) + state.get_shape().assert_has_rank(4) + c, h = tf.split(axis=3, num_or_size_splits=2, value=state) + inputs_h = tf.concat(axis=3, values=[inputs, h]) + # Parameters of gates are concatenated into one conv for efficiency. + i_j_f_o = layers.conv2d(inputs_h, + 4 * num_channels, [filter_size, filter_size], + stride=1, + activation_fn=None, + scope='Gates') + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = tf.split(axis=3, num_or_size_splits=4, value=i_j_f_o) + + new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j) + new_h = tf.tanh(new_c) * tf.sigmoid(o) + + return new_h, tf.concat(axis=3, values=[new_c, new_h]) + + +def kl_divergence(mu, log_sigma): + """KL divergence of diagonal gaussian N(mu,exp(log_sigma)) and N(0,1). + + Args: + mu: mu parameter of the distribution. + log_sigma: log(sigma) parameter of the distribution. + Returns: + the KL loss. + """ + + return -.5 * tf.reduce_sum(1. + log_sigma - tf.square(mu) - tf.exp(log_sigma), + axis=1) + + +def construct_latent_tower(images, hparams): + """Builds convolutional latent tower for stochastic model. + + At training time this tower generates a latent distribution (mean and std) + conditioned on the entire video. This latent variable will be fed to the + main tower as an extra variable to be used for future frames prediction. + At inference time, the tower is disabled and only returns latents sampled + from N(0,1). + If the multi_latent flag is on, a different latent for every timestep would + be generated. + + Args: + images: tensor of ground truth image sequences + Returns: + latent_mean: predicted latent mean + latent_std: predicted latent standard deviation + latent_loss: loss of the latent twoer + samples: random samples sampled from standard guassian + """ + + with slim.arg_scope([slim.conv2d], reuse=False): + stacked_images = tf.concat(images, 3) + + latent_enc1 = slim.conv2d( + stacked_images, + 32, [3, 3], + stride=2, + scope='latent_conv1', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'latent_norm1'}) + + latent_enc2 = slim.conv2d( + latent_enc1, + 64, [3, 3], + stride=2, + scope='latent_conv2', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'latent_norm2'}) + + latent_enc3 = slim.conv2d( + latent_enc2, + 64, [3, 3], + stride=1, + scope='latent_conv3', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'latent_norm3'}) + + latent_mean = slim.conv2d( + latent_enc3, + hparams.latent_channels, [3, 3], + stride=2, + activation_fn=None, + scope='latent_mean', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'latent_norm_mean'}) + + latent_std = slim.conv2d( + latent_enc3, + hparams.latent_channels, [3, 3], + stride=2, + scope='latent_std', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'latent_std_norm'}) + + latent_std += hparams.latent_std_min + + return latent_mean, latent_std + + +def encoder_fn(inputs, hparams): + images = tf.unstack(inputs['images'], axis=0) + latent_mean, latent_std = construct_latent_tower(images, hparams) + outputs = {'zs_mu_enc': latent_mean, 'zs_log_sigma_sq_enc': latent_std} + return outputs + + +def construct_model(images, + actions=None, + states=None, + outputs_enc=None, + iter_num=-1.0, + k=-1, + use_state=True, + num_masks=10, + stp=False, + cdna=True, + dna=False, + context_frames=2, + hparams=None): + """Build convolutional lstm video predictor using STP, CDNA, or DNA. + + Args: + images: tensor of ground truth image sequences + actions: tensor of action sequences + states: tensor of ground truth state sequences + iter_num: tensor of the current training iteration (for sched. sampling) + k: constant used for scheduled sampling. -1 to feed in own prediction. + use_state: True to include state and action in prediction + num_masks: the number of different pixel motion predictions (and + the number of masks for each of those predictions) + stp: True to use Spatial Transformer Predictor (STP) + cdna: True to use Convoluational Dynamic Neural Advection (CDNA) + dna: True to use Dynamic Neural Advection (DNA) + context_frames: number of ground truth frames to pass in before + feeding in own predictions + Returns: + gen_images: predicted future image frames + gen_states: predicted future states + + Raises: + ValueError: if more than one network option specified or more than 1 mask + specified for DNA model. + """ + # Each image is being used twice, in latent tower and main tower. + # This is to make sure we are using the *same* image for both, ... + # ... given how TF queues work. + images = [tf.identity(image) for image in images] + + if stp + cdna + dna != 1: + raise ValueError('More than one, or no network option specified.') + batch_size, img_height, img_width, color_channels = images[0].shape.as_list() + lstm_func = basic_conv_lstm_cell + + # Generated robot states and images. + gen_states, gen_images = [], [] + current_state = states[0] + + if k == -1: + feedself = True + else: + # Scheduled sampling: + # Calculate number of ground-truth frames to pass in. + num_ground_truth = tf.to_int32( + tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k))))) + feedself = False + + # LSTM state sizes and states. + lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32])) + lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None + lstm_state5, lstm_state6, lstm_state7 = None, None, None + + # Latent tower + if hparams.stochastic_model: + latent_shape = [batch_size, img_height // 8, img_width // 8, hparams.latent_channels] + if outputs_enc is None: # equivalent to inference_time + latent_mean, latent_std = None, None + else: + latent_mean, latent_std = outputs_enc['zs_mu_enc'], outputs_enc['zs_log_sigma_sq_enc'] + assert latent_mean.shape.as_list() == latent_shape + + if hparams.multi_latent: + # timestep x batch_size x latent_size + samples = tf.random_normal( + [hparams.sequence_length - 1] + latent_shape, 0, 1, + dtype=tf.float32) + else: + # batch_size x latent_size + samples = tf.random_normal(latent_shape, 0, 1, dtype=tf.float32) + + # Main tower + for t in range(hparams.sequence_length - 1): + action = actions[t] + # Reuse variables after the first timestep. + reuse = bool(gen_images) + + done_warm_start = len(gen_images) > context_frames - 1 + with slim.arg_scope( + [lstm_func, slim.layers.conv2d, slim.layers.fully_connected, + tf_layers.layer_norm, slim.layers.conv2d_transpose], + reuse=reuse): + + if feedself and done_warm_start: + # Feed in generated image. + prev_image = gen_images[-1] + elif done_warm_start: + # Scheduled sampling + prev_image = scheduled_sample(images[t], gen_images[-1], batch_size, + num_ground_truth) + else: + # Always feed in ground_truth + prev_image = images[t] + + # Predicted state is always fed back in + state_action = tf.concat(axis=1, values=[action, current_state]) + + enc0 = slim.layers.conv2d( + prev_image, + 32, [5, 5], + stride=2, + scope='scale1_conv1', + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'layer_norm1'}) + + hidden1, lstm_state1 = lstm_func( + enc0, lstm_state1, lstm_size[0], scope='state1') + hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2') + hidden2, lstm_state2 = lstm_func( + hidden1, lstm_state2, lstm_size[1], scope='state2') + hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3') + enc1 = slim.layers.conv2d( + hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2') + + hidden3, lstm_state3 = lstm_func( + enc1, lstm_state3, lstm_size[2], scope='state3') + hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4') + hidden4, lstm_state4 = lstm_func( + hidden3, lstm_state4, lstm_size[3], scope='state4') + hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5') + enc2 = slim.layers.conv2d( + hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3') + + # Pass in state and action. + smear = tf.reshape( + state_action, + [int(batch_size), 1, 1, int(state_action.get_shape()[1])]) + smear = tf.tile( + smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1]) + if use_state: + enc2 = tf.concat(axis=3, values=[enc2, smear]) + # Setup latent + if hparams.stochastic_model: + latent = samples + if hparams.multi_latent: + latent = samples[t] + if outputs_enc is not None: # equivalent to not inference_time + latent = tf.cond(iter_num < hparams.num_iterations_1st_stage, + lambda: tf.identity(latent), + lambda: latent_mean + tf.exp(latent_std / 2.0) * latent) + with tf.control_dependencies([latent]): + enc2 = tf.concat([enc2, latent], 3) + + enc3 = slim.layers.conv2d( + enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4') + + hidden5, lstm_state5 = lstm_func( + enc3, lstm_state5, lstm_size[4], scope='state5') # last 8x8 + hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6') + enc4 = slim.layers.conv2d_transpose( + hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1') + + hidden6, lstm_state6 = lstm_func( + enc4, lstm_state6, lstm_size[5], scope='state6') # 16x16 + hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7') + # Skip connection. + hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 + + enc5 = slim.layers.conv2d_transpose( + hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2') + hidden7, lstm_state7 = lstm_func( + enc5, lstm_state7, lstm_size[6], scope='state7') # 32x32 + hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8') + + # Skip connection. + hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 + + enc6 = slim.layers.conv2d_transpose( + hidden7, + hidden7.get_shape()[3], 3, stride=2, scope='convt3', activation_fn=None, + normalizer_fn=tf_layers.layer_norm, + normalizer_params={'scope': 'layer_norm9'}) + + if dna: + # Using largest hidden state for predicting untied conv kernels. + enc7 = slim.layers.conv2d_transpose( + enc6, DNA_KERN_SIZE ** 2, 1, stride=1, scope='convt4', activation_fn=None) + else: + # Using largest hidden state for predicting a new image layer. + enc7 = slim.layers.conv2d_transpose( + enc6, color_channels, 1, stride=1, scope='convt4', activation_fn=None) + # This allows the network to also generate one image from scratch, + # which is useful when regions of the image become unoccluded. + transformed = [tf.nn.sigmoid(enc7)] + + if stp: + stp_input0 = tf.reshape(hidden5, [int(batch_size), -1]) + stp_input1 = slim.layers.fully_connected( + stp_input0, 100, scope='fc_stp') + transformed += stp_transformation(prev_image, stp_input1, num_masks) + elif cdna: + cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) + transformed += cdna_transformation(prev_image, cdna_input, num_masks, + int(color_channels)) + elif dna: + # Only one mask is supported (more should be unnecessary). + if num_masks != 1: + raise ValueError('Only one mask is supported for DNA model.') + transformed = [dna_transformation(prev_image, enc7)] + + masks = slim.layers.conv2d_transpose( + enc6, num_masks + 1, 1, stride=1, scope='convt7', activation_fn=None) + masks = tf.reshape( + tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), + [int(batch_size), int(img_height), int(img_width), num_masks + 1]) + mask_list = tf.split(axis=3, num_or_size_splits=num_masks + 1, value=masks) + output = mask_list[0] * prev_image + for layer, mask in zip(transformed, mask_list[1:]): + output += layer * mask + gen_images.append(output) + + current_state = slim.layers.fully_connected( + state_action, + int(current_state.get_shape()[1]), + scope='state_pred', + activation_fn=None) + gen_states.append(current_state) + + return gen_images, gen_states + + +## Utility functions +def stp_transformation(prev_image, stp_input, num_masks): + """Apply spatial transformer predictor (STP) to previous image. + + Args: + prev_image: previous image to be transformed. + stp_input: hidden layer to be used for computing STN parameters. + num_masks: number of masks and hence the number of STP transformations. + Returns: + List of images transformed by the predicted STP parameters. + """ + # Only import spatial transformer if needed. + from spatial_transformer import transformer + + identity_params = tf.convert_to_tensor( + np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32)) + transformed = [] + for i in range(num_masks - 1): + params = slim.layers.fully_connected( + stp_input, 6, scope='stp_params' + str(i), + activation_fn=None) + identity_params + transformed.append(transformer(prev_image, params)) + + return transformed + + +def cdna_transformation(prev_image, cdna_input, num_masks, color_channels): + """Apply convolutional dynamic neural advection to previous image. + + Args: + prev_image: previous image to be transformed. + cdna_input: hidden lyaer to be used for computing CDNA kernels. + num_masks: the number of masks and hence the number of CDNA transformations. + color_channels: the number of color channels in the images. + Returns: + List of images transformed by the predicted CDNA kernels. + """ + batch_size = int(cdna_input.get_shape()[0]) + height = int(prev_image.get_shape()[1]) + width = int(prev_image.get_shape()[2]) + + # Predict kernels using linear function of last hidden layer. + cdna_kerns = slim.layers.fully_connected( + cdna_input, + DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks, + scope='cdna_params', + activation_fn=None) + + # Reshape and normalize. + cdna_kerns = tf.reshape( + cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks]) + cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT + norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True) + cdna_kerns /= norm_factor + + # Treat the color channel dimension as the batch dimension since the same + # transformation is applied to each color channel. + # Treat the batch dimension as the channel dimension so that + # depthwise_conv2d can apply a different transformation to each sample. + cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3]) + cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks]) + # Swap the batch and channel dimensions. + prev_image = tf.transpose(prev_image, [3, 1, 2, 0]) + + # Transform image. + transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME') + + # Transpose the dimensions to where they belong. + transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks]) + transformed = tf.transpose(transformed, [3, 1, 2, 0, 4]) + transformed = tf.unstack(transformed, axis=-1) + return transformed + + +def dna_transformation(prev_image, dna_input): + """Apply dynamic neural advection to previous image. + + Args: + prev_image: previous image to be transformed. + dna_input: hidden lyaer to be used for computing DNA transformation. + Returns: + List of images transformed by the predicted CDNA kernels. + """ + # Construct translated images. + prev_image_pad = tf.pad(prev_image, [[0, 0], [2, 2], [2, 2], [0, 0]]) + image_height = int(prev_image.get_shape()[1]) + image_width = int(prev_image.get_shape()[2]) + + inputs = [] + for xkern in range(DNA_KERN_SIZE): + for ykern in range(DNA_KERN_SIZE): + inputs.append( + tf.expand_dims( + tf.slice(prev_image_pad, [0, xkern, ykern, 0], + [-1, image_height, image_width, -1]), [3])) + inputs = tf.concat(axis=3, values=inputs) + + # Normalize channels to 1. + kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT + kernel = tf.expand_dims( + kernel / tf.reduce_sum( + kernel, [3], keepdims=True), [4]) + return tf.reduce_sum(kernel * inputs, [3], keepdims=False) + + +def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): + """Sample batch with specified mix of ground truth and generated data points. + + Args: + ground_truth_x: tensor of ground-truth data points. + generated_x: tensor of generated data points. + batch_size: batch size + num_ground_truth: number of ground-truth examples to include in batch. + Returns: + New batch with num_ground_truth sampled from ground_truth_x and the rest + from generated_x. + """ + idx = tf.random_shuffle(tf.range(int(batch_size))) + ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) + generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size))) + + ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) + generated_examps = tf.gather(generated_x, generated_idx) + return tf.dynamic_stitch([ground_truth_idx, generated_idx], + [ground_truth_examps, generated_examps]) + + +def generator_fn(inputs, mode, hparams): + images = tf.unstack(inputs['images'], axis=0) + batch_size = images[0].shape[0].value + action_dim, state_dim = 4, 3 + + # if not use_state, use zero actions and states to match reference implementation. + actions = inputs.get('actions', tf.zeros([hparams.sequence_length - 1, batch_size, action_dim])) + actions = tf.unstack(actions, axis=0) + states = inputs.get('states', tf.zeros([hparams.sequence_length, batch_size, state_dim])) + states = tf.unstack(states, axis=0) + iter_num = tf.to_float(tf.train.get_or_create_global_step()) + + schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1 + gen_images, gen_states = \ + construct_model(images, + actions, + states, + outputs_enc=None, + iter_num=iter_num, + k=schedule_sampling_k, + use_state='actions' in inputs, + num_masks=hparams.num_masks, + cdna=hparams.transformation == 'cdna', + dna=hparams.transformation == 'dna', + stp=hparams.transformation == 'stp', + context_frames=hparams.context_frames, + hparams=hparams) + outputs = { + 'gen_images': tf.stack(gen_images, axis=0), + 'gen_states': tf.stack(gen_states, axis=0), + } + + if mode == 'train': + outputs_enc = encoder_fn(inputs, hparams) + tf.get_variable_scope().reuse_variables() + gen_images_enc, gen_states_enc = \ + construct_model(images, + actions, + states, + outputs_enc=outputs_enc, + iter_num=iter_num, + k=schedule_sampling_k, + use_state='actions' in inputs, + num_masks=hparams.num_masks, + cdna=hparams.transformation == 'cdna', + dna=hparams.transformation == 'dna', + stp=hparams.transformation == 'stp', + context_frames=hparams.context_frames, + hparams=hparams) + outputs.update({ + 'gen_images_enc': tf.stack(gen_images_enc, axis=0), + 'gen_states_enc': tf.stack(gen_states_enc, axis=0), + 'zs_mu_enc': outputs_enc['zs_mu_enc'], + 'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'], + }) + return outputs + + +class SV2PVideoPredictionModel(VideoPredictionModel): + """ + Stochastic Variational Video Prediction + https://arxiv.org/abs/1710.11252 + + Reference implementation: + https://github.com/mbz/models/tree/master/research/video_prediction + """ + def __init__(self, *args, **kwargs): + super(SV2PVideoPredictionModel, self).__init__( + generator_fn, *args, ** kwargs) + self.deterministic = not self.hparams.stochastic_model + + def get_default_hparams_dict(self): + default_hparams = super(SV2PVideoPredictionModel, self).get_default_hparams_dict() + hparams = dict( + batch_size=32, + l1_weight=0.0, + l2_weight=1.0, + kl_weight=1e-3 * 10 * 8, # equivalent to latent_loss_multiplier up to a factor (see below) + transformation='cdna', + num_masks=10, + schedule_sampling_k=900.0, + stochastic_model=True, + multi_latent=False, + latent_std_min=-5.0, + latent_channels=1, + num_iterations_1st_stage=50000, + kl_anneal_steps=(100000, 120000), + max_steps=200000, + decay_steps=(0, 0), # do not decay the learning rate (doing so produces blurrier images) + ) + # Notes on equivalence with reference implementation: + # kl_weight is equivalent to latent_loss_multiplier * time_factor * factor, where + # time_factor = (sequence_length - context_frames) since the reference implementation + # doesn't normalize the kl divergence over time, and factor = (width // 8) / latent_channels + # since the reference implementation's kl_divergence sums over axis=1 instead of axis=-1. + # The paper and the reference implementation differs in the annealing of the kl_weight. + # Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is + # linearly increased for the first 20k iterations of this stage. + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def parse_hparams(self, hparams_dict, hparams): + # backwards compatibility + deprecated_hparams_keys = [ + 'num_gpus', + 'acvideo_gan_weight', + 'acvideo_vae_gan_weight', + 'image_gan_weight', + 'image_vae_gan_weight', + 'tuple_gan_weight', + 'tuple_vae_gan_weight', + 'gan_weight', + 'vae_gan_weight', + 'video_gan_weight', + 'video_vae_gan_weight', + ] + for deprecated_hparams_key in deprecated_hparams_keys: + hparams_dict.pop(deprecated_hparams_key, None) + return super(SV2PVideoPredictionModel, self).parse_hparams(hparams_dict, hparams) diff --git a/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction/models/vanilla_convLSTM_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e7753004348ae0ae60057a469de1e2d1421c3869 --- /dev/null +++ b/video_prediction/models/vanilla_convLSTM_model.py @@ -0,0 +1,162 @@ +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from video_prediction import ops, flow_ops +from video_prediction.models import BaseVideoPredictionModel +from video_prediction.models import networks +from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from video_prediction.layers import layer_def as ld +from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell + +class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): + def __init__(self, mode='train',aggregate_nccl=None, hparams_dict=None, + hparams=None, **kwargs): + super(VanillaConvLstmVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) + print ("Hparams_dict",self.hparams) + self.mode = mode + self.learning_rate = self.hparams.lr + self.gen_images_enc = None + self.recon_loss = None + self.latent_loss = None + self.total_loss = None + self.context_frames = 10 + self.sequence_length = 20 + self.predict_frames = self.sequence_length - self.context_frames + self.aggregate_nccl=aggregate_nccl + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + batch_size: batch size for training. + lr: learning rate. if decay steps is non-zero, this is the + learning rate for steps <= decay_step. + + + + max_steps: number of training steps. + + + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + """ + default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict() + print ("default hparams",default_hparams) + hparams = dict( + batch_size=16, + lr=0.001, + end_lr=0.0, + nz=16, + decay_steps=(200000, 300000), + max_steps=350000, + ) + + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def build_graph(self, x): + self.x = x["images"] + + self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + original_global_variables = tf.global_variables() + # ARCHITECTURE + self.x_hat_context_frames, self.x_hat_predict_frames = self.convLSTM_network() + self.x_hat = tf.concat([self.x_hat_context_frames, self.x_hat_predict_frames], 1) + + + self.context_frames_loss = tf.reduce_mean( + tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) + self.predict_frames_loss = tf.reduce_mean( + tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_predict_frames[:, :, :, :, 0])) + self.total_loss = self.context_frames_loss + self.predict_frames_loss + + self.train_op = tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) + self.outputs = {} + self.outputs["gen_images"] = self.x_hat + # Summary op + self.loss_summary = tf.summary.scalar("recon_loss", self.context_frames_loss) + self.loss_summary = tf.summary.scalar("latent_loss", self.predict_frames_loss) + self.loss_summary = tf.summary.scalar("total_loss", self.total_loss) + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + return + + + @staticmethod + def convLSTM_cell(inputs, hidden, nz=16): + + conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu") + + conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2", activate = "leaky_relu") + + conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3", activate = "leaky_relu") + + y_0 = conv3 + # conv lstm cell + cell_shape = y_0.get_shape().as_list() + with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): + cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size = [3, 3], num_features = 8) + if hidden is None: + hidden = cell.zero_state(y_0, tf.float32) + + output, hidden = cell(y_0, hidden) + + + output_shape = output.get_shape().as_list() + + + z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) + + conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, "decode_5", activate = "leaky_relu") + + + conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6", activate = "leaky_relu") + + + x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, "decode_7", activate = "sigmoid") # set activation to linear + + return x_hat, hidden + + def convLSTM_network(self): + network_template = tf.make_template('network', + VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables + # create network + x_hat_context = [] + x_hat_predict = [] + seq_start = 1 + hidden = None + for i in range(self.context_frames): + if i < seq_start: + x_1, hidden = network_template(self.x[:, i, :, :, :], hidden) + else: + x_1, hidden = network_template(x_1, hidden) + x_hat_context.append(x_1) + + for i in range(self.predict_frames): + x_1, hidden = network_template(x_1, hidden) + x_hat_predict.append(x_1) + + # pack them all together + x_hat_context = tf.stack(x_hat_context) + x_hat_predict = tf.stack(x_hat_predict) + self.x_hat_context = tf.transpose(x_hat_context, [1, 0, 2, 3, 4]) # change first dim with sec dim + self.x_hat_predict = tf.transpose(x_hat_predict, [1, 0, 2, 3, 4]) # change first dim with sec dim + return self.x_hat_context, self.x_hat_predict diff --git a/video_prediction/models/vanilla_vae_model.py b/video_prediction/models/vanilla_vae_model.py new file mode 100644 index 0000000000000000000000000000000000000000..eec5598305044226280080d630313487c7d847a4 --- /dev/null +++ b/video_prediction/models/vanilla_vae_model.py @@ -0,0 +1,191 @@ +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from video_prediction import ops, flow_ops +from video_prediction.models import BaseVideoPredictionModel +from video_prediction.models import networks +from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from video_prediction.layers import layer_def as ld + +class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): + def __init__(self, mode='train', aggregate_nccl=None,hparams_dict=None, + hparams=None,**kwargs): + super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) + self.mode = mode + self.learning_rate = self.hparams.lr + self.nz = self.hparams.nz + self.aggregate_nccl=aggregate_nccl + self.gen_images_enc = None + self.train_op = None + self.summary_op = None + self.recon_loss = None + self.latent_loss = None + self.total_loss = None + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + batch_size: batch size for training. + lr: learning rate. if decay steps is non-zero, this is the + learning rate for steps <= decay_step. + end_lr: learning rate for steps >= end_decay_step if decay_steps + is non-zero, ignored otherwise. + decay_steps: (decay_step, end_decay_step) tuple. + max_steps: number of training steps. + beta1: momentum term of Adam. + beta2: momentum term of Adam. + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + """ + default_hparams = super(VanillaVAEVideoPredictionModel, self).get_default_hparams_dict() + print ("default hparams",default_hparams) + hparams = dict( + batch_size=16, + lr=0.001, + end_lr=0.0, + decay_steps=(200000, 300000), + lr_boundaries=(0,), + max_steps=350000, + nz=10, + context_frames=-1, + sequence_length=-1, + clip_length=10, #Bing: TODO What is the clip_length, original is 10, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def build_graph(self,x): + + + + + + + tf.set_random_seed(12345) + self.x = x["images"] + + + self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + original_global_variables = tf.global_variables() + self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step') + + self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all() + + + + + + + + + + + + self.recon_loss = tf.reduce_mean(tf.square(self.x[:, 1:, :, :, 0] - self.x_hat[:, :-1, :, :, 0])) + + + + + + latent_loss = -0.5 * tf.reduce_sum( + 1 + self.z_log_sigma_sq - tf.square(self.z_mu) - + tf.exp(self.z_log_sigma_sq), axis = 1) + self.latent_loss = tf.reduce_mean(latent_loss) + self.total_loss = self.recon_loss + self.latent_loss + self.train_op = tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) + # Build a saver + + self.losses = { + 'recon_loss': self.recon_loss, + 'latent_loss': self.latent_loss, + 'total_loss': self.total_loss, + } + + # Summary op + self.loss_summary = tf.summary.scalar("recon_loss", self.recon_loss) + self.loss_summary = tf.summary.scalar("latent_loss", self.latent_loss) + self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss) + self.summary_op = tf.summary.merge_all() + + + + self.outputs = {} + self.outputs["gen_images"] = self.x_hat + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + + return + + + @staticmethod + def vae_arc3(x,l_name=0,nz=16): + seq_name = "sq_" + str(l_name) + "_" + + conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1") + + + conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2") + + + conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3") + + + conv4 = tf.layers.Flatten()(conv3) + + conv3_shape = conv3.get_shape().as_list() + + + z_mu = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m") + z_log_sigma_sq = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma') + eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32) + z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps + + z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") + + + z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]]) + + conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, + seq_name + "decode_5") + + conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, + seq_name + "decode_6") + + + x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8") + + return x_hat, z_mu, z_log_sigma_sq, z + + def vae_arc_all(self): + X = [] + z_log_sigma_sq_all = [] + z_mu_all = [] + for i in range(20): + q, z_mu, z_log_sigma_sq, z = VanillaVAEVideoPredictionModel.vae_arc3(self.x[:, i, :, :, :], l_name=i, nz=self.nz) + X.append(q) + z_log_sigma_sq_all.append(z_log_sigma_sq) + z_mu_all.append(z_mu) + x_hat = tf.stack(X, axis = 1) + z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1) + z_mu_all = tf.stack(z_mu_all, axis = 1) + + + return x_hat, z_log_sigma_sq_all, z_mu_all diff --git a/video_prediction/ops.py b/video_prediction/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1d2e9f2eac608d2a9ec61e24a354882b3acce2de --- /dev/null +++ b/video_prediction/ops.py @@ -0,0 +1,1098 @@ +import numpy as np +import tensorflow as tf + + +def dense(inputs, units, use_spectral_norm=False, use_bias=True): + with tf.variable_scope('dense'): + input_shape = inputs.get_shape().as_list() + kernel_shape = [input_shape[1], units] + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) + if use_spectral_norm: + kernel = spectral_normed_weight(kernel) + outputs = tf.matmul(inputs, kernel) + if use_bias: + bias = tf.get_variable('bias', [units], dtype=tf.float32, initializer=tf.zeros_initializer()) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def pad1d(inputs, size, strides=(1,), padding='SAME', mode='CONSTANT'): + size = list(size) if isinstance(size, (tuple, list)) else [size] + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] + input_shape = inputs.get_shape().as_list() + assert len(input_shape) == 3 + in_width = input_shape[1] + if padding in ('SAME', 'FULL'): + if in_width % strides[0] == 0: + pad_along_width = max(size[0] - strides[0], 0) + else: + pad_along_width = max(size[0] - (in_width % strides[0]), 0) + if padding == 'SAME': + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + else: + pad_left = pad_along_width + pad_right = pad_along_width + padding_pattern = [[0, 0], + [pad_left, pad_right], + [0, 0]] + outputs = tf.pad(inputs, padding_pattern, mode=mode) + elif padding == 'VALID': + outputs = inputs + else: + raise ValueError("Invalid padding scheme %s" % padding) + return outputs + + +def conv1d(inputs, filters, kernel_size, strides=(1,), padding='SAME', kernel=None, use_bias=True): + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] + input_shape = inputs.get_shape().as_list() + kernel_shape = list(kernel_size) + [input_shape[-1], filters] + if kernel is None: + with tf.variable_scope('conv1d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + else: + if kernel_shape != kernel.get_shape().as_list(): + raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) + if padding == 'FULL': + inputs = pad1d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') + padding = 'VALID' + stride, = strides + outputs = tf.nn.conv1d(inputs, kernel, stride, padding=padding) + if use_bias: + with tf.variable_scope('conv1d'): + bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def pad2d_paddings(inputs, size, strides=(1, 1), rate=(1, 1), padding='SAME'): + """ + Computes the paddings for a 4-D tensor according to the convolution padding algorithm. + + See pad2d. + + Reference: + https://www.tensorflow.org/api_guides/python/nn#convolution + https://www.tensorflow.org/api_docs/python/tf/nn/with_space_to_batch + """ + size = np.array(size) if isinstance(size, (tuple, list)) else np.array([size] * 2) + strides = np.array(strides) if isinstance(strides, (tuple, list)) else np.array([strides] * 2) + rate = np.array(rate) if isinstance(rate, (tuple, list)) else np.array([rate] * 2) + if np.any(strides > 1) and np.any(rate > 1): + raise ValueError("strides > 1 not supported in conjunction with rate > 1") + input_shape = inputs.get_shape().as_list() + assert len(input_shape) == 4 + input_size = np.array(input_shape[1:3]) + if padding in ('SAME', 'FULL'): + if np.any(rate > 1): + # We have two padding contributions. The first is used for converting "SAME" + # to "VALID". The second is required so that the height and width of the + # zero-padded value tensor are multiples of rate. + + # Spatial dimensions of the filters and the upsampled filters in which we + # introduce (rate - 1) zeros between consecutive filter values. + dilated_size = size + (size - 1) * (rate - 1) + pad = dilated_size - 1 + else: + pad = np.where(input_size % strides == 0, + np.maximum(size - strides, 0), + np.maximum(size - (input_size % strides), 0)) + if padding == 'SAME': + # When full_padding_shape is odd, we pad more at end, following the same + # convention as conv2d. + pad_start = pad // 2 + pad_end = pad - pad_start + else: + pad_start = pad + pad_end = pad + if np.any(rate > 1): + # More padding so that rate divides the height and width of the input. + # TODO: not sure if this is correct when padding == 'FULL' + orig_pad_end = pad_end + full_input_size = input_size + pad_start + orig_pad_end + pad_end_extra = (rate - full_input_size % rate) % rate + pad_end = orig_pad_end + pad_end_extra + paddings = [[0, 0], + [pad_start[0], pad_end[0]], + [pad_start[1], pad_end[1]], + [0, 0]] + elif padding == 'VALID': + paddings = [[0, 0]] * 4 + else: + raise ValueError("Invalid padding scheme %s" % padding) + return paddings + + +def pad2d(inputs, size, strides=(1, 1), rate=(1, 1), padding='SAME', mode='CONSTANT'): + """ + Pads a 4-D tensor according to the convolution padding algorithm. + + Convolution with a padding scheme + conv2d(..., padding=padding) + is equivalent to zero-padding of the input with such scheme, followed by + convolution with 'VALID' padding + padded = pad2d(..., padding=padding, mode='CONSTANT') + conv2d(padded, ..., padding='VALID') + + Args: + inputs: A 4-D tensor of shape + `[batch, in_height, in_width, in_channels]`. + padding: A string, either 'VALID', 'SAME', or 'FULL'. The padding algorithm. + mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive). + + Returns: + A 4-D tensor. + + Reference: + https://www.tensorflow.org/api_guides/python/nn#convolution + """ + paddings = pad2d_paddings(inputs, size, strides=strides, rate=rate, padding=padding) + if paddings == [[0, 0]] * 4: + outputs = inputs + else: + outputs = tf.pad(inputs, paddings, mode=mode) + return outputs + + +def local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', + kernel=None, flip_filters=False, + use_bias=True, channelwise=False): + """ + 2-D locally connected operation. + + Works similarly to 2-D convolution except that the weights are unshared, that is, a different set of filters is + applied at each different patch of the input. + + Args: + inputs: A 4-D tensor of shape + `[batch, in_height, in_width, in_channels]`. + kernel: A 6-D or 7-D tensor of shape + `[in_height, in_width, kernel_size[0], kernel_size[1], in_channels, filters]` or + `[batch, in_height, in_width, kernel_size[0], kernel_size[1], in_channels, filters]`. + + Returns: + A 4-D tensor. + """ + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + if strides != [1, 1]: + raise NotImplementedError + if padding == 'FULL': + inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') + padding = 'VALID' + input_shape = inputs.get_shape().as_list() + if padding == 'SAME': + output_shape = input_shape[:3] + [filters] + elif padding == 'VALID': + output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters] + else: + raise ValueError("Invalid padding scheme %s" % padding) + + if channelwise: + if filters not in (input_shape[-1], 1): + raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise " + "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1])) + kernel_shape = output_shape[1:3] + kernel_size + [filters] + else: + kernel_shape = output_shape[1:3] + kernel_size + [input_shape[-1], filters] + if kernel is None: + with tf.variable_scope('local2d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + else: + if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape): + raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s" + % (tuple(kernel_shape), tuple([input_shape[0]] + kernel_shape), tuple(kernel.get_shape().as_list()))) + + outputs = [] + for i in range(kernel_size[0]): + filter_h_ind = -i-1 if flip_filters else i + if padding == 'VALID': + ii = i + else: + ii = i - (kernel_size[0] // 2) + input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1])) + output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii) + assert 0 <= output_h_slice.start < output_shape[1] + assert 0 < output_h_slice.stop <= output_shape[1] + + for j in range(kernel_size[1]): + filter_w_ind = -j-1 if flip_filters else j + if padding == 'VALID': + jj = j + else: + jj = j - (kernel_size[1] // 2) + input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2])) + output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj) + assert 0 <= output_w_slice.start < output_shape[2] + assert 0 < output_w_slice.stop <= output_shape[2] + if channelwise: + inc = inputs[:, input_h_slice, input_w_slice, :] * \ + kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :] + else: + inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] * + kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :, :], axis=-2) + # equivalent to this + # outputs[:, output_h_slice, output_w_slice, :] += inc + paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop], + [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]] + outputs.append(tf.pad(inc, paddings)) + outputs = tf.add_n(outputs) + if use_bias: + with tf.variable_scope('local2d'): + bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer()) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def separable_local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', + vertical_kernel=None, horizontal_kernel=None, flip_filters=False, + use_bias=True, channelwise=False): + """ + 2-D locally connected operation with separable filters. + + Note that, unlike tf.nn.separable_conv2d, this is spatial separability between dimensions 1 and 2. + + Args: + inputs: A 4-D tensor of shape + `[batch, in_height, in_width, in_channels]`. + vertical_kernel: A 5-D or 6-D tensor of shape + `[in_height, in_width, kernel_size[0], in_channels, filters]` or + `[batch, in_height, in_width, kernel_size[0], in_channels, filters]`. + horizontal_kernel: A 5-D or 6-D tensor of shape + `[in_height, in_width, kernel_size[1], in_channels, filters]` or + `[batch, in_height, in_width, kernel_size[1], in_channels, filters]`. + + Returns: + A 4-D tensor. + """ + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + if strides != [1, 1]: + raise NotImplementedError + if padding == 'FULL': + inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') + padding = 'VALID' + input_shape = inputs.get_shape().as_list() + if padding == 'SAME': + output_shape = input_shape[:3] + [filters] + elif padding == 'VALID': + output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters] + else: + raise ValueError("Invalid padding scheme %s" % padding) + + kernels = [vertical_kernel, horizontal_kernel] + for i, (kernel_type, kernel_length, kernel) in enumerate(zip(['vertical', 'horizontal'], kernel_size, kernels)): + if channelwise: + if filters not in (input_shape[-1], 1): + raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise " + "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1])) + kernel_shape = output_shape[1:3] + [kernel_length, filters] + else: + kernel_shape = output_shape[1:3] + [kernel_length, input_shape[-1], filters] + if kernel is None: + with tf.variable_scope('separable_local2d'): + kernel = tf.get_variable('%s_kernel' % kernel_type, kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + kernels[i] = kernel + else: + if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape): + raise ValueError("Expecting %s kernel with shape %s or %s but instead got kernel with shape %s" + % (kernel_type, + tuple(kernel_shape), tuple([input_shape[0]] +kernel_shape), + tuple(kernel.get_shape().as_list()))) + + outputs = [] + for i in range(kernel_size[0]): + filter_h_ind = -i-1 if flip_filters else i + if padding == 'VALID': + ii = i + else: + ii = i - (kernel_size[0] // 2) + input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1])) + output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii) + assert 0 <= output_h_slice.start < output_shape[1] + assert 0 < output_h_slice.stop <= output_shape[1] + + for j in range(kernel_size[1]): + filter_w_ind = -j-1 if flip_filters else j + if padding == 'VALID': + jj = j + else: + jj = j - (kernel_size[1] // 2) + input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2])) + output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj) + assert 0 <= output_w_slice.start < output_shape[2] + assert 0 < output_w_slice.stop <= output_shape[2] + if channelwise: + inc = inputs[:, input_h_slice, input_w_slice, :] * \ + kernels[0][..., output_h_slice, output_w_slice, filter_h_ind, :] * \ + kernels[1][..., output_h_slice, output_w_slice, filter_w_ind, :] + else: + inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] * + kernels[0][..., output_h_slice, output_w_slice, filter_h_ind, :, :] * + kernels[1][..., output_h_slice, output_w_slice, filter_w_ind, :, :], + axis=-2) + # equivalent to this + # outputs[:, output_h_slice, output_w_slice, :] += inc + paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop], + [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]] + outputs.append(tf.pad(inc, paddings)) + outputs = tf.add_n(outputs) + if use_bias: + with tf.variable_scope('separable_local2d'): + bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer()) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def kronecker_local2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', + kernels=None, flip_filters=False, use_bias=True, channelwise=False): + """ + 2-D locally connected operation with filters represented as a kronecker product of smaller filters + + Args: + inputs: A 4-D tensor of shape + `[batch, in_height, in_width, in_channels]`. + kernel: A list of 6-D or 7-D tensors of shape + `[in_height, in_width, kernel_size[i][0], kernel_size[i][1], in_channels, filters]` or + `[batch, in_height, in_width, kernel_size[i][0], kernel_size[i][1], in_channels, filters]`. + + Returns: + A 4-D tensor. + """ + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + if strides != [1, 1]: + raise NotImplementedError + if padding == 'FULL': + inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') + padding = 'VALID' + input_shape = inputs.get_shape().as_list() + if padding == 'SAME': + output_shape = input_shape[:3] + [filters] + elif padding == 'VALID': + output_shape = [input_shape[0], input_shape[1] - kernel_size[0] + 1, input_shape[2] - kernel_size[1] + 1, filters] + else: + raise ValueError("Invalid padding scheme %s" % padding) + + if channelwise: + if filters not in (input_shape[-1], 1): + raise ValueError("Number of filters should match the number of input channels or be 1 when channelwise " + "is true, but got filters=%r and %d input channels" % (filters, input_shape[-1])) + kernel_shape = output_shape[1:3] + kernel_size + [filters] + factor_kernel_shape = output_shape[1:3] + [None, None, filters] + else: + kernel_shape = output_shape[1:3] + kernel_size + [input_shape[-1], filters] + factor_kernel_shape = output_shape[1:3] + [None, None, input_shape[-1], filters] + if kernels is None: + with tf.variable_scope('kronecker_local2d'): + kernels = [tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02))] + filter_h_lengths = [kernel_size[0]] + filter_w_lengths = [kernel_size[1]] + else: + for kernel in kernels: + if not ((len(kernel.shape) == len(factor_kernel_shape) and + all(((k == f) or f is None) for k, f in zip(kernel.get_shape().as_list(), factor_kernel_shape))) or + (len(kernel.shape) == (len(factor_kernel_shape) + 1) and + all(((k == f) or f is None) for k, f in zip(kernel.get_shape().as_list(), [input_shape[0]] +factor_kernel_shape)))): + raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s" + % (tuple(factor_kernel_shape), tuple([input_shape[0]] + factor_kernel_shape), + tuple(kernel.get_shape().as_list()))) + if channelwise: + filter_h_lengths, filter_w_lengths = zip(*[kernel.get_shape().as_list()[-3:-1] for kernel in kernels]) + else: + filter_h_lengths, filter_w_lengths = zip(*[kernel.get_shape().as_list()[-4:-2] for kernel in kernels]) + if [np.prod(filter_h_lengths), np.prod(filter_w_lengths)] != kernel_size: + raise ValueError("Expecting kernel size %s but instead got kernel size %s" + % (tuple(kernel_size), tuple([np.prod(filter_h_lengths), np.prod(filter_w_lengths)]))) + + def get_inds(ind, lengths): + inds = [] + for i in range(len(lengths)): + curr_ind = int(ind) + for j in range(len(lengths) - 1, i, -1): + curr_ind //= lengths[j] + curr_ind %= lengths[i] + inds.append(curr_ind) + return inds + + outputs = [] + for i in range(kernel_size[0]): + if padding == 'VALID': + ii = i + else: + ii = i - (kernel_size[0] // 2) + input_h_slice = slice(max(ii, 0), min(ii + output_shape[1], input_shape[1])) + output_h_slice = slice(input_h_slice.start - ii, input_h_slice.stop - ii) + assert 0 <= output_h_slice.start < output_shape[1] + assert 0 < output_h_slice.stop <= output_shape[1] + + for j in range(kernel_size[1]): + if padding == 'VALID': + jj = j + else: + jj = j - (kernel_size[1] // 2) + input_w_slice = slice(max(jj, 0), min(jj + output_shape[2], input_shape[2])) + output_w_slice = slice(input_w_slice.start - jj, input_w_slice.stop - jj) + assert 0 <= output_w_slice.start < output_shape[2] + assert 0 < output_w_slice.stop <= output_shape[2] + kernel_slice = 1.0 + for filter_h_ind, filter_w_ind, kernel in zip(get_inds(i, filter_h_lengths), get_inds(j, filter_w_lengths), kernels): + if flip_filters: + filter_h_ind = -filter_h_ind-1 + filter_w_ind = -filter_w_ind-1 + if channelwise: + kernel_slice *= kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :] + else: + kernel_slice *= kernel[..., output_h_slice, output_w_slice, filter_h_ind, filter_w_ind, :, :] + if channelwise: + inc = inputs[:, input_h_slice, input_w_slice, :] * kernel_slice + else: + inc = tf.reduce_sum(inputs[:, input_h_slice, input_w_slice, :, None] * kernel_slice, axis=-2) + # equivalent to this + # outputs[:, output_h_slice, output_w_slice, :] += inc + paddings = [[0, 0], [output_h_slice.start, output_shape[1] - output_h_slice.stop], + [output_w_slice.start, output_shape[2] - output_w_slice.stop], [0, 0]] + outputs.append(tf.pad(inc, paddings)) + outputs = tf.add_n(outputs) + if use_bias: + with tf.variable_scope('kronecker_local2d'): + bias = tf.get_variable('bias', output_shape[1:], dtype=tf.float32, initializer=tf.zeros_initializer()) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def depthwise_conv2d(inputs, channel_multiplier, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True): + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + input_shape = inputs.get_shape().as_list() + kernel_shape = kernel_size + [input_shape[-1], channel_multiplier] + if kernel is None: + with tf.variable_scope('depthwise_conv2d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + else: + if kernel_shape != kernel.get_shape().as_list(): + raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" + % (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) + if padding == 'FULL': + inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') + padding = 'VALID' + outputs = tf.nn.depthwise_conv2d(inputs, kernel, [1] + strides + [1], padding=padding) + if use_bias: + with tf.variable_scope('depthwise_conv2d'): + bias = tf.get_variable('bias', [input_shape[-1] * channel_multiplier], dtype=tf.float32, initializer=tf.zeros_initializer()) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, use_spectral_norm=False): + """ + 2-D convolution. + + Args: + inputs: A 4-D tensor of shape + `[batch, in_height, in_width, in_channels]`. + kernel: A 4-D or 5-D tensor of shape + `[kernel_size[0], kernel_size[1], in_channels, filters]` or + `[batch, kernel_size[0], kernel_size[1], in_channels, filters]`. + bias: A 1-D or 2-D tensor of shape + `[filters]` or `[batch, filters]`. + + Returns: + A 4-D tensor. + """ + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + input_shape = inputs.get_shape().as_list() + kernel_shape = list(kernel_size) + [input_shape[-1], filters] + if kernel is None: + with tf.variable_scope('conv2d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + if use_spectral_norm: + kernel = spectral_normed_weight(kernel) + else: + if kernel.get_shape().as_list() not in (kernel_shape, [input_shape[0]] + kernel_shape): + raise ValueError("Expecting kernel with shape %s or %s but instead got kernel with shape %s" + % (tuple(kernel_shape), tuple([input_shape[0]] + kernel_shape), tuple(kernel.get_shape().as_list()))) + if padding == 'FULL': + inputs = pad2d(inputs, kernel_size, strides=strides, padding=padding, mode='CONSTANT') + padding = 'VALID' + if kernel.get_shape().ndims == 4: + outputs = tf.nn.conv2d(inputs, kernel, [1] + strides + [1], padding=padding) + else: + def conv2d_single_fn(args): + input_, kernel_ = args + input_ = tf.expand_dims(input_, axis=0) + output = tf.nn.conv2d(input_, kernel_, [1] + strides + [1], padding=padding) + output = tf.squeeze(output, axis=0) + return output + outputs = tf.map_fn(conv2d_single_fn, [inputs, kernel], dtype=tf.float32) + if use_bias: + bias_shape = [filters] + if bias is None: + with tf.variable_scope('conv2d'): + bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) + else: + if bias.get_shape().as_list() not in (bias_shape, [input_shape[0]] + bias_shape): + raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" + % (tuple(bias_shape), tuple(bias.get_shape().as_list()))) + if bias.get_shape().ndims == 1: + outputs = tf.nn.bias_add(outputs, bias) + else: + outputs = tf.add(outputs, bias[:, None, None, :]) + return outputs + + +def deconv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True): + """ + 2-D transposed convolution. + + Notes on padding: + The equivalent of transposed convolution with full padding is a convolution with valid padding, and + the equivalent of transposed convolution with valid padding is a convolution with full padding. + + Reference: + http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html + """ + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + input_shape = inputs.get_shape().as_list() + kernel_shape = list(kernel_size) + [filters, input_shape[-1]] + if kernel is None: + with tf.variable_scope('deconv2d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + else: + if kernel_shape != kernel.get_shape().as_list(): + raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) + if padding == 'FULL': + output_h, output_w = [s * (i + 1) - k for (i, k, s) in zip(input_shape[1:3], kernel_size, strides)] + elif padding == 'SAME': + output_h, output_w = [s * i for (i, s) in zip(input_shape[1:3], strides)] + elif padding == 'VALID': + output_h, output_w = [s * (i - 1) + k for (i, k, s) in zip(input_shape[1:3], kernel_size, strides)] + else: + raise ValueError("Invalid padding scheme %s" % padding) + output_shape = [input_shape[0], output_h, output_w, filters] + outputs = tf.nn.conv2d_transpose(inputs, kernel, output_shape, [1] + strides + [1], padding=padding) + if use_bias: + with tf.variable_scope('deconv2d'): + bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def get_bilinear_kernel(strides): + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + strides = np.array(strides) + kernel_size = 2 * strides - strides % 2 + center = strides - (kernel_size % 2 == 1) - 0.5 * (kernel_size % 2 != 1) + vertical_kernel = 1 - abs(np.arange(kernel_size[0]) - center[0]) / strides[0] + horizontal_kernel = 1 - abs(np.arange(kernel_size[1]) - center[1]) / strides[1] + kernel = vertical_kernel[:, None] * horizontal_kernel[None, :] + return kernel + + +def upsample2d(inputs, strides, padding='SAME', upsample_mode='bilinear'): + if upsample_mode == 'bilinear': + single_bilinear_kernel = get_bilinear_kernel(strides).astype(np.float32) + input_shape = inputs.get_shape().as_list() + bilinear_kernel = tf.matrix_diag(tf.tile(tf.constant(single_bilinear_kernel)[..., None], (1, 1, input_shape[-1]))) + outputs = deconv2d(inputs, input_shape[-1], kernel_size=single_bilinear_kernel.shape, + strides=strides, kernel=bilinear_kernel, padding=padding, use_bias=False) + elif upsample_mode == 'nearest': + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + input_shape = inputs.get_shape().as_list() + inputs_tiled = tf.tile(inputs[:, :, None, :, None, :], [1, 1, strides[0], 1, strides[1], 1]) + outputs = tf.reshape(inputs_tiled, [input_shape[0], input_shape[1] * strides[0], + input_shape[2] * strides[1], input_shape[3]]) + else: + raise ValueError("Unknown upsample mode %s" % upsample_mode) + return outputs + + +def upsample2d_v2(inputs, strides, padding='SAME', upsample_mode='bilinear'): + """ + Possibly less computationally efficient but more memory efficent than upsampled2d. + """ + if upsample_mode == 'bilinear': + single_kernel = get_bilinear_kernel(strides).astype(np.float32) + elif upsample_mode == 'nearest': + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + single_kernel = np.ones(strides, dtype=np.float32) + else: + raise ValueError("Unknown upsample mode %s" % upsample_mode) + input_shape = inputs.get_shape().as_list() + kernel = tf.constant(single_kernel)[:, :, None, None] + inputs = tf.transpose(inputs, [3, 0, 1, 2])[..., None] + outputs = tf.map_fn(lambda input: deconv2d(input, 1, kernel_size=single_kernel.shape, + strides=strides, kernel=kernel, + padding=padding, use_bias=False), + inputs, parallel_iterations=input_shape[-1]) + outputs = tf.transpose(tf.squeeze(outputs, axis=-1), [1, 2, 3, 0]) + return outputs + + +def upsample_conv2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', + kernel=None, use_bias=True, bias=None, upsample_mode='bilinear'): + """ + Upsamples the inputs by a factor using bilinear interpolation and the performs conv2d on the upsampled input. This + function is more computationally and memory efficient than a naive implementation. Unlike a naive implementation + that would upsample the input first, this implementation first convolves the bilinear kernel with the given kernel, + and then performs the convolution (actually a deconv2d) with the combined kernel. As opposed to just using deconv2d + directly, this function is less prone to checkerboard artifacts thanks to the implicit bilinear upsampling. + + Example: + >>> import numpy as np + >>> import tensorflow as tf + >>> from video_prediction.ops import upsample_conv2d, upsample2d, conv2d, pad2d_paddings + >>> inputs_shape = [4, 8, 8, 64] + >>> kernel_size = [3, 3] # for convolution + >>> filters = 32 # for convolution + >>> strides = [2, 2] # for upsampling + >>> inputs = tf.get_variable("inputs", inputs_shape) + >>> kernel = tf.get_variable("kernel", (kernel_size[0], kernel_size[1], inputs_shape[-1], filters)) + >>> bias = tf.get_variable("bias", (filters,)) + >>> outputs = upsample_conv2d(inputs, filters, kernel_size=kernel_size, strides=strides, \ + kernel=kernel, bias=bias) + >>> # upsample with bilinear interpolation + >>> inputs_up = upsample2d(inputs, strides=strides, padding='VALID') + >>> # convolve upsampled input with kernel + >>> outputs_up = conv2d(inputs_up, filters, kernel_size=kernel_size, strides=(1, 1), \ + kernel=kernel, bias=bias, padding='FULL') + >>> # crop appropriately + >>> same_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='SAME') + >>> full_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='FULL') + >>> crop_top = (strides[0] - strides[0] % 2) // 2 + full_paddings[1][1] - same_paddings[1][1] + >>> crop_left = (strides[1] - strides[1] % 2) // 2 + full_paddings[2][1] - same_paddings[2][1] + >>> outputs_up = outputs_up[:, crop_top:crop_top + strides[0] * inputs_shape[1], \ + crop_left:crop_left + strides[1] * inputs_shape[2], :] + >>> sess = tf.Session() + >>> sess.run(tf.global_variables_initializer()) + >>> assert np.allclose(*sess.run([outputs, outputs_up]), atol=1e-5) + + """ + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + if padding != 'SAME' or upsample_mode != 'bilinear': + raise NotImplementedError + input_shape = inputs.get_shape().as_list() + kernel_shape = list(kernel_size) + [input_shape[-1], filters] + if kernel is None: + with tf.variable_scope('upsample_conv2d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + else: + if kernel_shape != kernel.get_shape().as_list(): + raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % + (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) + + # convolve bilinear kernel with kernel + single_bilinear_kernel = get_bilinear_kernel(strides).astype(np.float32) + kernel_transposed = tf.transpose(kernel, (0, 1, 3, 2)) + kernel_reshaped = tf.reshape(kernel_transposed, kernel_size + [1, input_shape[-1] * filters]) + kernel_up_reshaped = conv2d(tf.constant(single_bilinear_kernel)[None, :, :, None], input_shape[-1] * filters, + kernel_size=kernel_size, kernel=kernel_reshaped, padding='FULL', use_bias=False) + kernel_up = tf.reshape(kernel_up_reshaped, + kernel_up_reshaped.get_shape().as_list()[1:3] + [filters, input_shape[-1]]) + + # deconvolve with the bilinearly convolved kernel + outputs = deconv2d(inputs, filters, kernel_size=kernel_up.get_shape().as_list()[:2], strides=strides, + kernel=kernel_up, padding='SAME', use_bias=False) + if use_bias: + if bias is None: + with tf.variable_scope('upsample_conv2d'): + bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) + else: + bias_shape = [filters] + if bias_shape != bias.get_shape().as_list(): + raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" % + (tuple(bias_shape), tuple(bias.get_shape().as_list()))) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def upsample_conv2d_v2(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', + kernel=None, use_bias=True, bias=None, upsample_mode='bilinear'): + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + if padding != 'SAME': + raise NotImplementedError + input_shape = inputs.get_shape().as_list() + kernel_shape = list(kernel_size) + [input_shape[-1], filters] + if kernel is None: + with tf.variable_scope('upsample_conv2d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + else: + if kernel_shape != kernel.get_shape().as_list(): + raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % + (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) + + inputs_up = upsample2d_v2(inputs, strides=strides, padding='VALID', upsample_mode=upsample_mode) + # convolve upsampled input with kernel + outputs = conv2d(inputs_up, filters, kernel_size=kernel_size, strides=(1, 1), + kernel=kernel, bias=None, padding='FULL', use_bias=False) + # crop appropriately + same_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='SAME') + full_paddings = pad2d_paddings(inputs, kernel_size, strides=(1, 1), padding='FULL') + crop_top = (strides[0] - strides[0] % 2) // 2 + full_paddings[1][1] - same_paddings[1][1] + crop_left = (strides[1] - strides[1] % 2) // 2 + full_paddings[2][1] - same_paddings[2][1] + outputs = outputs[:, crop_top:crop_top + strides[0] * input_shape[1], + crop_left:crop_left + strides[1] * input_shape[2], :] + + if use_bias: + if bias is None: + with tf.variable_scope('upsample_conv2d'): + bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) + else: + bias_shape = [filters] + if bias_shape != bias.get_shape().as_list(): + raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" % + (tuple(bias_shape), tuple(bias.get_shape().as_list()))) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def conv3d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', use_bias=True, use_spectral_norm=False): + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 3 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 3 + input_shape = inputs.get_shape().as_list() + kernel_shape = list(kernel_size) + [input_shape[-1], filters] + with tf.variable_scope('conv3d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) + if use_spectral_norm: + kernel = spectral_normed_weight(kernel) + outputs = tf.nn.conv3d(inputs, kernel, [1] + strides + [1], padding=padding) + if use_bias: + bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def pool2d(inputs, pool_size, strides=(1, 1), padding='SAME', pool_mode='avg'): + pool_size = list(pool_size) if isinstance(pool_size, (tuple, list)) else [pool_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + if padding == 'FULL': + inputs = pad2d(inputs, pool_size, strides=strides, padding=padding, mode='CONSTANT') + padding = 'VALID' + if pool_mode == 'max': + outputs = tf.nn.max_pool(inputs, [1] + pool_size + [1], [1] + strides + [1], padding=padding) + elif pool_mode == 'avg': + outputs = tf.nn.avg_pool(inputs, [1] + pool_size + [1], [1] + strides + [1], padding=padding) + else: + raise ValueError('Invalid pooling mode:', pool_mode) + return outputs + + +def conv_pool2d(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, pool_mode='avg'): + """ + Similar optimization as in upsample_conv2d + + Example: + >>> import numpy as np + >>> import tensorflow as tf + >>> from video_prediction.ops import conv_pool2d, conv2d, pool2d + >>> inputs_shape = [4, 16, 16, 32] + >>> kernel_size = [3, 3] # for convolution + >>> filters = 64 # for convolution + >>> strides = [2, 2] # for pooling + >>> inputs = tf.get_variable("inputs", inputs_shape) + >>> kernel = tf.get_variable("kernel", (kernel_size[0], kernel_size[1], inputs_shape[-1], filters)) + >>> bias = tf.get_variable("bias", (filters,)) + >>> outputs = conv_pool2d(inputs, filters, kernel_size=kernel_size, strides=strides, + kernel=kernel, bias=bias, pool_mode='avg') + >>> inputs_conv = conv2d(inputs, filters, kernel_size=kernel_size, strides=(1, 1), + kernel=kernel, bias=bias) + >>> outputs_pool = pool2d(inputs_conv, pool_size=strides, strides=strides, pool_mode='avg') + >>> sess = tf.Session() + >>> sess.run(tf.global_variables_initializer()) + >>> assert np.allclose(*sess.run([outputs, outputs_pool]), atol=1e-5) + + """ + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + if padding != 'SAME' or pool_mode != 'avg': + raise NotImplementedError + input_shape = inputs.get_shape().as_list() + if input_shape[1] % strides[0] or input_shape[2] % strides[1]: + raise NotImplementedError("The height and width of the input should be " + "an integer multiple of the respective stride.") + kernel_shape = list(kernel_size) + [input_shape[-1], filters] + if kernel is None: + with tf.variable_scope('conv_pool2d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + else: + if kernel_shape != kernel.get_shape().as_list(): + raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % + (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) + + # pool kernel + kernel_reshaped = tf.reshape(kernel, [1] + kernel_size + [input_shape[-1] * filters]) + kernel_pool_reshaped = pool2d(kernel_reshaped, pool_size=strides, padding='FULL', pool_mode='avg') + kernel_pool = tf.reshape(kernel_pool_reshaped, + kernel_pool_reshaped.get_shape().as_list()[1:3] + [input_shape[-1], filters]) + + outputs = conv2d(inputs, filters, kernel_size=kernel_pool.get_shape().as_list()[:2], strides=strides, + kernel=kernel_pool, padding='SAME', use_bias=False) + if use_bias: + if bias is None: + with tf.variable_scope('conv_pool2d'): + bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) + else: + bias_shape = [filters] + if bias_shape != bias.get_shape().as_list(): + raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" % + (tuple(bias_shape), tuple(bias.get_shape().as_list()))) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def conv_pool2d_v2(inputs, filters, kernel_size, strides=(1, 1), padding='SAME', kernel=None, use_bias=True, bias=None, pool_mode='avg'): + kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + strides = list(strides) if isinstance(strides, (tuple, list)) else [strides] * 2 + if padding != 'SAME' or pool_mode != 'avg': + raise NotImplementedError + input_shape = inputs.get_shape().as_list() + if input_shape[1] % strides[0] or input_shape[2] % strides[1]: + raise NotImplementedError("The height and width of the input should be " + "an integer multiple of the respective stride.") + kernel_shape = list(kernel_size) + [input_shape[-1], filters] + if kernel is None: + with tf.variable_scope('conv_pool2d'): + kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + else: + if kernel_shape != kernel.get_shape().as_list(): + raise ValueError("Expecting kernel with shape %s but instead got kernel with shape %s" % + (tuple(kernel_shape), tuple(kernel.get_shape().as_list()))) + + inputs_conv = conv2d(inputs, filters, kernel_size=kernel_size, strides=(1, 1), + kernel=kernel, bias=None, use_bias=False) + outputs = pool2d(inputs_conv, pool_size=strides, strides=strides, pool_mode='avg') + + if use_bias: + if bias is None: + with tf.variable_scope('conv_pool2d'): + bias = tf.get_variable('bias', [filters], dtype=tf.float32, initializer=tf.zeros_initializer()) + else: + bias_shape = [filters] + if bias_shape != bias.get_shape().as_list(): + raise ValueError("Expecting bias with shape %s but instead got bias with shape %s" % + (tuple(bias_shape), tuple(bias.get_shape().as_list()))) + outputs = tf.nn.bias_add(outputs, bias) + return outputs + + +def lrelu(x, alpha): + """ + Leaky ReLU activation function + + Reference: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_ops.py + """ + with tf.name_scope("lrelu"): + return tf.maximum(alpha * x, x) + + +def batchnorm(input): + with tf.variable_scope("batchnorm"): + # this block looks like it has 3 inputs on the graph unless we do this + input = tf.identity(input) + + channels = input.get_shape()[-1] + offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) + scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.truncated_normal_initializer(1.0, 0.02)) + mean, variance = tf.nn.moments(input, axes=list(range(len(input.get_shape()) - 1)), keepdims=False) + variance_epsilon = 1e-5 + normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) + return normalized + + +def instancenorm(input): + with tf.variable_scope("instancenorm"): + # this block looks like it has 3 inputs on the graph unless we do this + input = tf.identity(input) + + channels = input.get_shape()[-1] + offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) + scale = tf.get_variable("scale", [channels], dtype=tf.float32, + initializer=tf.truncated_normal_initializer(1.0, 0.02)) + mean, variance = tf.nn.moments(input, axes=list(range(1, len(input.get_shape()) - 1)), keepdims=True) + variance_epsilon = 1e-5 + normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, + variance_epsilon=variance_epsilon) + return normalized + + +def flatten(input, axis=1, end_axis=-1): + """ + Caffe-style flatten. + + Args: + inputs: An N-D tensor. + axis: The first axis to flatten: all preceding axes are retained in the output. + May be negative to index from the end (e.g., -1 for the last axis). + end_axis: The last axis to flatten: all following axes are retained in the output. + May be negative to index from the end (e.g., the default -1 for the last + axis) + + Returns: + A M-D tensor where M = N - (end_axis - axis) + """ + input_shape = tf.shape(input) + input_rank = tf.shape(input_shape)[0] + if axis < 0: + axis = input_rank + axis + if end_axis < 0: + end_axis = input_rank + end_axis + output_shape = [] + if axis != 0: + output_shape.append(input_shape[:axis]) + output_shape.append([tf.reduce_prod(input_shape[axis:end_axis + 1])]) + if end_axis + 1 != input_rank: + output_shape.append(input_shape[end_axis + 1:]) + output_shape = tf.concat(output_shape, axis=0) + output = tf.reshape(input, output_shape) + return output + + +def tile_concat(values, axis): + """ + Like concat except that first tiles the broadcastable dimensions if necessary + """ + shapes = [value.get_shape() for value in values] + # convert axis to positive form + ndims = shapes[0].ndims + for shape in shapes[1:]: + assert ndims == shape.ndims + if -ndims < axis < 0: + axis += ndims + # remove axis dimension + shapes = [shape.as_list() for shape in shapes] + dims = [shape.pop(axis) for shape in shapes] + shapes = [tf.TensorShape(shape) for shape in shapes] + # compute broadcasted shape + b_shape = shapes[0] + for shape in shapes[1:]: + b_shape = tf.broadcast_static_shape(b_shape, shape) + # add back axis dimension + b_shapes = [b_shape.as_list() for _ in dims] + for b_shape, dim in zip(b_shapes, dims): + b_shape.insert(axis, dim) + # tile values to match broadcasted shape, if necessary + b_values = [] + for value, b_shape in zip(values, b_shapes): + multiples = [] + for dim, b_dim in zip(value.get_shape().as_list(), b_shape): + if dim == b_dim: + multiples.append(1) + else: + assert dim == 1 + multiples.append(b_dim) + if any(multiple != 1 for multiple in multiples): + b_value = tf.tile(value, multiples) + else: + b_value = value + b_values.append(b_value) + return tf.concat(b_values, axis=axis) + + +def sigmoid_kl_with_logits(logits, targets): + # broadcasts the same target value across the whole batch + # this is implemented so awkwardly because tensorflow lacks an x log x op + assert isinstance(targets, float) + if targets in [0., 1.]: + entropy = 0. + else: + entropy = - targets * np.log(targets) - (1. - targets) * np.log(1. - targets) + return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits) * targets) - entropy + + +def spectral_normed_weight(W, u=None, num_iters=1): + SPECTRAL_NORMALIZATION_VARIABLES = 'spectral_normalization_variables' + + # Usually num_iters = 1 will be enough + W_shape = W.shape.as_list() + W_reshaped = tf.reshape(W, [-1, W_shape[-1]]) + if u is None: + u = tf.get_variable("u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False) + + def l2normalize(v, eps=1e-12): + return v / (tf.norm(v) + eps) + + def power_iteration(i, u_i, v_i): + v_ip1 = l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped))) + u_ip1 = l2normalize(tf.matmul(v_ip1, W_reshaped)) + return i + 1, u_ip1, v_ip1 + _, u_final, v_final = tf.while_loop( + cond=lambda i, _1, _2: i < num_iters, + body=power_iteration, + loop_vars=(tf.constant(0, dtype=tf.int32), + u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]])) + ) + sigma = tf.squeeze(tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))) + W_bar_reshaped = W_reshaped / sigma + W_bar = tf.reshape(W_bar_reshaped, W_shape) + + if u not in tf.get_collection(SPECTRAL_NORMALIZATION_VARIABLES): + tf.add_to_collection(SPECTRAL_NORMALIZATION_VARIABLES, u) + tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u.assign(u_final)) + return W_bar + + +def get_activation_layer(layer_type): + if layer_type == 'relu': + layer = tf.nn.relu + elif layer_type == 'elu': + layer = tf.nn.elu + else: + raise ValueError('Invalid activation layer %s' % layer_type) + return layer + + +def get_norm_layer(layer_type): + if layer_type == 'batch': + layer = tf.layers.batch_normalization + elif layer_type == 'layer': + layer = tf.contrib.layers.layer_norm + elif layer_type == 'instance': + from video_prediction.layers import fused_instance_norm + layer = fused_instance_norm + elif layer_type == 'none': + layer = tf.identity + else: + raise ValueError('Invalid normalization layer %s' % layer_type) + return layer + + +def get_upsample_layer(layer_type): + if layer_type == 'deconv2d': + layer = deconv2d + elif layer_type == 'upsample_conv2d': + layer = upsample_conv2d + elif layer_type == 'upsample_conv2d_v2': + layer = upsample_conv2d_v2 + else: + raise ValueError('Invalid upsampling layer %s' % layer_type) + return layer + + +def get_downsample_layer(layer_type): + if layer_type == 'conv2d': + layer = conv2d + elif layer_type == 'conv_pool2d': + layer = conv_pool2d + elif layer_type == 'conv_pool2d_v2': + layer = conv_pool2d_v2 + else: + raise ValueError('Invalid downsampling layer %s' % layer_type) + return layer diff --git a/video_prediction/rnn_ops.py b/video_prediction/rnn_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..970fae3ec656dd3e677e906602f0d29d6650b2c7 --- /dev/null +++ b/video_prediction/rnn_ops.py @@ -0,0 +1,267 @@ +# Copyright 2016 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Convolutional LSTM implementation.""" + +import tensorflow as tf +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import variable_scope as vs + + +class BasicConv2DLSTMCell(rnn_cell_impl.RNNCell): + """2D Convolutional LSTM cell with (optional) normalization and recurrent dropout. + + The implementation is based on: tf.contrib.rnn.LayerNormBasicLSTMCell. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + """ + def __init__(self, input_shape, filters, kernel_size, + forget_bias=1.0, activation_fn=math_ops.tanh, + normalizer_fn=None, separate_norms=True, + norm_gain=1.0, norm_shift=0.0, + dropout_keep_prob=1.0, dropout_prob_seed=None, + skip_connection=False, reuse=None): + """Initializes the basic convolutional LSTM cell. + + Args: + input_shape: int tuple, Shape of the input, excluding the batch size. + filters: int, The number of filters of the conv LSTM cell. + kernel_size: int tuple, The kernel size of the conv LSTM cell. + forget_bias: float, The bias added to forget gates (see above). + activation_fn: Activation function of the inner states. + normalizer_fn: If specified, this normalization will be applied before the + internal nonlinearities. + separate_norms: If set to `False`, the normalizer_fn is applied to the + concatenated tensor that follows the convolution, i.e. before splitting + the tensor. This case is slightly faster but it might be functionally + different, depending on the normalizer_fn (it's functionally the same + for instance norm but not for layer norm). Default: `True`. + norm_gain: float, The layer normalization gain initial value. If + `normalizer_fn` is `None`, this argument will be ignored. + norm_shift: float, The layer normalization shift initial value. If + `normalizer_fn` is `None`, this argument will be ignored. + dropout_keep_prob: unit Tensor or float between 0 and 1 representing the + recurrent dropout probability value. If float and 1.0, no dropout will + be applied. + dropout_prob_seed: (optional) integer, the randomness seed. + skip_connection: If set to `True`, concatenate the input to the + output of the conv LSTM. Default: `False`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + """ + super(BasicConv2DLSTMCell, self).__init__(_reuse=reuse) + + self._input_shape = input_shape + self._filters = filters + self._kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + self._forget_bias = forget_bias + self._activation_fn = activation_fn + self._normalizer_fn = normalizer_fn + self._separate_norms = separate_norms + self._g = norm_gain + self._b = norm_shift + self._keep_prob = dropout_keep_prob + self._seed = dropout_prob_seed + self._skip_connection = skip_connection + self._reuse = reuse + + if self._skip_connection: + output_channels = self._filters + self._input_shape[-1] + else: + output_channels = self._filters + cell_size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._filters]) + self._output_size = tensor_shape.TensorShape(self._input_shape[:-1] + [output_channels]) + self._state_size = rnn_cell_impl.LSTMStateTuple(cell_size, self._output_size) + + @property + def output_size(self): + return self._output_size + + @property + def state_size(self): + return self._state_size + + def _norm(self, inputs, scope): + shape = inputs.get_shape()[-1:] + gamma_init = init_ops.constant_initializer(self._g) + beta_init = init_ops.constant_initializer(self._b) + with vs.variable_scope(scope): + # Initialize beta and gamma for use by normalizer. + vs.get_variable("gamma", shape=shape, initializer=gamma_init) + vs.get_variable("beta", shape=shape, initializer=beta_init) + normalized = self._normalizer_fn(inputs, reuse=True, scope=scope) + return normalized + + def _conv2d(self, inputs): + output_filters = 4 * self._filters + input_shape = inputs.get_shape().as_list() + kernel_shape = list(self._kernel_size) + [input_shape[-1], output_filters] + kernel = vs.get_variable("kernel", kernel_shape, dtype=dtypes.float32, + initializer=init_ops.truncated_normal_initializer(stddev=0.02)) + outputs = nn_ops.conv2d(inputs, kernel, [1] * 4, padding='SAME') + if not self._normalizer_fn: + bias = vs.get_variable('bias', [output_filters], dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + outputs = nn_ops.bias_add(outputs, bias) + return outputs + + def _dense(self, inputs): + num_units = 4 * self._filters + input_shape = inputs.shape.as_list() + kernel_shape = [input_shape[-1], num_units] + kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32, + initializer=init_ops.truncated_normal_initializer(stddev=0.02)) + outputs = tf.matmul(inputs, kernel) + return outputs + + def call(self, inputs, state): + """2D Convolutional LSTM cell with (optional) normalization and recurrent dropout.""" + c, h = state + tile_concat = isinstance(inputs, (list, tuple)) + if tile_concat: + inputs, inputs_non_spatial = inputs + args = array_ops.concat([inputs, h], -1) + concat = self._conv2d(args) + if tile_concat: + concat = concat + self._dense(inputs_non_spatial)[:, None, None, :] + + if self._normalizer_fn and not self._separate_norms: + concat = self._norm(concat, "input_transform_forget_output") + i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=-1) + if self._normalizer_fn and self._separate_norms: + i = self._norm(i, "input") + j = self._norm(j, "transform") + f = self._norm(f, "forget") + o = self._norm(o, "output") + + g = self._activation_fn(j) + if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: + g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) + + new_c = (c * math_ops.sigmoid(f + self._forget_bias) + + math_ops.sigmoid(i) * g) + if self._normalizer_fn: + new_c = self._norm(new_c, "state") + new_h = self._activation_fn(new_c) * math_ops.sigmoid(o) + + if self._skip_connection: + new_h = array_ops.concat([new_h, inputs], axis=-1) + + new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) + return new_h, new_state + + +class Conv2DGRUCell(tf.nn.rnn_cell.RNNCell): + """2D Convolutional GRU cell with (optional) normalization. + + Modified from these: + https://github.com/carlthome/tensorflow-convlstm-cell/blob/master/cell.py + https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn_cell_impl.py + """ + def __init__(self, input_shape, filters, kernel_size, + activation_fn=tf.tanh, + normalizer_fn=None, separate_norms=True, + bias_initializer=None, reuse=None): + super(Conv2DGRUCell, self).__init__(_reuse=reuse) + self._input_shape = input_shape + self._filters = filters + self._kernel_size = list(kernel_size) if isinstance(kernel_size, (tuple, list)) else [kernel_size] * 2 + self._activation_fn = activation_fn + self._normalizer_fn = normalizer_fn + self._separate_norms = separate_norms + self._bias_initializer = bias_initializer + self._size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._filters]) + + @property + def state_size(self): + return self._size + + @property + def output_size(self): + return self._size + + def _norm(self, inputs, scope, bias_initializer): + shape = inputs.get_shape()[-1:] + gamma_init = init_ops.ones_initializer() + beta_init = bias_initializer + with vs.variable_scope(scope): + # Initialize beta and gamma for use by normalizer. + vs.get_variable("gamma", shape=shape, initializer=gamma_init) + vs.get_variable("beta", shape=shape, initializer=beta_init) + normalized = self._normalizer_fn(inputs, reuse=True, scope=scope) + return normalized + + def _conv2d(self, inputs, output_filters, bias_initializer): + input_shape = inputs.get_shape().as_list() + kernel_shape = list(self._kernel_size) + [input_shape[-1], output_filters] + kernel = vs.get_variable("kernel", kernel_shape, dtype=dtypes.float32, + initializer=init_ops.truncated_normal_initializer(stddev=0.02)) + outputs = nn_ops.conv2d(inputs, kernel, [1] * 4, padding='SAME') + if not self._normalizer_fn: + bias = vs.get_variable('bias', [output_filters], dtype=dtypes.float32, + initializer=bias_initializer) + outputs = nn_ops.bias_add(outputs, bias) + return outputs + + def _dense(self, inputs, num_units): + input_shape = inputs.shape.as_list() + kernel_shape = [input_shape[-1], num_units] + kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32, + initializer=init_ops.truncated_normal_initializer(stddev=0.02)) + outputs = tf.matmul(inputs, kernel) + return outputs + + def call(self, inputs, state): + bias_ones = self._bias_initializer + if self._bias_initializer is None: + bias_ones = init_ops.ones_initializer() + tile_concat = isinstance(inputs, (list, tuple)) + if tile_concat: + inputs, inputs_non_spatial = inputs + with vs.variable_scope('gates'): + inputs = array_ops.concat([inputs, state], axis=-1) + concat = self._conv2d(inputs, 2 * self._filters, bias_ones) + if tile_concat: + concat = concat + self._dense(inputs_non_spatial, concat.shape[-1].value)[:, None, None, :] + if self._normalizer_fn and not self._separate_norms: + concat = self._norm(concat, "reset_update", bias_ones) + r, u = array_ops.split(concat, 2, axis=-1) + if self._normalizer_fn and self._separate_norms: + r = self._norm(r, "reset", bias_ones) + u = self._norm(u, "update", bias_ones) + r, u = math_ops.sigmoid(r), math_ops.sigmoid(u) + + bias_zeros = self._bias_initializer + if self._bias_initializer is None: + bias_zeros = init_ops.zeros_initializer() + with vs.variable_scope('candidate'): + inputs = array_ops.concat([inputs, r * state], axis=-1) + candidate = self._conv2d(inputs, self._filters, bias_zeros) + if tile_concat: + candidate = candidate + self._dense(inputs_non_spatial, candidate.shape[-1].value)[:, None, None, :] + if self._normalizer_fn: + candidate = self._norm(candidate, "state", bias_zeros) + + c = self._activation_fn(candidate) + new_h = u * state + (1 - u) * c + return new_h, new_h diff --git a/video_prediction/utils/__init__.py b/video_prediction/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/video_prediction/utils/ffmpeg_gif.py b/video_prediction/utils/ffmpeg_gif.py new file mode 100644 index 0000000000000000000000000000000000000000..724933f7e1057e2e0b6174401021e6941bda5959 --- /dev/null +++ b/video_prediction/utils/ffmpeg_gif.py @@ -0,0 +1,95 @@ +import os + +import numpy as np + + +def save_gif(gif_fname, images, fps): + """ + To generate a gif from image files, first generate palette from images + and then generate the gif from the images and the palette. + ffmpeg -i input_%02d.jpg -vf palettegen -y palette.png + ffmpeg -i input_%02d.jpg -i palette.png -lavfi paletteuse -y output.gif + + Alternatively, use a filter to map the input images to both the palette + and gif commands, while also passing the palette to the gif command. + ffmpeg -i input_%02d.jpg -filter_complex "[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse" -y output.gif + + To directly pass in numpy images, use rawvideo format and `-i -` option. + """ + from subprocess import Popen, PIPE + head, tail = os.path.split(gif_fname) + if head and not os.path.exists(head): + os.makedirs(head) + h, w, c = images[0].shape + cmd = ['ffmpeg', '-y', + '-f', 'rawvideo', + '-vcodec', 'rawvideo', + '-r', '%.02f' % fps, + '-s', '%dx%d' % (w, h), + '-pix_fmt', {1: 'gray', 3: 'rgb24', 4: 'rgba'}[c], + '-i', '-', + '-filter_complex', '[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse', + '-r', '%.02f' % fps, + '%s' % gif_fname] + proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) + for image in images: + proc.stdin.write(image.tostring()) + out, err = proc.communicate() + if proc.returncode: + err = '\n'.join([' '.join(cmd), err.decode('utf8')]) + raise IOError(err) + del proc + + +def encode_gif(images, fps): + """Encodes numpy images into gif string. + Args: + images: A 5-D `uint8` `np.array` (or a list of 4-D images) of shape + `[batch_size, time, height, width, channels]` where `channels` is 1 or 3. + fps: frames per second of the animation + Returns: + The encoded gif string. + Raises: + IOError: If the ffmpeg command returns an error. + """ + from subprocess import Popen, PIPE + h, w, c = images[0].shape + cmd = ['ffmpeg', '-y', + '-f', 'rawvideo', + '-vcodec', 'rawvideo', + '-r', '%.02f' % fps, + '-s', '%dx%d' % (w, h), + '-pix_fmt', {1: 'gray', 3: 'rgb24'}[c], + '-i', '-', + '-filter_complex', '[0:v]split[x][z];[z]palettegen[y];[x][y]paletteuse', + '-r', '%.02f' % fps, + '-f', 'gif', + '-'] + proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) + for image in images: + proc.stdin.write(image.tostring()) + out, err = proc.communicate() + if proc.returncode: + err = '\n'.join([' '.join(cmd), err.decode('utf8')]) + raise IOError(err) + del proc + return out + + +def main(): + images_shape = (12, 64, 64, 3) # num_frames, height, width, channels + images = np.random.randint(256, size=images_shape).astype(np.uint8) + + save_gif('output_save.gif', images, 4) + with open('output_save.gif', 'rb') as f: + string_save = f.read() + + string_encode = encode_gif(images, 4) + with open('output_encode.gif', 'wb') as f: + f.write(string_encode) + + print(np.all(string_save == string_encode)) + + +if __name__ == '__main__': + main() diff --git a/video_prediction/utils/gif_summary.py b/video_prediction/utils/gif_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..55f89987855c0288de827326107c9d72abc8ba6c --- /dev/null +++ b/video_prediction/utils/gif_summary.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.python.ops import summary_op_util +#from tensorflow.python.distribute.summary_op_util import skip_summary TODO: IMPORT ERRORS IN juwels +from video_prediction.utils import ffmpeg_gif + + +def py_gif_summary(tag, images, max_outputs, fps): + """Outputs a `Summary` protocol buffer with gif animations. + Args: + tag: Name of the summary. + images: A 5-D `uint8` `np.array` of shape `[batch_size, time, height, width, + channels]` where `channels` is 1 or 3. + max_outputs: Max number of batch elements to generate gifs for. + fps: frames per second of the animation + Returns: + The serialized `Summary` protocol buffer. + Raises: + ValueError: If `images` is not a 5-D `uint8` array with 1 or 3 channels. + """ + is_bytes = isinstance(tag, bytes) + if is_bytes: + tag = tag.decode("utf-8") + images = np.asarray(images) + if images.dtype != np.uint8: + raise ValueError("Tensor must have dtype uint8 for gif summary.") + if images.ndim != 5: + raise ValueError("Tensor must be 5-D for gif summary.") + batch_size, _, height, width, channels = images.shape + if channels not in (1, 3): + raise ValueError("Tensors must have 1 or 3 channels for gif summary.") + + summ = tf.Summary() + num_outputs = min(batch_size, max_outputs) + for i in range(num_outputs): + image_summ = tf.Summary.Image() + image_summ.height = height + image_summ.width = width + image_summ.colorspace = channels # 1: grayscale, 3: RGB + try: + image_summ.encoded_image_string = ffmpeg_gif.encode_gif(images[i], fps) + except (IOError, OSError) as e: + tf.logging.warning( + "Unable to encode images to a gif string because either ffmpeg is " + "not installed or ffmpeg returned an error: %s. Falling back to an " + "image summary of the first frame in the sequence.", e) + try: + from PIL import Image # pylint: disable=g-import-not-at-top + import io # pylint: disable=g-import-not-at-top + with io.BytesIO() as output: + Image.fromarray(images[i][0]).save(output, "PNG") + image_summ.encoded_image_string = output.getvalue() + except: + tf.logging.warning( + "Gif summaries requires ffmpeg or PIL to be installed: %s", e) + image_summ.encoded_image_string = "".encode('utf-8') if is_bytes else "" + if num_outputs == 1: + summ_tag = "{}/gif".format(tag) + else: + summ_tag = "{}/gif/{}".format(tag, i) + summ.value.add(tag=summ_tag, image=image_summ) + summ_str = summ.SerializeToString() + return summ_str + + +def gif_summary(name, tensor, max_outputs=3, fps=10, collections=None, + family=None): + """Outputs a `Summary` protocol buffer with gif animations. + Args: + name: Name of the summary. + tensor: A 5-D `uint8` `Tensor` of shape `[batch_size, time, height, width, + channels]` where `channels` is 1 or 3. + max_outputs: Max number of batch elements to generate gifs for. + fps: frames per second of the animation + collections: Optional list of tf.GraphKeys. The collections to add the + summary to. Defaults to [tf.GraphKeys.SUMMARIES] + family: Optional; if provided, used as the prefix of the summary tag name, + which controls the tab name used for display on Tensorboard. + Returns: + A scalar `Tensor` of type `string`. The serialized `Summary` protocol + buffer. + """ + tensor = tf.convert_to_tensor(tensor) + # if skip_summary(): TODO: skipo summary errors happend in JUEWLS + # return tf.constant("") + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + val = tf.py_func( + py_gif_summary, + [tag, tensor, max_outputs, fps], + tf.string, + stateful=False, + name=scope) + summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES]) + return val diff --git a/video_prediction/utils/html.py b/video_prediction/utils/html.py new file mode 100755 index 0000000000000000000000000000000000000000..4334f87b56446c238d5fd15a34ecebef889e003b --- /dev/null +++ b/video_prediction/utils/html.py @@ -0,0 +1,105 @@ +import os + +import dominate +from dominate.tags import * + + +class HTML: + def __init__(self, web_dir, title, reflesh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + # print(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + self.t = None + + def get_image_dir(self): + return self.img_dir + + def add_header1(self, str): + with self.doc: + h1(str) + + def add_header2(self, str): + with self.doc: + h2(str) + + def add_header3(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_row(self, txts, colspans=None): + if self.t is None: + self.add_table() + with self.t: + with tr(): + if colspans: + assert len(txts) == len(colspans) + colspans = [dict(colspan=str(colspan)) for colspan in colspans] + else: + colspans = [dict()] * len(txts) + for txt, colspan in zip(txts, colspans): + style = "word-break: break-all;" if len(str(txt)) > 80 else "word-wrap: break-word;" + with td(style=style, halign="center", valign="top", **colspan): + with p(): + if txt is not None: + p(txt) + + def add_images(self, ims, txts, links, colspans=None, height=None, width=400): + image_style = '' + if height is not None: + image_style += "height:%dpx;" % height + if width is not None: + image_style += "width:%dpx;" % width + if self.t is None: + self.add_table() + with self.t: + with tr(): + if colspans: + assert len(txts) == len(colspans) + colspans = [dict(colspan=str(colspan)) for colspan in colspans] + else: + colspans = [dict()] * len(txts) + for im, txt, link, colspan in zip(ims, txts, links, colspans): + with td(style="word-wrap: break-word;", halign="center", valign="top", **colspan): + with p(): + if im is not None and link is not None: + with a(href=os.path.join('images', link)): + img(style=image_style, src=os.path.join('images', im)) + if im is not None and link is not None and txt is not None: + br() + if txt is not None: + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.jpg' % n) + txts.append('text_%d' % n) + links.append('image_%d.jpg' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/video_prediction/utils/mcnet_utils.py b/video_prediction/utils/mcnet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7bad0131218f02e96d0e39132bc0d677547041da --- /dev/null +++ b/video_prediction/utils/mcnet_utils.py @@ -0,0 +1,156 @@ +""" +Some codes from https://github.com/Newmu/dcgan_code +""" + +import cv2 +import random +import imageio +import scipy.misc +import numpy as np + + +def transform(image): + return image/127.5 - 1. + + +def inverse_transform(images): + return (images+1.)/2. + + +def save_images(images, size, image_path): + return imsave(inverse_transform(images)*255., size, image_path) + + +def merge(images, size): + h, w = images.shape[1], images.shape[2] + img = np.zeros((h * size[0], w * size[1], 3)) + + for idx, image in enumerate(images): + i = idx % size[1] + j = idx / size[1] + img[j*h:j*h+h, i*w:i*w+w, :] = image + + return img + + +def imsave(images, size, path): + return scipy.misc.imsave(path, merge(images, size)) + + +def get_minibatches_idx(n, minibatch_size, shuffle=False): + """ + Used to shuffle the dataset at each iteration. + """ + idx_list = np.arange(n, dtype="int32") + + if shuffle: + random.shuffle(idx_list) + + minibatches = [] + minibatch_start = 0 + for i in range(n // minibatch_size): + minibatches.append(idx_list[minibatch_start:minibatch_start + minibatch_size]) + minibatch_start += minibatch_size + + if (minibatch_start != n): + # Make a minibatch out of what is left + minibatches.append(idx_list[minibatch_start:]) + + return zip(range(len(minibatches)), minibatches) + + +def draw_frame(img, is_input): + if img.shape[2] == 1: + img = np.repeat(img, [3], axis=2) + if is_input: + img[:2,:,0] = img[:2,:,2] = 0 + img[:,:2,0] = img[:,:2,2] = 0 + img[-2:,:,0] = img[-2:,:,2] = 0 + img[:,-2:,0] = img[:,-2:,2] = 0 + img[:2,:,1] = 255 + img[:,:2,1] = 255 + img[-2:,:,1] = 255 + img[:,-2:,1] = 255 + else: + img[:2,:,0] = img[:2,:,1] = 0 + img[:,:2,0] = img[:,:2,2] = 0 + img[-2:,:,0] = img[-2:,:,1] = 0 + img[:,-2:,0] = img[:,-2:,1] = 0 + img[:2,:,2] = 255 + img[:,:2,2] = 255 + img[-2:,:,2] = 255 + img[:,-2:,2] = 255 + + return img + + +def load_kth_data(f_name, data_path, image_size, K, T): + flip = np.random.binomial(1,.5,1)[0] + tokens = f_name.split() + vid_path = data_path + tokens[0] + "_uncomp.avi" + vid = imageio.get_reader(vid_path,"ffmpeg") + low = int(tokens[1]) + high = np.min([int(tokens[2]),vid.get_length()])-K-T+1 + if low == high: + stidx = 0 + else: + if low >= high: print(vid_path) + stidx = np.random.randint(low=low, high=high) + seq = np.zeros((image_size, image_size, K+T, 1), dtype="float32") + for t in xrange(K+T): + img = cv2.cvtColor(cv2.resize(vid.get_data(stidx+t), + (image_size,image_size)), + cv2.COLOR_RGB2GRAY) + seq[:,:,t] = transform(img[:,:,None]) + + if flip == 1: + seq = seq[:,::-1] + + diff = np.zeros((image_size, image_size, K-1, 1), dtype="float32") + for t in xrange(1,K): + prev = inverse_transform(seq[:,:,t-1]) + next = inverse_transform(seq[:,:,t]) + diff[:,:,t-1] = next.astype("float32")-prev.astype("float32") + + return seq, diff + + +def load_s1m_data(f_name, data_path, trainlist, K, T): + flip = np.random.binomial(1,.5,1)[0] + vid_path = data_path + f_name + img_size = [240,320] + + while True: + try: + vid = imageio.get_reader(vid_path,"ffmpeg") + low = 1 + high = vid.get_length()-K-T+1 + if low == high: + stidx = 0 + else: + stidx = np.random.randint(low=low, high=high) + seq = np.zeros((img_size[0], img_size[1], K+T, 3), + dtype="float32") + for t in xrange(K+T): + img = cv2.resize(vid.get_data(stidx+t), + (img_size[1],img_size[0]))[:,:,::-1] + seq[:,:,t] = transform(img) + + if flip == 1:seq = seq[:,::-1] + + diff = np.zeros((img_size[0], img_size[1], K-1, 1), + dtype="float32") + for t in xrange(1,K): + prev = inverse_transform(seq[:,:,t-1])*255 + prev = cv2.cvtColor(prev.astype("uint8"),cv2.COLOR_BGR2GRAY) + next = inverse_transform(seq[:,:,t])*255 + next = cv2.cvtColor(next.astype("uint8"),cv2.COLOR_BGR2GRAY) + diff[:,:,t-1,0] = (next.astype("float32")-prev.astype("float32"))/255. + break + except Exception: + # In case the current video is bad load a random one + rep_idx = np.random.randint(low=0, high=len(trainlist)) + f_name = trainlist[rep_idx] + vid_path = data_path + f_name + + return seq, diff diff --git a/video_prediction/utils/tf_utils.py b/video_prediction/utils/tf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51e49a54b11fb07833f0ef33278d2e894b905afc --- /dev/null +++ b/video_prediction/utils/tf_utils.py @@ -0,0 +1,609 @@ +import itertools +import os +from collections import OrderedDict + +import numpy as np +import six +import tensorflow as tf +import tensorflow.contrib.graph_editor as ge +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import device as pydev +from tensorflow.python.training import device_setter +from tensorflow.python.util import nest + +from video_prediction.utils import ffmpeg_gif +from video_prediction.utils import gif_summary + +IMAGE_SUMMARIES = "image_summaries" +EVAL_SUMMARIES = "eval_summaries" + + +def local_device_setter(num_devices=1, + ps_device_type='cpu', + worker_device='/cpu:0', + ps_ops=None, + ps_strategy=None): + if ps_ops == None: + ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] + + if ps_strategy is None: + ps_strategy = device_setter._RoundRobinStrategy(num_devices) + if not six.callable(ps_strategy): + raise TypeError("ps_strategy must be callable") + + def _local_device_chooser(op): + current_device = pydev.DeviceSpec.from_string(op.device or "") + + node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def + if node_def.op in ps_ops: + ps_device_spec = pydev.DeviceSpec.from_string( + '/{}:{}'.format(ps_device_type, ps_strategy(op))) + + ps_device_spec.merge_from(current_device) + return ps_device_spec.to_string() + else: + worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "") + worker_device_spec.merge_from(current_device) + return worker_device_spec.to_string() + + return _local_device_chooser + + +def replace_read_ops(loss_or_losses, var_list): + """ + Replaces read ops of each variable in `vars` with new read ops obtained + from `read_value()`, thus forcing to read the most up-to-date values of + the variables (which might incur copies across devices). + The graph is seeded from the tensor(s) `loss_or_losses`. + """ + # ops between var ops and the loss + ops = set(ge.get_walks_intersection_ops([var.op for var in var_list], loss_or_losses)) + if not ops: # loss_or_losses doesn't depend on any var in var_list, so there is nothiing to replace + return + + # filter out variables that are not involved in computing the loss + var_list = [var for var in var_list if var.op in ops] + + for var in var_list: + output, = var.op.outputs + read_ops = set(output.consumers()) & ops + for read_op in read_ops: + with tf.name_scope('/'.join(read_op.name.split('/')[:-1])): + with tf.device(read_op.device): + read_t, = read_op.outputs + consumer_ops = set(read_t.consumers()) & ops + # consumer_sgv might have multiple inputs, but we only care + # about replacing the input that is read_t + consumer_sgv = ge.sgv(consumer_ops) + consumer_sgv = consumer_sgv.remap_inputs([list(consumer_sgv.inputs).index(read_t)]) + ge.connect(ge.sgv(var.read_value().op), consumer_sgv) + + +def print_loss_info(losses, *tensors): + def get_descendants(tensor, tensors): + descendants = [] + for child in tensor.op.inputs: + if child in tensors: + descendants.append(child) + else: + descendants.extend(get_descendants(child, tensors)) + return descendants + + name_to_tensors = itertools.chain(*[tensor.items() for tensor in tensors]) + tensor_to_names = OrderedDict([(v, k) for k, v in name_to_tensors]) + + print(tf.get_default_graph().get_name_scope()) + for name, (loss, weight) in losses.items(): + print(' %s (%r)' % (name, weight)) + descendant_names = [] + for descendant in set(get_descendants(loss, tensor_to_names.keys())): + descendant_names.append(tensor_to_names[descendant]) + for descendant_name in sorted(descendant_names): + print(' %s' % descendant_name) + + +def with_flat_batch(flat_batch_fn, ndims=4): + def fn(x, *args, **kwargs): + shape = tf.shape(x) + flat_batch_shape = tf.concat([[-1], shape[-(ndims-1):]], axis=0) + flat_batch_shape.set_shape([ndims]) + flat_batch_x = tf.reshape(x, flat_batch_shape) + flat_batch_r = flat_batch_fn(flat_batch_x, *args, **kwargs) + r = nest.map_structure(lambda x: tf.reshape(x, tf.concat([shape[:-(ndims-1)], tf.shape(x)[1:]], axis=0)), + flat_batch_r) + return r + return fn + + +def transpose_batch_time(x): + if isinstance(x, tf.Tensor) and x.shape.ndims >= 2: + return tf.transpose(x, [1, 0] + list(range(2, x.shape.ndims))) + else: + return x + + +def dimension(inputs, axis=0): + shapes = [input_.shape for input_ in nest.flatten(inputs)] + s = tf.TensorShape([None]) + for shape in shapes: + s = s.merge_with(shape[axis:axis + 1]) + dim = s[0].value + return dim + + +def unroll_rnn(cell, inputs, scope=None, use_dynamic_rnn=True): + """Chooses between dynamic_rnn and static_rnn if the leading time dimension is dynamic or not.""" + dim = dimension(inputs, axis=0) + if use_dynamic_rnn or dim is None: + return tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32, + swap_memory=False, time_major=True, scope=scope) + else: + return static_rnn(cell, inputs, scope=scope) + + +def static_rnn(cell, inputs, scope=None): + """Simple version of static_rnn.""" + with tf.variable_scope(scope or "rnn") as varscope: + batch_size = dimension(inputs, axis=1) + state = cell.zero_state(batch_size, tf.float32) + flat_inputs = nest.flatten(inputs) + flat_inputs = list(zip(*[tf.unstack(flat_input, axis=0) for flat_input in flat_inputs])) + flat_outputs = [] + for time, flat_input in enumerate(flat_inputs): + if time > 0: + varscope.reuse_variables() + input_ = nest.pack_sequence_as(inputs, flat_input) + output, state = cell(input_, state) + flat_output = nest.flatten(output) + flat_outputs.append(flat_output) + flat_outputs = [tf.stack(flat_output, axis=0) for flat_output in zip(*flat_outputs)] + outputs = nest.pack_sequence_as(output, flat_outputs) + return outputs, state + + +def maybe_pad_or_slice(tensor, desired_length): + length = tensor.shape.as_list()[0] + if length < desired_length: + paddings = [[0, desired_length - length]] + [[0, 0]] * (tensor.shape.ndims - 1) + tensor = tf.pad(tensor, paddings) + elif length > desired_length: + tensor = tensor[:desired_length] + assert tensor.shape.as_list()[0] == desired_length + return tensor + + +def tensor_to_clip(tensor): + if tensor.shape.ndims == 6: + # concatenate last dimension vertically + tensor = tf.concat(tf.unstack(tensor, axis=-1), axis=-3) + if tensor.shape.ndims == 5: + # concatenate batch dimension horizontally + tensor = tf.concat(tf.unstack(tensor, axis=0), axis=2) + if tensor.shape.ndims == 4: + # keep up to the first 3 channels + tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True) + else: + raise NotImplementedError + return tensor + + +def tensor_to_image_batch(tensor): + if tensor.shape.ndims == 6: + # concatenate last dimension vertically + tensor= tf.concat(tf.unstack(tensor, axis=-1), axis=-3) + if tensor.shape.ndims == 5: + # concatenate time dimension horizontally + tensor = tf.concat(tf.unstack(tensor, axis=1), axis=2) + if tensor.shape.ndims == 4: + # keep up to the first 3 channels + tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True) + else: + raise NotImplementedError + return tensor + + +def _as_name_scope_map(values): + name_scope_to_values = {} + for name, value in values.items(): + name_scope = name.split('/')[0] + name_scope_to_values.setdefault(name_scope, {}) + name_scope_to_values[name_scope][name] = value + return name_scope_to_values + + +def add_image_summaries(outputs, max_outputs=8, collections=None): + if collections is None: + collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES] + for name_scope, outputs in _as_name_scope_map(outputs).items(): + with tf.name_scope(name_scope): + for name, output in outputs.items(): + if max_outputs: + output = output[:max_outputs] + output = tensor_to_image_batch(output) + if output.shape[-1] not in (1, 3): + # these are feature maps, so just skip them + continue + tf.summary.image(name, output, collections=collections) + + +def add_gif_summaries(outputs, max_outputs=8, collections=None): + if collections is None: + collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES] + for name_scope, outputs in _as_name_scope_map(outputs).items(): + with tf.name_scope(name_scope): + for name, output in outputs.items(): + if max_outputs: + output = output[:max_outputs] + output = tensor_to_clip(output) + if output.shape[-1] not in (1, 3): + # these are feature maps, so just skip them + continue + gif_summary.gif_summary(name, output[None], fps=4, collections=collections) + + +def add_scalar_summaries(losses_or_metrics, collections=None): + for name_scope, losses_or_metrics in _as_name_scope_map(losses_or_metrics).items(): + with tf.name_scope(name_scope): + for name, loss_or_metric in losses_or_metrics.items(): + if isinstance(loss_or_metric, tuple): + loss_or_metric, _ = loss_or_metric + tf.summary.scalar(name, loss_or_metric, collections=collections) + + +def add_summaries(outputs, collections=None): + scalar_outputs = OrderedDict() + image_outputs = OrderedDict() + gif_outputs = OrderedDict() + for name, output in outputs.items(): + if not isinstance(output, tf.Tensor): + continue + if output.shape.ndims == 0: + scalar_outputs[name] = output + elif output.shape.ndims == 4: + image_outputs[name] = output + elif output.shape.ndims > 4 and output.shape[4].value in (1, 3): + gif_outputs[name] = output + add_scalar_summaries(scalar_outputs, collections=collections) + add_image_summaries(image_outputs, collections=collections) + add_gif_summaries(gif_outputs, collections=collections) + + +def plot_buf(y): + def _plot_buf(y): + from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + from matplotlib.figure import Figure + import io + fig = Figure(figsize=(3, 3)) + canvas = FigureCanvas(fig) + ax = fig.add_subplot(111) + ax.plot(y) + ax.grid(axis='y') + fig.tight_layout(pad=0) + + buf = io.BytesIO() + fig.savefig(buf, format='png') + buf.seek(0) + return buf.getvalue() + + s = tf.py_func(_plot_buf, [y], tf.string) + return s + + +def add_plot_image_summaries(metrics, collections=None): + if collections is None: + collections = [IMAGE_SUMMARIES] + for name_scope, metrics in _as_name_scope_map(metrics).items(): + with tf.name_scope(name_scope): + for name, metric in metrics.items(): + try: + buf = plot_buf(metric) + except: + continue + image = tf.image.decode_png(buf, channels=4) + image = tf.expand_dims(image, axis=0) + tf.summary.image(name, image, max_outputs=1, collections=collections) + + +def plot_summary(name, x, y, display_name=None, description=None, collections=None): + """ + Hack that uses pr_curve summaries for 2D plots. + + Args: + x: 1-D tensor with values in increasing order. + y: 1-D tensor with static shape. + + Note: tensorboard needs to be modified and compiled from source to disable + default axis range [-0.05, 1.05]. + """ + from tensorboard import summary as summary_lib + x = tf.convert_to_tensor(x) + y = tf.convert_to_tensor(y) + with tf.control_dependencies([ + tf.assert_equal(tf.shape(x), tf.shape(y)), + tf.assert_equal(y.shape.ndims, 1), + ]): + y = tf.identity(y) + num_thresholds = y.shape[0].value + if num_thresholds is None: + raise ValueError('Size of y needs to be statically defined for num_thresholds argument') + summary = summary_lib.pr_curve_raw_data_op( + name, + true_positive_counts=tf.ones(num_thresholds), + false_positive_counts=tf.ones(num_thresholds), + true_negative_counts=tf.ones(num_thresholds), + false_negative_counts=tf.ones(num_thresholds), + precision=y[::-1], + recall=x[::-1], + num_thresholds=num_thresholds, + display_name=display_name, + description=description, + collections=collections) + return summary + + +def add_plot_summaries(metrics, x_offset=0, collections=None): + for name_scope, metrics in _as_name_scope_map(metrics).items(): + with tf.name_scope(name_scope): + for name, metric in metrics.items(): + plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections) + + +def add_plot_and_scalar_summaries(metrics, x_offset=0, collections=None): + for name_scope, metrics in _as_name_scope_map(metrics).items(): + with tf.name_scope(name_scope): + for name, metric in metrics.items(): + tf.summary.scalar(name, tf.reduce_mean(metric), collections=collections) + plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections) + + +def convert_tensor_to_gif_summary(summ): + if isinstance(summ, bytes): + summary_proto = tf.Summary() + summary_proto.ParseFromString(summ) + summ = summary_proto + + summary = tf.Summary() + for value in summ.value: + tag = value.tag + try: + images_arr = tf.make_ndarray(value.tensor) + except TypeError: + summary.value.add(tag=tag, image=value.image) + continue + + if len(images_arr.shape) == 5: + images_arr = np.concatenate(list(images_arr), axis=-2) + if len(images_arr.shape) != 4: + raise ValueError('Tensors must be 4-D or 5-D for gif summary.') + channels = images_arr.shape[-1] + if channels < 1 or channels > 4: + raise ValueError('Tensors must have 1, 2, 3, or 4 color channels for gif summary.') + + encoded_image_string = ffmpeg_gif.encode_gif(images_arr, fps=4) + + image = tf.Summary.Image() + image.height = images_arr.shape[-3] + image.width = images_arr.shape[-2] + image.colorspace = channels # 1: grayscale, 2: grayscale + alpha, 3: RGB, 4: RGBA + image.encoded_image_string = encoded_image_string + summary.value.add(tag=tag, image=image) + return summary + + +def compute_averaged_gradients(opt, tower_loss, **kwargs): + tower_gradvars = [] + for loss in tower_loss: + with tf.device(loss.device): + gradvars = opt.compute_gradients(loss, **kwargs) + tower_gradvars.append(gradvars) + + # Now compute global loss and gradients. + gradvars = [] + with tf.name_scope('gradient_averaging'): + all_grads = {} + for grad, var in itertools.chain(*tower_gradvars): + if grad is not None: + all_grads.setdefault(var, []).append(grad) + for var, grads in all_grads.items(): + # Average gradients on the same device as the variables + # to which they apply. + with tf.device(var.device): + if len(grads) == 1: + avg_grad = grads[0] + else: + avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads)) + gradvars.append((avg_grad, var)) + return gradvars + + +# the next 3 function are from tensorpack: +# https://github.com/tensorpack/tensorpack/blob/master/tensorpack/graph_builder/utils.py +def split_grad_list(grad_list): + """ + Args: + grad_list: K x N x 2 + + Returns: + K x N: gradients + K x N: variables + """ + g = [] + v = [] + for tower in grad_list: + g.append([x[0] for x in tower]) + v.append([x[1] for x in tower]) + return g, v + + +def merge_grad_list(all_grads, all_vars): + """ + Args: + all_grads (K x N): gradients + all_vars(K x N): variables + + Return: + K x N x 2: list of list of (grad, var) pairs + """ + return [list(zip(gs, vs)) for gs, vs in zip(all_grads, all_vars)] + + +def allreduce_grads(all_grads, average): + """ + All-reduce average the gradients among K devices. Results are broadcasted to all devices. + + Args: + all_grads (K x N): List of list of gradients. N is the number of variables. + average (bool): average gradients or not. + + Returns: + K x N: same as input, but each grad is replaced by the average over K devices. + """ + from tensorflow.contrib import nccl + nr_tower = len(all_grads) + if nr_tower == 1: + return all_grads + new_all_grads = [] # N x K + for grads in zip(*all_grads): + summed = nccl.all_sum(grads) + + grads_for_devices = [] # K + for g in summed: + with tf.device(g.device): + # tensorflow/benchmarks didn't average gradients + if average: + g = tf.multiply(g, 1.0 / nr_tower) + grads_for_devices.append(g) + new_all_grads.append(grads_for_devices) + + # transpose to K x N + ret = list(zip(*new_all_grads)) + return ret + + +def _reduce_entries(*entries): + num_gpus = len(entries) + if entries[0] is None: + assert all(entry is None for entry in entries[1:]) + reduced_entry = None + elif isinstance(entries[0], tf.Tensor): + if entries[0].shape.ndims == 0: + reduced_entry = tf.add_n(entries) / tf.to_float(num_gpus) + else: + reduced_entry = tf.concat(entries, axis=0) + elif np.isscalar(entries[0]) or isinstance(entries[0], np.ndarray): + if np.isscalar(entries[0]) or entries[0].ndim == 0: + reduced_entry = sum(entries) / float(num_gpus) + else: + reduced_entry = np.concatenate(entries, axis=0) + elif isinstance(entries[0], tuple) and len(entries[0]) == 2: + losses, weights = zip(*entries) + loss = tf.add_n(losses) / tf.to_float(num_gpus) + if isinstance(weights[0], tf.Tensor): + with tf.control_dependencies([tf.assert_equal(weight, weights[0]) for weight in weights[1:]]): + weight = tf.identity(weights[0]) + else: + assert all(weight == weights[0] for weight in weights[1:]) + weight = weights[0] + reduced_entry = (loss, weight) + else: + raise NotImplementedError + return reduced_entry + + +def reduce_tensors(structures, shallow=False): + if len(structures) == 1: + reduced_structure = structures[0] + else: + if shallow: + if isinstance(structures[0], dict): + shallow_tree = type(structures[0])([(k, None) for k in structures[0]]) + else: + shallow_tree = type(structures[0])([None for _ in structures[0]]) + reduced_structure = nest.map_structure_up_to(shallow_tree, _reduce_entries, *structures) + else: + reduced_structure = nest.map_structure(_reduce_entries, *structures) + return reduced_structure + + +def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None): + + + if os.path.isdir(checkpoint): + # latest_checkpoint doesn't work when the path has special characters + checkpoint = tf.train.latest_checkpoint(checkpoint) + checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint) + checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys() + restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0]) + if not var_list: + var_list = tf.global_variables() + restore_vars = {restore_to_checkpoint_mapping(var.name, checkpoint_var_names): var for var in var_list} + if skip_global_step and 'global_step' in restore_vars: + del restore_vars['global_step'] + # restore variables that are both in the global graph and in the checkpoint + restore_and_checkpoint_vars = {name: var for name, var in restore_vars.items() if name in checkpoint_var_names} + #restore_saver = tf.train.Saver(max_to_keep=1, var_list=restore_and_checkpoint_vars, filename=checkpoint) + # print out information regarding variables that were not restored or used for restoring + restore_not_in_checkpoint_vars = {name: var for name, var in restore_vars.items() if + name not in checkpoint_var_names} + checkpoint_not_in_restore_var_names = [name for name in checkpoint_var_names if name not in restore_vars] + if skip_global_step and 'global_step' in checkpoint_not_in_restore_var_names: + checkpoint_not_in_restore_var_names.remove('global_step') + if restore_not_in_checkpoint_vars: + print("global variables that were not restored because they are " + "not in the checkpoint:") + for name, _ in sorted(restore_not_in_checkpoint_vars.items()): + print(" ", name) + if checkpoint_not_in_restore_var_names: + print("checkpoint variables that were not used for restoring " + "because they are not in the graph:") + for name in sorted(checkpoint_not_in_restore_var_names): + print(" ", name) + + + restore_saver = tf.train.Saver(max_to_keep=1, var_list=restore_and_checkpoint_vars, filename=checkpoint) + + return restore_saver, checkpoint + + +def pixel_distribution(pos, height, width): + batch_size = pos.get_shape().as_list()[0] + y, x = tf.unstack(pos, 2, axis=1) + + x0 = tf.cast(tf.floor(x), 'int32') + x1 = x0 + 1 + y0 = tf.cast(tf.floor(y), 'int32') + y1 = y0 + 1 + + Ia = tf.reshape(tf.one_hot(y0 * width + x0, height * width), [batch_size, height, width]) + Ib = tf.reshape(tf.one_hot(y1 * width + x0, height * width), [batch_size, height, width]) + Ic = tf.reshape(tf.one_hot(y0 * width + x1, height * width), [batch_size, height, width]) + Id = tf.reshape(tf.one_hot(y1 * width + x1, height * width), [batch_size, height, width]) + + x0_f = tf.cast(x0, 'float32') + x1_f = tf.cast(x1, 'float32') + y0_f = tf.cast(y0, 'float32') + y1_f = tf.cast(y1, 'float32') + wa = ((x1_f - x) * (y1_f - y))[:, None, None] + wb = ((x1_f - x) * (y - y0_f))[:, None, None] + wc = ((x - x0_f) * (y1_f - y))[:, None, None] + wd = ((x - x0_f) * (y - y0_f))[:, None, None] + + return tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) + + +def flow_to_rgb(flows): + """The last axis should have dimension 2, for x and y values.""" + + def cartesian_to_polar(x, y): + magnitude = tf.sqrt(tf.square(x) + tf.square(y)) + angle = tf.atan2(y, x) + return magnitude, angle + + mag, ang = cartesian_to_polar(*tf.unstack(flows, axis=-1)) + ang_normalized = (ang + np.pi) / (2 * np.pi) + mag_min = tf.reduce_min(mag) + mag_max = tf.reduce_max(mag) + mag_normalized = (mag - mag_min) / (mag_max - mag_min) + hsv = tf.stack([ang_normalized, tf.ones_like(ang), mag_normalized], axis=-1) + rgb = tf.image.hsv_to_rgb(hsv) + return rgb