diff --git a/.gitignore b/.gitignore index 027f7926e575fe674bbb2f703a436999d63d7d26..a7dfc3d7d776b25730741f12db6f98cdfa127751 100644 --- a/.gitignore +++ b/.gitignore @@ -88,6 +88,7 @@ celerybeat-schedule venv/ ENV/ virtual_env*/ +virt_env*/ # Spyder project settings .spyderproject @@ -108,8 +109,8 @@ virtual_env*/ *.DS_Store # Ignore log- and errorfiles -*-err.??????? -*-out.??????? +*-err.[0-9]* +*-out.[0-9]* #Ignore the results files diff --git a/video_prediction_savp/HPC_scripts/DataExtraction.sh b/video_prediction_savp/HPC_scripts/DataExtraction.sh index 0c1491a3f044919f01771ca6ff94ca09a18d3a56..b44065e7babb0411cda6d2849ec429f3672c60d5 100755 --- a/video_prediction_savp/HPC_scripts/DataExtraction.sh +++ b/video_prediction_savp/HPC_scripts/DataExtraction.sh @@ -1,4 +1,5 @@ #!/bin/bash -x +## Controlling Batch-job #SBATCH --account=deepacf #SBATCH --nodes=1 #SBATCH --ntasks=13 @@ -6,18 +7,40 @@ #SBATCH --cpus-per-task=1 #SBATCH --output=DataExtraction-out.%j #SBATCH --error=DataExtraction-err.%j -#SBATCH --time=00:20:00 +#SBATCH --time=05:00:00 #SBATCH --partition=devel #SBATCH --mail-type=ALL -#SBATCH --mail-user=s.stadtler@fz-juelich.de -##jutil env activate -p deepacf - -module purge -module use $OTHERSTAGES -module load Stages/2019a -module addad Intel/2019.3.199-GCC-8.3.0 ParaStationMPI/5.2.2-1 -module load h5py/2.9.0-Python-3.6.8 -module load mpi4py/3.0.1-Python-3.6.8 -module load netcdf4-python/1.5.0.1-Python-3.6.8 - -srun python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/2017/ --destination_dir /p/scratch/deepacf/scarlet/extractedData +#SBATCH --mail-user=b.gong@fz-juelich.de + + +jutil env activate -p deepacf + +# Name of virtual environment +VIRT_ENV_NAME="virt_env_hdfml" + +# Loading mouldes +source ../env_setup/modules_preprocess.sh +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi + +# Declare path-variables +source_dir="/p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/" +dest_dir="/p/scratch/deepacf/video_prediction_shared_folder/extractedData/" + +year="2010" + +# Run data extraction +srun python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir ${source_dir}/${year}/ --destination_dir ${dest_dir}/${year}/ + + + +# 2tier pystager +#srun python ../../workflow_parallel_frame_prediction/DataExtraction/main_single_master.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/${year}/ --destination_dir ${SAVE_DIR}/extractedData/${year} diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess.sh b/video_prediction_savp/HPC_scripts/DataPreprocess.sh index 48ed0581802fe5f629019e729425a7ca1445af4f..aa84de9de7dce7015b26f040aaec48d0b096a816 100755 --- a/video_prediction_savp/HPC_scripts/DataPreprocess.sh +++ b/video_prediction_savp/HPC_scripts/DataPreprocess.sh @@ -1,4 +1,5 @@ #!/bin/bash -x +## Controlling Batch-job #SBATCH --account=deepacf #SBATCH --nodes=1 #SBATCH --ntasks=12 @@ -6,38 +7,58 @@ #SBATCH --cpus-per-task=1 #SBATCH --output=DataPreprocess-out.%j #SBATCH --error=DataPreprocess-err.%j -#SBATCH --time=02:20:00 -#SBATCH --partition=batch +#SBATCH --time=00:20:00 +#SBATCH --partition=devel #SBATCH --mail-type=ALL #SBATCH --mail-user=b.gong@fz-juelich.de -module --force purge -module use $OTHERSTAGES -module load Stages/2019a -module load Intel/2019.3.199-GCC-8.3.0 ParaStationMPI/5.2.2-1 -module load h5py/2.9.0-Python-3.6.8 -module load mpi4py/3.0.1-Python-3.6.8 +# Name of virtual environment +VIRT_ENV_NAME="virt_env_hdfml" -srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ - --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2015/ \ - --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2015/ \ - --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi +# Loading mouldes +source ../env_setup/modules_preprocess.sh -srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ - --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2016/ \ - --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2016/ \ - --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 +source_dir=${SAVE_DIR}/extractedData +destination_dir=${SAVE_DIR}/preprocessedData/era5-Y2015to2017M01to12 +script_dir=`pwd` -srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ - --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017/ \ - --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2017/ \ - --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 +declare -a years=("2222" + "2010_1" + "2012" + "2013_complete" + "2015" + "2016" + "2017" + "2019" + ) +declare -a years=( + "2015" + "2016" + "2017" + ) -#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ -# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017 \ -# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016toY2017M01to12-128x160-74d00N71d0E-T_MSL_gph500/2017 \ -# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 +# ececute Python-scripts +for year in "${years[@]}"; do + echo "Year $year" + echo "source_dir ${source_dir}/${year}" + srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ + --source_dir ${source_dir} -scr_dir ${script_dir} \ + --destination_dir ${destination_dir} --years ${year} --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 + done + + +#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir} --varnames T2 MSL gph500 diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess_dev.sh b/video_prediction_savp/HPC_scripts/DataPreprocess_dev.sh deleted file mode 100755 index b5aa2010cbe2b5b9f87b5b65bade29db974bcc8d..0000000000000000000000000000000000000000 --- a/video_prediction_savp/HPC_scripts/DataPreprocess_dev.sh +++ /dev/null @@ -1,68 +0,0 @@ -#!/bin/bash -x -#SBATCH --account=deepacf -#SBATCH --nodes=1 -#SBATCH --ntasks=12 -##SBATCH --ntasks-per-node=12 -#SBATCH --cpus-per-task=1 -#SBATCH --output=DataPreprocess-out.%j -#SBATCH --error=DataPreprocess-err.%j -#SBATCH --time=00:20:00 -#SBATCH --partition=devel -#SBATCH --mail-type=ALL -#SBATCH --mail-user=m.langguth@fz-juelich.de - -module --force purge -module use $OTHERSTAGES -module load Stages/2019a -module load Intel/2019.3.199-GCC-8.3.0 ParaStationMPI/5.2.2-1 -module load h5py/2.9.0-Python-3.6.8 -module load mpi4py/3.0.1-Python-3.6.8 - -source_dir=/p/scratch/deepacf/video_prediction_shared_folder/extractedData -destination_dir=/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle -declare -a years=("2015" - "2016" - "2017" - ) - - - -for year in "${years[@]}"; - do - echo "Year $year" - echo "source_dir ${source_dir}/${year}" - srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ - --source_dir ${source_dir}/${year}/ \ - --destination_dir ${destination_dir}/${year}/ --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 - done - - -srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir} - - - - -#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ -# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2015/ \ -# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2015/ \ -# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 - -#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ -# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2016/ \ -# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2016/ \ -# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 - -#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ -# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017/ \ -# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/2017/ \ -# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 - -#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py \ -#--destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-#T_MSL_gph500/ - - - -#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ -# --source_dir /p/scratch/deepacf/video_prediction_shared_folder/extractedData/2017 \ -# --destination_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016toY2017M01to12-128x160-74d00N71d0E-T_MSL_gph500/2017 \ -# --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf.sh b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf.sh index 6f541b9d31f582dfd8b9318f7980930716c6c09b..bcf950e93145bcc8b0d15892a606d4cc5d7dd66e 100755 --- a/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf.sh +++ b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf.sh @@ -1,8 +1,8 @@ #!/bin/bash -x #SBATCH --account=deepacf #SBATCH --nodes=1 -#SBATCH --ntasks=12 -##SBATCH --ntasks-per-node=12 +#SBATCH --ntasks=13 +##SBATCH --ntasks-per-node=13 #SBATCH --cpus-per-task=1 #SBATCH --output=DataPreprocess_to_tf-out.%j #SBATCH --error=DataPreprocess_to_tf-err.%j @@ -11,12 +11,26 @@ #SBATCH --mail-type=ALL #SBATCH --mail-user=b.gong@fz-juelich.de -module purge -module use $OTHERSTAGES -module load Stages/2019a -module load Intel/2019.3.199-GCC-8.3.0 ParaStationMPI/5.2.2-1 -module load h5py/2.9.0-Python-3.6.8 -module load mpi4py/3.0.1-Python-3.6.8 -module load TensorFlow/1.13.1-GPU-Python-3.6.8 -srun python ../video_prediction/datasets/era5_dataset_v2.py /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016M01to12-128_160-74.00N710E-T_T_T/splits/ /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/Y2016M01to12-128_160-74.00N710E-T_T_T/tfrecords/ -vars T2 T2 T2 +# Name of virtual environment +VIRT_ENV_NAME="vp" + +# Loading mouldes +source ../env_setup/modules_train.sh +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) +source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/ +destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/ + +# run Preprocessing (step 2 where Tf-records are generated) +srun python ../video_prediction/datasets/era5_dataset_v2.py ${source_dir}/pickle ${destination_dir}/tfrecords -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20 diff --git a/video_prediction_savp/HPC_scripts/generate_era5.sh b/video_prediction_savp/HPC_scripts/generate_era5.sh index c6121ebfc4e441d587cd97fabb366c106324a8be..bb36609129d2c45c5c0cbfaaf0675a7e338eb09c 100755 --- a/video_prediction_savp/HPC_scripts/generate_era5.sh +++ b/video_prediction_savp/HPC_scripts/generate_era5.sh @@ -13,19 +13,33 @@ #SBATCH --mail-user=b.gong@fz-juelich.de ##jutil env activate -p cjjsc42 +# Name of virtual environment +VIRT_ENV_NAME="vp" -module purge -module load GCC/8.3.0 -module load ParaStationMPI/5.2.2-1 -module load TensorFlow/1.13.1-GPU-Python-3.6.8 -module load netcdf4-python/1.5.0.1-Python-3.6.8 -module load h5py/2.9.0-Python-3.6.8 +# Loading mouldes +source ../env_setup/modules_train.sh +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) +source_dir=/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/ +checkpoint_dir=/p/scratch/deepacf/video_prediction_shared_folder/models/ +results_dir=/p/scratch/deepacf/video_prediction_shared_folder/results/ -python -u ../scripts/generate_transfer_learning_finetune.py \ ---input_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/tfrecords/ \ ---dataset_hparams sequence_length=20 --checkpoint /p/scratch/deepacf/video_prediction_shared_folder/models/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/ours_gan \ ---mode test --results_dir /p/scratch/deepacf/video_prediction_shared_folder/results/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T \ ---batch_size 4 --dataset era5 > generate_era5-out.out +# name of model +model=convLSTM + +# run postprocessing/generation of model results including evaluation metrics +srun python -u ../scripts/generate_transfer_learning_finetune.py \ +--input_dir ${source_dir}/tfrecords --dataset_hparams sequence_length=20 --checkpoint ${checkpoint_dir}/${model} \ +--mode test --model ${model} --results_dir ${results_dir}/${model}/ --batch_size 2 --dataset era5 > generate_era5-out.out #srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/video_prediction_savp/HPC_scripts/hyperparam_setup.sh b/video_prediction_savp/HPC_scripts/hyperparam_setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6c24a062ca30b06879641806d771beacc4b34f8 --- /dev/null +++ b/video_prediction_savp/HPC_scripts/hyperparam_setup.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +# for choosing the model +export model=convLSTM +export model_hparams=../hparams/era5/${model}/model_hparams.json + +#create a subfolder with create time and user names, which can be consider as hyperparameter tunning folder. This can avoid overwrite the prevoius trained model using differ#ent hypermeters +export hyperdir="$(date +"%Y%m%dT%H%M")_"$USER"" + +echo "model: ${model}" +echo "hparams: ${model_hparams}" +echo "experiment dir: ${hyperdir}" diff --git a/video_prediction_savp/HPC_scripts/reset_dirs.sh b/video_prediction_savp/HPC_scripts/reset_dirs.sh new file mode 100644 index 0000000000000000000000000000000000000000..8de5247e044150d1c01eccfa512b9ae1c0e4cdfa --- /dev/null +++ b/video_prediction_savp/HPC_scripts/reset_dirs.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +sed -i "s|source_dir=.*|source_dir=${SAVE_DIR}preprocessedData/|g" DataPreprocess_to_tf.sh +sed -i "s|destination_dir=.*|destination_dir=${SAVE_DIR}preprocessedData/|g" DataPreprocess_to_tf.sh + +sed -i "s|source_dir=.*|source_dir=${SAVE_DIR}preprocessedData/|g" train_era5.sh +sed -i "s|destination_dir=.*|destination_dir=${SAVE_DIR}models/|g" train_era5.sh + +sed -i "s|source_dir=.*|source_dir=${SAVE_DIR}preprocessedData/|g" generate_era5.sh +sed -i "s|checkpoint_dir=.*|checkpoint_dir=${SAVE_DIR}models/|g" generate_era5.sh +sed -i "s|results_dir=.*|results_dir=${SAVE_DIR}results/|g" generate_era5.sh diff --git a/video_prediction_savp/HPC_scripts/train_era5.sh b/video_prediction_savp/HPC_scripts/train_era5.sh index ef060b0d985aa0141a1a1cdb974bf04f37b7204b..f605866056f6b2d9fa179a00850468fee0c72d87 100755 --- a/video_prediction_savp/HPC_scripts/train_era5.sh +++ b/video_prediction_savp/HPC_scripts/train_era5.sh @@ -7,21 +7,41 @@ #SBATCH --output=train_era5-out.%j #SBATCH --error=train_era5-err.%j #SBATCH --time=00:20:00 -#SBATCH --gres=gpu:1 +#SBATCH --gres=gpu:2 #SBATCH --partition=develgpus #SBATCH --mail-type=ALL #SBATCH --mail-user=b.gong@fz-juelich.de ##jutil env activate -p cjjsc42 -module --force purge -module use $OTHERSTAGES -module load Stages/2019a -module load GCCcore/.8.3.0 -module load mpi4py/3.0.1-Python-3.6.8 -module load h5py/2.9.0-serial-Python-3.6.8 -module load TensorFlow/1.13.1-GPU-Python-3.6.8 -module load cuDNN/7.5.1.10-CUDA-10.1.105 +# Name of virtual environment +VIRT_ENV_NAME="vp" -srun python ../scripts/train_v2.py --input_dir /p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/2017M01to12-64_64-50.00N11.50E-T_T_T/tfrecords --dataset era5 --model savp --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir /p/scratch/deepacf/video_prediction_shared_folder/models/2017M01to12-64_64-50.00N11.50E-T_T_T/ours_savp -#srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp +# Loading mouldes +source ../env_setup/modules_train.sh +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi + + + + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) +source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/ +destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/ + +# for choosing the model +model=convLSTM +model_hparams=../hparams/era5/${model}/model_hparams.json + +# rund training +srun python ../scripts/train_dummy.py --input_dir ${source_dir}/tfrecords/ --dataset era5 --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/ + + diff --git a/video_prediction_savp/env_setup/create_env.sh b/video_prediction_savp/env_setup/create_env.sh old mode 100755 new mode 100644 index 7d0f0a10bd8586e59fe129198a5a6f7c21121502..ad388826caf1d077c1c6434acae29d6cbaa9c6fc --- a/video_prediction_savp/env_setup/create_env.sh +++ b/video_prediction_savp/env_setup/create_env.sh @@ -1,39 +1,111 @@ #!/usr/bin/env bash +# +# __authors__ = Bing Gong, Michael Langguth +# __date__ = '2020_07_24' -if [[ ! -n "$1" ]]; then - echo "Provide the user name, which will be taken as folder name" +# This script can be used for setting up the virtual environment needed for ambs-project +# or to simply activate it. +# +# some first sanity checks +if [[ ${BASH_SOURCE[0]} == ${0} ]]; then + echo "ERROR: 'create_env.sh' must be sourced, i.e. execute by prompting 'source create_env.sh [virt_env_name]'" exit 1 fi -if [[ ! -n "$2" ]]; then - echo "Provide the env name, which will be taken as folder name" - exit 1 +# from now on, just return if something unexpected occurs instead of exiting +# as the latter would close the terminal including logging out +if [[ ! -n "$1" ]]; then + echo "ERROR: Provide a name to set up the virtual environment, i.e. execute by prompting 'source create_env.sh [virt_env_name]" + return fi -ENV_NAME=$2 -FOLDER_NAME=$1 -WORKING_DIR=/p/project/deepacf/deeprain/${FOLDER_NAME}/video_prediction_savp -ENV_SETUP_DIR=${WORKING_DIR}/env_setup +HOST_NAME=`hostname` +ENV_NAME=$1 +ENV_SETUP_DIR=`pwd` +WORKING_DIR="$(dirname "$ENV_SETUP_DIR")" +EXE_DIR="$(basename "$ENV_SETUP_DIR")" ENV_DIR=${WORKING_DIR}/${ENV_NAME} -source ${ENV_SETUP_DIR}/modules.sh -# Install additional Python packages. -python3 -m venv $ENV_DIR -source ${ENV_DIR}/bin/activate -pip3 install -r ${ENV_SETUP_DIR}/requirements.txt -#pip3 install --user netCDF4 -#pip3 install --user numpy +# further sanity checks: +# * ensure execution from env_setup-directory +# * check if virtual env has already been set up + +if [[ "${EXE_DIR}" != "env_setup" ]]; then + echo "ERROR: The setup-script for the virtual environment from the env_setup-directory!" + return +fi + +if [[ -d ${ENV_DIR} ]]; then + echo "Virtual environment has already been set up under ${ENV_DIR}. The present virtual environment is activated now." + echo "NOTE: If you wish to set up a new virtual environment, delete the existing one or provide a different name." + + ENV_EXIST=1 +else + ENV_EXIST=0 +fi -#Copy the hickle package from bing's account -cp -r /p/project/deepacf/deeprain/bing/hickle ${WORKING_DIR} +# add personal email-address to Batch-scripts +if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then + USER_EMAIL=$(jutil user show -o json | grep email | cut -f2 -d':' | cut -f1 -d',' | cut -f2 -d'"') + #replace the email in sbatch script with the USER_EMAIL + sed -i "s/--mail-user=.*/--mail-user=$USER_EMAIL/g" ../HPC_scripts/*.sh + # load modules and check for their availability + echo "***** Checking modules required during the workflow... *****" + source ${ENV_SETUP_DIR}/modules_preprocess.sh + source ${ENV_SETUP_DIR}/modules_train.sh -source ${ENV_SETUP_DIR}/modules.sh -source ${ENV_DIR}/bin/activate +elif [[ "${HOST_NAME}" == "zam347" ]]; then + unset PYTHONPATH +fi -export PYTHONPATH=${WORKING_DIR}/hickle/lib/python3.6/site-packages:$PYTHONPATH -export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH -export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH -#export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH -export PYTHONPATH=${WORKING_DIR}/lpips-tensorflow:$PYTHONPATH +if [[ "$ENV_EXIST" == 0 ]]; then + # Activate virtual environmen and install additional Python packages. + echo "Configuring and activating virtual environment on ${HOST_NAME}" + + python3 -m venv $ENV_DIR + + activate_virt_env=${ENV_DIR}/bin/activate + echo ${activate_virt_env} + + source ${activate_virt_env} + + # install some requirements and/or check for modules + if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then + # check module availability for the first time on known HPC-systems + echo "***** Start installing additional Python modules with pip... *****" + pip3 install --no-cache-dir --ignore-installed -r ${ENV_SETUP_DIR}/requirements.txt + #pip3 install --user netCDF4 + #pip3 install --user numpy + elif [[ "${HOST_NAME}" == "zam347" ]]; then + echo "***** Start installing additional Python modules with pip... *****" + pip3 install --upgrade pip + pip3 install -r ${ENV_SETUP_DIR}/requirements.txt + pip3 install mpi4py + pip3 install netCDF4 + pip3 install numpy + pip3 install h5py + pip3 install tensorflow-gpu==1.13.1 + fi + # expand PYTHONPATH... + export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH >> ${activate_virt_env} + #export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH + export PYTHONPATH=${WORKING_DIR}/external_package/lpips-tensorflow:$PYTHONPATH >> ${activate_virt_env} + if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then + export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH >> ${activate_virt_env} + fi + # ...and ensure that this also done when the + echo "" >> ${activate_virt_env} + echo "# Expand PYTHONPATH..." >> ${activate_virt_env} + echo "export PYTHONPATH=${WORKING_DIR}:\$PYTHONPATH" >> ${activate_virt_env} + #export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:\$PYTHONPATH + echo "export PYTHONPATH=${WORKING_DIR}/external_package/lpips-tensorflow:\$PYTHONPATH" >> ${activate_virt_env} + + if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then + echo "export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:\$PYTHONPATH" >> ${activate_virt_env} + fi +elif [[ "$ENV_EXIST" == 1 ]]; then + # activating virtual env is suifficient + source ${ENV_DIR}/bin/activate +fi diff --git a/video_prediction_savp/env_setup/modules.sh b/video_prediction_savp/env_setup/modules.sh deleted file mode 100755 index e6793787ad59988cfc6646dc8dd789d1573c6b23..0000000000000000000000000000000000000000 --- a/video_prediction_savp/env_setup/modules.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash -module purge -module use $OTHERSTAGES -module load Stages/2019a -module load GCC/8.3.0 -module load MVAPICH2/.2.3.1-GDR -module load GCCcore/.8.3.0 -module load mpi4py/3.0.1-Python-3.6.8 -module load h5py/2.9.0-serial-Python-3.6.8 -module load TensorFlow/1.13.1-GPU-Python-3.6.8 -module load cuDNN/7.5.1.10-CUDA-10.1.105 - diff --git a/video_prediction_savp/env_setup/modules_preprocess.sh b/video_prediction_savp/env_setup/modules_preprocess.sh new file mode 100755 index 0000000000000000000000000000000000000000..c4a242b0a739eb65c7a340dd8e3a3fbb57b04408 --- /dev/null +++ b/video_prediction_savp/env_setup/modules_preprocess.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# __author__ = Bing Gong, Michael Langguth +# __date__ = '2020_06_26' + +# This script loads the required modules for ambs on Juwels and HDF-ML. +# Note that some other packages have to be installed into a venv (see create_env.sh and requirements.txt). + +HOST_NAME=`hostname` + +echo "Start loading modules on ${HOST_NAME} required for preprocessing..." +echo "This script is used by: " +echo "* DataExtraction.sh" +echo "* DataPreprocess.sh" + +module purge +module use $OTHERSTAGES +module load Stages/2019a +module load GCC/8.3.0 +module load ParaStationMPI/5.2.2-1 +module load mpi4py/3.0.1-Python-3.6.8 +# serialized version is not available on HFML +# see https://gitlab.version.fz-juelich.de/haf/Wiki/-/wikis/HDF-ML%20System +if [[ "${HOST_NAME}" == hdfml* ]]; then + module load h5py/2.9.0-serial-Python-3.6.8 +elif [[ "${HOST_NAME}" == juwels* ]]; then + module load h5py/2.9.0-Python-3.6.8 +fi +module load netcdf4-python/1.5.0.1-Python-3.6.8 + diff --git a/video_prediction_savp/env_setup/modules_train.sh b/video_prediction_savp/env_setup/modules_train.sh new file mode 100755 index 0000000000000000000000000000000000000000..cdbc436e1976773132aa636a3f1cedfa506d3a14 --- /dev/null +++ b/video_prediction_savp/env_setup/modules_train.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# __author__ = Bing Gong, Michael Langguth +# __date__ = '2020_06_26' + +# This script loads the required modules for ambs on Juwels and HDF-ML. +# Note that some other packages have to be installed into a venv (see create_env.sh and requirements.txt). + +HOST_NAME=`hostname` + +echo "Start loading modules on ${HOST_NAME}..." + +module purge +module use $OTHERSTAGES +module load Stages/2019a +module load GCC/8.3.0 +module load GCCcore/.8.3.0 +module load ParaStationMPI/5.2.2-1 +module load mpi4py/3.0.1-Python-3.6.8 +# serialized version of HDF5 is used since only this version is compatible with TensorFlow/1.13.1-GPU-Python-3.6.8 +module load h5py/2.9.0-serial-Python-3.6.8 +module load TensorFlow/1.13.1-GPU-Python-3.6.8 +module load cuDNN/7.5.1.10-CUDA-10.1.105 +module load netcdf4-python/1.5.0.1-Python-3.6.8 + diff --git a/video_prediction_savp/env_setup/requirements.txt b/video_prediction_savp/env_setup/requirements.txt index 76dd1f57d64577cc565968bb7106656e53687261..4bf2f0b25d082c4c503bbd56f46d28360a48df43 100644 --- a/video_prediction_savp/env_setup/requirements.txt +++ b/video_prediction_savp/env_setup/requirements.txt @@ -2,3 +2,4 @@ opencv-python scipy scikit-image pandas +hickle diff --git a/video_prediction_savp/hparams/era5/model_hparams.json b/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json similarity index 56% rename from video_prediction_savp/hparams/era5/model_hparams.json rename to video_prediction_savp/hparams/era5/convLSTM/model_hparams.json index b121ee2f005b6db753b2536deb804204dd41b78d..c2edaad9f9ac158f6e7b8d94bb81db16d55d05e8 100644 --- a/video_prediction_savp/hparams/era5/model_hparams.json +++ b/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json @@ -1,11 +1,12 @@ + { - "batch_size": 8, + "batch_size": 10, "lr": 0.001, - "nz": 16, - "max_steps":500, + "max_epochs":2, "context_frames":10, "sequence_length":20 } + diff --git a/video_prediction_savp/metadata.py b/video_prediction_savp/metadata.py index b65db76d71a549b3405f9f10b26dcccb09b30230..8f61d5766169c08d793b665253a8cd1e86a80548 100644 --- a/video_prediction_savp/metadata.py +++ b/video_prediction_savp/metadata.py @@ -4,6 +4,7 @@ Classes and routines to retrieve and handle meta-data import os import sys +import time import numpy as np import json from netCDF4 import Dataset @@ -23,7 +24,9 @@ class MetaData: method_name = MetaData.__init__.__name__+" of Class "+MetaData.__name__ if not json_file is None: - MetaData.get_metadata_from_file(json_file) + print(json_file) + print(type(json_file)) + MetaData.get_metadata_from_file(self,json_file) else: # No dictionary from json-file available, all other arguments have to set @@ -90,9 +93,11 @@ class MetaData: self.nx, self.ny = np.abs(slices['lon_e'] - slices['lon_s']), np.abs(slices['lat_e'] - slices['lat_s']) sw_c = [float(datafile.variables['lat'][slices['lat_e']-1]),float(datafile.variables['lon'][slices['lon_s']])] # meridional axis lat is oriented from north to south (i.e. monotonically decreasing) self.sw_c = sw_c + self.lat = datafile.variables['lat'][slices['lat_s']:slices['lat_e']] + self.lon = datafile.variables['lon'][slices['lon_s']:slices['lon_e']] - # Now start constructing exp_dir-string - # switch sign and coordinate-flags to avoid negative values appearing in exp_dir-name + # Now start constructing expdir-string + # switch sign and coordinate-flags to avoid negative values appearing in expdir-name if sw_c[0] < 0.: sw_c[0] = np.abs(sw_c[0]) flag_coords[0] = "S" @@ -112,7 +117,7 @@ class MetaData: expdir, expname = path_parts[0], path_parts[1] - # extend exp_dir_in successively (splitted up for better readability) + # extend expdir_in successively (splitted up for better readability) expname += "-"+str(self.nx) + "x" + str(self.ny) expname += "-"+(("{0: 05.2f}"+flag_coords[0]+"{1:05.2f}"+flag_coords[1]).format(*sw_c)).strip().replace(".","")+"-" @@ -139,10 +144,15 @@ class MetaData: "expdir" : self.expdir} meta_dict["sw_corner_frame"] = { - "lat" : self.sw_c[0], - "lon" : self.sw_c[1] + "lat" : np.around(self.sw_c[0],decimals=2), + "lon" : np.around(self.sw_c[1],decimals=2) } + meta_dict["coordinates"] = { + "lat" : np.around(self.lat,decimals=2).tolist(), + "lon" : np.around(self.lon,decimals=2).tolist() + } + meta_dict["frame_size"] = { "nx" : int(self.nx), "ny" : int(self.ny) @@ -150,7 +160,7 @@ class MetaData: meta_dict["variables"] = [] for i in range(len(self.varnames)): - print(self.varnames[i]) + #print(self.varnames[i]) meta_dict["variables"].append( {"var"+str(i+1) : self.varnames[i]}) @@ -163,14 +173,19 @@ class MetaData: meta_fname = os.path.join(dest_dir,"metadata.json") - if os.path.exists(meta_fname): # check if a metadata-file already exists and check its content + if os.path.exists(meta_fname): # check if a metadata-file already exists and check its content + print(method_name+": json-file ('"+meta_fname+"' already exists. Its content will be checked...") self.status = "old" # set status to old in order to prevent repeated modification of shell-/Batch-scripts with open(meta_fname,'r') as js_file: dict_dupl = json.load(js_file) if dict_dupl != meta_dict: - print(method_name+": Already existing metadata (see '"+meta_fname+") do not fit data being processed right now. Ensure a common data base.") - sys.exit(1) + meta_fname_dbg = os.path.join(dest_dir,"metadata_debug.json") + print(method_name+": Already existing metadata (see '"+meta_fname+"') do not fit data being processed right now (see '" \ + +meta_fname_dbg+"'. Ensure a common data base.") + with open(meta_fname_dbg,'w') as js_file: + json.dump(meta_dict,js_file) + raise ValueError else: #do not need to do anything pass else: @@ -189,20 +204,24 @@ class MetaData: with open(js_file) as js_file: dict_in = json.load(js_file) - self.exp_dir = dict_in["exp_dir"] + self.expdir = dict_in["expdir"] self.sw_c = [dict_in["sw_corner_frame"]["lat"],dict_in["sw_corner_frame"]["lon"] ] + self.lat = dict_in["coordinates"]["lat"] + self.lon = dict_in["coordinates"]["lon"] self.nx = dict_in["frame_size"]["nx"] self.ny = dict_in["frame_size"]["ny"] - - self.variables = [dict_in["variables"][ivar] for ivar in dict_in["variables"].keys()] - + # dict_in["variables"] is a list like [{var1: varname1},{var2: varname2},...] + list_of_dict_aux = dict_in["variables"] + # iterate through the list with an integer ivar + # note: the naming of the variables starts with var1, thus add 1 to the iterator + self.variables = [list_of_dict_aux[ivar]["var"+str(ivar+1)] for ivar in range(len(list_of_dict_aux))] def write_dirs_to_batch_scripts(self,batch_script): """ - Expands ('known') directory-variables in batch_script by exp_dir-attribute of class instance + Expands ('known') directory-variables in batch_script by expdir-attribute of class instance """ paths_to_mod = ["source_dir=","destination_dir=","checkpoint_dir=","results_dir="] # known directory-variables in batch-scripts @@ -224,6 +243,7 @@ class MetaData: def write_destdir_jsontmp(dest_dir, tmp_dir = None): """ Writes dest_dir to temporary json-file (temp.json) stored in the current working directory. + To be executed by Master node in parallel mode. """ if not tmp_dir: tmp_dir = os.getcwd() @@ -259,6 +279,34 @@ class MetaData: else: return(dict_tmp.get("destination_dir")) + @staticmethod + def wait_for_jsontmp(tmp_dir = None, waittime = 10, delay=0.5): + """ + Waits at max. waittime (in sec) until temp.json-file becomes available + """ + + method_name = MetaData.wait_for_jsontmp.__name__+" of Class "+MetaData.__name__ + + if not tmp_dir: tmp_dir = os.getcwd() + + file_tmp = os.path.join(tmp_dir,"temp.json") + + counter_max = waittime/delay + counter = 0 + status = "not_ok" + + while (counter <= counter_max): + if os.path.isfile(file_tmp): + status = "ok" + break + else: + time.sleep(delay) + + counter += 1 + + if status != "ok": raise IOError(method_name+": '"+file_tmp+ \ + "' does not exist after waiting for "+str(waittime)+" sec.") + @staticmethod def issubset(a,b): diff --git a/video_prediction_savp/scripts/generate_anomaly.py b/video_prediction_savp/scripts/generate_anomaly.py deleted file mode 100644 index 9c8e555c2e223c51bfb555c7dd0fa0eb089610d8..0000000000000000000000000000000000000000 --- a/video_prediction_savp/scripts/generate_anomaly.py +++ /dev/null @@ -1,518 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import errno -import json -import os -import math -import random -import cv2 -import numpy as np -import tensorflow as tf -import seaborn as sns -import pickle -from random import seed -import random -import json -import numpy as np -#from six.moves import cPickle -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -import matplotlib.animation as animation -import seaborn as sns -import pandas as pd -from video_prediction import datasets, models -from matplotlib.colors import LinearSegmentedColormap -from matplotlib.ticker import MaxNLocator -from video_prediction.utils.ffmpeg_gif import save_gif - -with open("./splits_size_64_64_1/geo_info.json","r") as json_file: - geo = json.load(json_file) - lat = [round(i,2) for i in geo["lat"]] - lon = [round(i,2) for i in geo["lon"]] - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " - "train, val, test, etc, or a directory containing " - "the tfrecords") - parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") - parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") - parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") - parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is " - "results_gif_dir/model_fname") - parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is " - "results_png_dir/model_fname") - parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") - - parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') - - parser.add_argument("--dataset", type=str, help="dataset class name") - parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") - parser.add_argument("--model", type=str, help="model class name") - parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") - - parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") - parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") - parser.add_argument("--num_epochs", type=int, default=1) - - parser.add_argument("--num_stochastic_samples", type=int, default=1) #Bing original is 5, change to 1 - parser.add_argument("--gif_length", type=int, help="default is sequence_length") - parser.add_argument("--fps", type=int, default=4) - - parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") - parser.add_argument("--seed", type=int, default=7) - - args = parser.parse_args() - - if args.seed is not None: - tf.set_random_seed(args.seed) - np.random.seed(args.seed) - random.seed(args.seed) - - args.results_gif_dir = args.results_gif_dir or args.results_dir - args.results_png_dir = args.results_png_dir or args.results_dir - dataset_hparams_dict = {} - model_hparams_dict = {} - if args.checkpoint: - checkpoint_dir = os.path.normpath(args.checkpoint) - if not os.path.isdir(args.checkpoint): - checkpoint_dir, _ = os.path.split(checkpoint_dir) - if not os.path.exists(checkpoint_dir): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) - with open(os.path.join(checkpoint_dir, "options.json")) as f: - print("loading options from checkpoint %s" % args.checkpoint) - options = json.loads(f.read()) - args.dataset = args.dataset or options['dataset'] - args.model = args.model or options['model'] - try: - with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: - dataset_hparams_dict = json.loads(f.read()) - except FileNotFoundError: - print("dataset_hparams.json was not loaded because it does not exist") - try: - with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: - model_hparams_dict = json.loads(f.read()) - except FileNotFoundError: - print("model_hparams.json was not loaded because it does not exist") - args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1]) - args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1]) - else: - if not args.dataset: - raise ValueError('dataset is required when checkpoint is not specified') - if not args.model: - raise ValueError('model is required when checkpoint is not specified') - args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) - args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) - - print('----------------------------------- Options ------------------------------------') - for k, v in args._get_kwargs(): - print(k, "=", v) - print('------------------------------------- End --------------------------------------') - - VideoDataset = datasets.get_dataset_class(args.dataset) - dataset = VideoDataset( - args.input_dir, - mode=args.mode, - num_epochs=args.num_epochs, - seed=args.seed, - hparams_dict=dataset_hparams_dict, - hparams=args.dataset_hparams) - VideoPredictionModel = models.get_model_class(args.model) - hparams_dict = dict(model_hparams_dict) - hparams_dict.update({ - 'context_frames': dataset.hparams.context_frames, - 'sequence_length': dataset.hparams.sequence_length, - 'repeat': dataset.hparams.time_shift, - }) - model = VideoPredictionModel( - mode=args.mode, - hparams_dict=hparams_dict, - hparams=args.model_hparams) - - sequence_length = model.hparams.sequence_length - context_frames = model.hparams.context_frames - future_length = sequence_length - context_frames - - if args.num_samples: - if args.num_samples > dataset.num_examples_per_epoch(): - raise ValueError('num_samples cannot be larger than the dataset') - num_examples_per_epoch = args.num_samples - else: - #Bing: error occurs here, cheats a little bit here - #num_examples_per_epoch = dataset.num_examples_per_epoch() - num_examples_per_epoch = args.batch_size * 8 - if num_examples_per_epoch % args.batch_size != 0: - #bing - #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) - pass - #Bing if it is era 5 data we used dataset.make_batch_v2 - #inputs = dataset.make_batch(args.batch_size) - inputs, inputs_mean = dataset.make_batch_v2(args.batch_size) - input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} - with tf.variable_scope(''): - model.build_graph(input_phs) - - for output_dir in (args.output_gif_dir, args.output_png_dir): - if not os.path.exists(output_dir): - os.makedirs(output_dir) - with open(os.path.join(output_dir, "options.json"), "w") as f: - f.write(json.dumps(vars(args), sort_keys=True, indent=4)) - with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: - f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) - with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: - f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) - - gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) - config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) - sess = tf.Session(config=config) - sess.graph.as_default() - model.restore(sess, args.checkpoint) - sample_ind = 0 - gen_images_all = [] - input_images_all = [] - - while True: - if args.num_samples and sample_ind >= args.num_samples: - break - try: - input_results = sess.run(inputs) - input_mean_results = sess.run(inputs_mean) - input_final = input_results["images"] + input_mean_results["images"] - - except tf.errors.OutOfRangeError: - break - print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) - feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} - for stochastic_sample_ind in range(args.num_stochastic_samples): - print("Stochastic sample id", stochastic_sample_ind) - gen_anomaly = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) - gen_images = gen_anomaly + input_mean_results["images"][:,1:,:,:] - #input_images = sess.run(inputs["images"]) - #Bing: Add evaluation metrics - # fetches = {'images': model.inputs['images']} - # fetches.update(model.eval_outputs.items()) - # fetches.update(model.eval_metrics.items()) - # results = sess.run(fetches, feed_dict = feed_dict) - # input_images = results["images"] #shape (batch_size,future_frames,height,width,channel) - # only keep the future frames - #gen_images = gen_images[:, -future_length:] #(8,10,64,64,1) (batch_size, sequences, height, width, channel) - #input_images = input_results["images"][:,-future_length:,:,:] - #input_images = input_results["images"][:,1:,:,:,:] - input_images = input_final [:,1:,:,:,:] - #gen_mse_avg = results["eval_mse/avg"] #shape (batch_size,future_frames) - print("Finish sample ind",stochastic_sample_ind) - input_gen_diff_ = input_images - gen_images - #diff_image_range = pd.cut(input_gen_diff_.flatten(), bins = 4, labels = [-10, -5, 0, 5], right = False) - #diff_image_range = np.reshape(np.array(diff_image_range),input_gen_diff_.shape) - gen_images_all.extend(gen_images) - input_images_all.extend(input_images) - - colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] - cmap_name = 'my_list' - if sample_ind < 100: - for i in range(len(gen_images)): - name = 'Batch_id_' + str(sample_ind) + " + Sample_" + str(i) - gen_images_ = gen_images[i, :] - gen_mse_avg_ = [np.mean(input_gen_diff_[i, frame, :, :, :]**2) for frame in - range(19)] # return the list with 10 (sequence) mse - - input_gen_diff = input_gen_diff_[i,:,:,:,:] - input_images_ = input_images[i, :] - #gen_mse_avg_ = gen_mse_avg[i, :] - fig = plt.figure() - gs = gridspec.GridSpec(4,6) - gs.update(wspace = 0.7,hspace=0.8) - ax1 = plt.subplot(gs[0:2,0:3]) - ax2 = plt.subplot(gs[0:2,3:],sharey=ax1) - ax3 = plt.subplot(gs[2:4,0:3]) - ax4 = plt.subplot(gs[2:4,3:]) - xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] - ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] - plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels) - ax1.title.set_text("(a) Ground Truth") - ax2.title.set_text("(b) SAVP") - ax3.title.set_text("(c) Diff.") - ax4.title.set_text("(d) MSE") - - ax1.xaxis.set_tick_params(labelsize=7) - ax1.yaxis.set_tick_params(labelsize = 7) - ax2.xaxis.set_tick_params(labelsize=7) - ax2.yaxis.set_tick_params(labelsize = 7) - ax3.xaxis.set_tick_params(labelsize=7) - ax3.yaxis.set_tick_params(labelsize = 7) - - init_images = np.zeros((input_images_.shape[1], input_images_.shape[2])) - print("inti images shape", init_images.shape) - xdata, ydata = [], [] - plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) - plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) - #x = np.linspace(0, 64, 64) - #y = np.linspace(0, 64, 64) - #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) - #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) - fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7) - fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7) - - cm = LinearSegmentedColormap.from_list( - cmap_name, "bwr", N = 5) - - plot3 = ax3.imshow(init_images, vmin=-10, vmax=10, cmap=cm)#cmap = 'PuBu_r', - - plot4, = ax4.plot([], [], color = "r") - ax4.set_xlim(0, len(gen_mse_avg_)-1) - ax4.set_ylim(0, 10) - ax4.set_xlabel("Frames", fontsize=10) - #ax4.set_ylabel("MSE", fontsize=10) - ax4.xaxis.set_tick_params(labelsize=7) - ax4.yaxis.set_tick_params(labelsize=7) - - - plots = [plot1, plot2, plot3, plot4] - - #fig.colorbar(plots[1], ax = [ax1, ax2]) - - fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7) - #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7) - #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7) - - def animation_sample(t): - input_image = input_images_[t, :, :, 0] - gen_image = gen_images_[t, :, :, 0] - diff_image = input_gen_diff[t,:,:,0] - - data = gen_mse_avg_[:t + 1] - # x = list(range(len(gen_mse_avg_)))[:t+1] - xdata.append(t) - print("xdata", xdata) - ydata.append(gen_mse_avg_[t]) - - print("ydata", ydata) - # p = sns.lineplot(x=x,y=data,color="b") - # p.tick_params(labelsize=17) - # plt.setp(p.lines, linewidth=6) - plots[0].set_data(input_image) - plots[1].set_data(gen_image) - #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) - #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) - plots[2].set_data(diff_image) - plots[3].set_data(xdata, ydata) - fig.suptitle("Frame " + str(t+1)) - - return plots - - ani = animation.FuncAnimation(fig, animation_sample, frames = len(gen_mse_avg_), interval = 1000, - repeat_delay = 2000) - ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4")) - - else: - pass - - # # for i, gen_mse_avg_ in enumerate(gen_mse_avg): - # # ims = [] - # # fig = plt.figure() - # # plt.xlim(0,len(gen_mse_avg_)) - # # plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg)) - # # plt.xlabel("Frames") - # # plt.ylabel("MSE_AVG") - # # #X = list(range(len(gen_mse_avg_))) - # # #for t, gen_mse_avg_ in enumerate(gen_mse_avg): - # # def animate_metric(j): - # # data = gen_mse_avg_[:(j+1)] - # # x = list(range(len(gen_mse_avg_)))[:(j+1)] - # # p = sns.lineplot(x=x,y=data,color="b") - # # p.tick_params(labelsize=17) - # # plt.setp(p.lines, linewidth=6) - # # ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000) - # # ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif")) - # # - # # - # # for i, input_images_ in enumerate(input_images): - # # #context_images_ = (input_results['images'][i]) - # # #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) - # # ims = [] - # # fig = plt.figure() - # # for t, input_image in enumerate(input_images_): - # # im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') - # # ttl = plt.text(1.5, 2,"Frame_" + str(t)) - # # ims.append([im,ttl]) - # # ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000) - # # ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif")) - # # #plt.show() - # # - # # for i,gen_images_ in enumerate(gen_images): - # # ims = [] - # # fig = plt.figure() - # # for t, gen_image in enumerate(gen_images_): - # # im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') - # # ttl = plt.text(1.5, 2, "Frame_" + str(t)) - # # ims.append([im, ttl]) - # # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) - # # ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif")) - # - # - # # for i, gen_images_ in enumerate(gen_images): - # # #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) - # # #gen_images_ = (gen_images_ * 255.0).astype(np.uint8) - # # #bing - # # context_images_ = (input_results['images'][i]) - # # gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) - # # context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) - # # plt.figure(figsize = (10,2)) - # # gs = gridspec.GridSpec(2,10) - # # gs.update(wspace=0.,hspace=0.) - # # for t, gen_image in enumerate(gen_images_): - # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1))) - # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) - # # plt.subplot(gs[t]) - # # plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') # the last index sets the channel. 0 = t2 - # # # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) - # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, - # # right = False, labelbottom = False, labelleft = False) - # # if t == 0: plt.ylabel('Actual', fontsize = 10) - # # - # # plt.subplot(gs[t + 10]) - # # plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') - # # # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) - # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, - # # right = False, labelbottom = False, labelleft = False) - # # if t == 0: plt.ylabel('Predicted', fontsize = 10) - # # plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png') - # # plt.clf() - # - # # if args.gif_length: - # # context_and_gen_images = context_and_gen_images[:args.gif_length] - # # save_gif(os.path.join(args.output_gif_dir, gen_images_fname), - # # context_and_gen_images, fps=args.fps) - # # - # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) - # # for t, gen_image in enumerate(gen_images_): - # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) - # # if gen_image.shape[-1] == 1: - # # gen_image = np.tile(gen_image, (1, 1, 3)) - # # else: - # # gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) - # # cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) - - sample_ind += args.batch_size - - - with open(os.path.join(args.output_png_dir, "input_images_all"),"wb") as input_files: - pickle.dump(input_images_all,input_files) - - with open(os.path.join(args.output_png_dir, "gen_images_all"),"wb") as gen_files: - pickle.dump(gen_images_all,gen_files) - - with open(os.path.join(args.output_png_dir, "input_images_all"),"rb") as input_files: - input_images_all = pickle.load(input_files) - - with open(os.path.join(args.output_png_dir, "gen_images_all"),"rb") as gen_files: - gen_images_all=pickle.load(gen_files) - ims = [] - fig = plt.figure() - for frame in range(19): - input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) - #pix_mean = np.mean(input_gen_diff, axis = 0) - #pix_std = np.std(input_gen_diff, axis=0) - im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu') - if frame == 0: - fig.colorbar(im) - ttl = plt.text(1.5, 2, "Frame_" + str(frame +1)) - ims.append([im, ttl]) - ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) - ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4")) - plt.close("all") - - ims = [] - fig = plt.figure() - for frame in range(19): - pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) - #pix_mean = np.mean(input_gen_diff, axis = 0) - #pix_std = np.std(input_gen_diff, axis=0) - im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu') - if frame == 0: - fig.colorbar(im) - ttl = plt.text(1.5, 2, "Frame_" + str(frame+1)) - ims.append([im, ttl]) - ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) - ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4")) - - gen_images_all = np.array(gen_images_all) - input_images_all = np.array(input_images_all) - # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first - # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2) - # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 ) - - mse_model = np.mean((input_images_all[:, :10,:,:,0] - gen_images_all[:, :10,:,:,0])**2) # look at all timesteps except the first - mse_model_last = np.mean((input_images_all[:,10,:,:,0] - gen_images_all[:, 10,:,:,0])**2) - mse_prev = np.mean((input_images_all[:, :9, :, :, 0] - input_images_all[:, 1:10, :, :, 0])**2 ) - - def psnr(img1, img2): - mse = np.mean((img1 - img2) ** 2) - if mse == 0: return 100 - PIXEL_MAX = 1 - return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) - - psnr_model = psnr(input_images_all[:, :10, :, :, 0], gen_images_all[:, :10, :, :, 0]) - psnr_model_last = psnr(input_images_all[:, 10, :, :, 0], gen_images_all[:,10, :, :, 0]) - psnr_prev = psnr(input_images_all[:, :9, :, :, 0], input_images_all[:, 1:10, :, :, 0]) - f = open(os.path.join(args.output_png_dir,'prediction_scores_4prediction.txt'), 'w') - f.write("Model MSE: %f\n" % mse_model) - f.write("Model MSE from only last prediction in sequence: %f\n" % mse_model_last) - f.write("Previous Frame MSE: %f\n" % mse_prev) - f.write("Model PSNR: %f\n" % psnr_model) - f.write("Model PSNR from only last prediction in sequence: %f\n" % psnr_model_last) - f.write("Previous frame PSNR: %f\n" % psnr_prev) - f.write("Shape of X_test: " + str(input_images_all.shape)) - f.write("") - f.write("Shape of X_hat: " + str(gen_images_all.shape)) - f.close() - - seed(1) - s = random.sample(range(len(gen_images_all)), 100) - print("******KDP******") - #kernel density plot for checking the model collapse - fig = plt.figure() - kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images") - kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True") - kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability') - plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400) - plt.clf() - - #line plot for evaluating the prediction and groud-truth - for i in [0,3,6,9,12,15,18]: - fig = plt.figure() - plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3) - #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3) - plt.xlabel("Prediction") - plt.ylabel("Real values") - plt.title("Frame_{}".format(i+1)) - plt.plot([250,300], [250,300],color="black") - plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i)))) - plt.clf() - - - mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence) - x = [str(i+1) for i in list(range(19))] - fig,axis = plt.subplots() - mean_f = np.mean(mse_model_by_frames, axis = 0) - median = np.median(mse_model_by_frames, axis=0) - q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0) - q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0) - d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0) - d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0) - plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range") - plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range") - plt.plot(x, median, color="grey", linewidth=0.6, label="Median") - plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean") - plt.title(f'MSE percentile') - plt.xlabel("Frames") - plt.legend(loc=2, fontsize=8) - plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png")) - -if __name__ == '__main__': - main() diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index 331559f6287a4f24c1c19ee9f7f4b03309a22abf..0e250b47df28d115c8cdfc77fc708eab5e094ce6 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -12,12 +12,10 @@ import cv2 import numpy as np import tensorflow as tf import pickle -import hickle from random import seed import random import json import numpy as np -#from six.moves import cPickle import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt @@ -37,11 +35,88 @@ import sys sys.path.append(path.abspath('../video_prediction/datasets/')) from era5_dataset_v2 import Norm_data from os.path import dirname +from netCDF4 import Dataset,date2num +from metadata import MetaData as MetaData + +def set_seed(seed): + if seed is not None: + tf.set_random_seed(seed) + np.random.seed(seed) + random.seed(seed) + +def get_coordinates(metadata_fname): + """ + Retrieves the latitudes and longitudes read from the metadata json file. + """ + md = MetaData(json_file=metadata_fname) + md.get_metadata_from_file(metadata_fname) + + try: + print("lat:",md.lat) + print("lon:",md.lon) + return md.lat, md.lon + except: + raise ValueError("Error when handling: '"+metadata_fname+"'") + + +def load_checkpoints_and_create_output_dirs(checkpoint,dataset,model): + if checkpoint: + checkpoint_dir = os.path.normpath(checkpoint) + if not os.path.isdir(checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % checkpoint) + options = json.loads(f.read()) + dataset = dataset or options['dataset'] + model = model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + else: + if not dataset: + raise ValueError('dataset is required when checkpoint is not specified') + if not model: + raise ValueError('model is required when checkpoint is not specified') -with open("../geo_info.json","r") as json_file: - geo = json.load(json_file) - lat = [round(i,2) for i in geo["lat"]] - lon = [round(i,2) for i in geo["lon"]] + return options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict + + + +def setup_dataset(dataset,input_dir,mode,seed,num_epochs,dataset_hparams,dataset_hparams_dict): + VideoDataset = datasets.get_dataset_class(dataset) + dataset = VideoDataset( + input_dir, + mode = mode, + num_epochs = num_epochs, + seed = seed, + hparams_dict = dataset_hparams_dict, + hparams = dataset_hparams) + return dataset + + +def setup_dirs(input_dir,results_png_dir): + input_dir = args.input_dir + temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/" + print ("temporal_dir:",temporal_dir) + + +def update_hparams_dict(model_hparams_dict,dataset): + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + return hparams_dict def psnr(img1, img2): @@ -51,6 +126,218 @@ def psnr(img1, img2): return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) +def setup_num_samples_per_epoch(num_samples, dataset): + if num_samples: + if num_samples > dataset.num_examples_per_epoch(): + raise ValueError('num_samples cannot be larger than the dataset') + num_examples_per_epoch = num_samples + else: + num_examples_per_epoch = dataset.num_examples_per_epoch() + #if num_examples_per_epoch % args.batch_size != 0: + # raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + return num_examples_per_epoch + + +def initia_save_data(): + sample_ind = 0 + gen_images_all = [] + #Bing:20200410 + persistent_images_all = [] + input_images_all = [] + return sample_ind, gen_images_all,persistent_images_all, input_images_all + + +def write_params_to_results_dir(args,output_dir,dataset,model): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys = True, indent = 4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4)) + with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4)) + return None + + +def denorm_images(stat_fl, input_images_,channel,var): + norm_cls = Norm_data(var) + norm = 'minmax' + with open(stat_fl) as js_file: + norm_cls.check_and_set_norm(json.load(js_file),norm) + input_images_denorm = norm_cls.denorm_var(input_images_[:, :, :,channel], var, norm) + return input_images_denorm + +def denorm_images_all_channels(stat_fl,input_images_,*args): + input_images_all_channles_denorm = [] + input_images_ = np.array(input_images_) + print("THIS IS INPUT_IAMGES SHPAE,",input_images_.shape) + args = [item for item in args][0] + for c in range(len(args)): + print("args c:", args[c]) + input_images_all_channles_denorm.append(denorm_images(stat_fl,input_images_,channel=c,var=args[c])) + input_images_denorm = np.stack(input_images_all_channles_denorm, axis=-1) + #print("input_images_denorm shape",input_images_denorm.shape) + return input_images_denorm + +def get_one_seq_and_time(input_images,t_starts,i): + assert (len(np.array(input_images).shape)==5) + input_images_ = input_images[i,:,:,:,:] + t_start = t_starts[i] + return input_images_,t_start + + +def generate_seq_timestamps(t_start,len_seq=20): + if isinstance(t_start,int): t_start = str(t_start) + if isinstance(t_start,np.ndarray):t_start = str(t_start[0]) + s_datetime = datetime.datetime.strptime(t_start, '%Y%m%d%H') + seq_ts = [s_datetime + datetime. timedelta(hours = i+1) for i in range(len_seq)] + return seq_ts + + +def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,ts,context_frames,future_length,model_name,fl_name="test.nc"): + assert (len(np.array(input_images_).shape)==len(np.array(gen_images_).shape)) + + y_len = len(lats) + x_len = len(lons) + ts_len = len(ts) + ts_input = ts[:context_frames] + ts_forecast = ts[context_frames:] + #print("context_frame:",context_frames) + #print("future_frame",future_length) + #print("length of ts input:",len(ts_input)) + + print("input_images_ shape in netcdf,",input_images_.shape) + gen_images_ = np.array(gen_images_) + + output_file = os.path.join(output_dir,fl_name) + with Dataset(output_file, "w", format="NETCDF4") as nc_file: + nc_file.title = 'ERA5 hourly reanalysis data and the forecasting data by deep learning for 2-m above sea level temperatures' + nc_file.author = "Bing Gong, Michael Langguth" + nc_file.create_date = "2020-08-04" + + #create groups forecasts and analysis + fcst = nc_file.createGroup("forecasts") + analgrp = nc_file.createGroup("analysis") + + #create dims for all the data(forecast and analysis) + latD = nc_file.createDimension('lat', y_len) + lonD = nc_file.createDimension('lon', x_len) + timeD = nc_file.createDimension('time_input', context_frames) + timeF = nc_file.createDimension('time_forecast', future_length) + + #Latitude + lat = nc_file.createVariable('lat', float, ('lat',), zlib = True) + lat.units = 'degrees_north' + lat[:] = lats + + + #Longitude + lon = nc_file.createVariable('lon', float, ('lon',), zlib = True) + lon.units = 'degrees_east' + lon[:] = lons + + #Time for input + time = nc_file.createVariable('time_input', 'f8', ('time_input',), zlib = True) + time.units = "hours since 1970-01-01 00:00:00" + time.calendar = "gregorian" + time[:] = date2num(ts_input, units = time.units, calendar = time.calendar) + + #time for forecast + time_f = nc_file.createVariable('time_forecast', 'f8', ('time_forecast',), zlib = True) + time_f.units = "hours since 1970-01-01 00:00:00" + time_f.calendar = "gregorian" + time_f[:] = date2num(ts_forecast, units = time.units, calendar = time.calendar) + + ################ analysis group ##################### + + #####sub group for inputs + # create variables for non-meta data + #Temperature + t2 = nc_file.createVariable("/analysis/inputs/T2","f4",("time_input","lat","lon"), zlib = True) + t2.units = 'K' + t2[:,:,:] = input_images_[:context_frames,:,:,0] + + #mean sea level pressure + msl = nc_file.createVariable("/analysis/inputs/MSL","f4",("time_input","lat","lon"), zlib = True) + msl.units = 'Pa' + msl[:,:,:] = input_images_[:context_frames,:,:,1] + + #Geopotential at 500 + gph500 = nc_file.createVariable("/analysis/inputs/GPH500","f4",("time_input","lat","lon"), zlib = True) + gph500.units = 'm' + gph500[:,:,:] = input_images_[:context_frames,:,:,2] + + #####sub group for reference(ground truth) + #Temperature + t2_r = nc_file.createVariable("/analysis/reference/T2","f4",("time_forecast","lat","lon"), zlib = True) + t2_r.units = 'K' + t2_r[:,:,:] = input_images_[context_frames:,:,:,0] + + #mean sea level pressure + msl_r = nc_file.createVariable("/analysis/reference/MSL","f4",("time_forecast","lat","lon"), zlib = True) + msl_r.units = 'Pa' + msl_r[:,:,:] = input_images_[context_frames:,:,:,1] + + #Geopotential at 500 + gph500_r = nc_file.createVariable("/analysis/reference/GPH500","f4",("time_forecast","lat","lon"), zlib = True) + gph500_r.units = 'm' + gph500_r[:,:,:] = input_images_[context_frames:,:,:,2] + + + ################ forecast group ##################### + + #Temperature: + t2 = nc_file.createVariable("/forecast/{}/T2".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) + t2.units = 'K' + t2[:,:,:] = gen_images_[context_frames:,:,:,0] + print("NetCDF created") + + #mean sea level pressure + msl = nc_file.createVariable("/forecast/{}/MSL".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) + msl.units = 'Pa' + msl[:,:,:] = gen_images_[context_frames:,:,:,1] + + #Geopotential at 500 + gph500 = nc_file.createVariable("/forecast/{}/GPH500".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) + gph500.units = 'm' + gph500[:,:,:] = gen_images_[context_frames:,:,:,2] + + print("{} created".format(output_file)) + + return None + +def plot_seq_imgs(imgs,lats,lons,ts,output_png_dir,label="Ground Truth"): + """ + Plot the seq images + """ + + if len(np.array(imgs).shape)!=3:raise("img dims should be four: (seq_len,lat,lon)") + if np.array(imgs).shape[0]!= len(ts): raise("The len of timestamps should be equal the image seq_len") + fig = plt.figure(figsize=(18,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + xlables = [round(i,2) for i in list(np.linspace(np.min(lons),np.max(lons),5))] + ylabels = [round(i,2) for i in list(np.linspace(np.max(lats),np.min(lats),5))] + for i in range(len(ts)): + t = ts[i] + #if i==0 : ax1=plt.subplot(gs[i]) + ax1 = plt.subplot(gs[i]) + plt.imshow(imgs[i] ,cmap = 'jet', vmin=270, vmax=300) + ax1.title.set_text("t = " + t.strftime("%Y%m%d%H")) + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + if i == 0: + plt.setp([ax1], xticks = list(np.linspace(0, len(lons), 5)), xticklabels = xlables, yticks = list(np.linspace(0, len(lats), 5)), yticklabels = ylabels) + plt.ylabel(label, fontsize=10) + plt.savefig(os.path.join(output_png_dir, label + "_TS_" + str(ts[0]) + ".jpg")) + plt.clf() + output_fname = label + "_TS_" + ts[0].strftime("%Y%m%d%H") + ".jpg" + print("image {} saved".format(output_fname)) + + +def get_persistence(ts): + pass + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type = str, required = True, @@ -59,106 +346,52 @@ def main(): "the tfrecords") parser.add_argument("--results_dir", type = str, default = 'results', help = "ignored if output_gif_dir is specified") - parser.add_argument("--results_gif_dir", type = str, - help = "default is results_dir. ignored if output_gif_dir is specified") - parser.add_argument("--results_png_dir", type = str, - help = "default is results_dir. ignored if output_png_dir is specified") - parser.add_argument("--output_gif_dir", help = "output directory where samples are saved as gifs. default is " - "results_gif_dir/model_fname") - parser.add_argument("--output_png_dir", help = "output directory where samples are saved as pngs. default is " - "results_png_dir/model_fname") parser.add_argument("--checkpoint", help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") - parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val', help = 'mode for dataset, val or test.') - parser.add_argument("--dataset", type = str, help = "dataset class name") parser.add_argument("--dataset_hparams", type = str, help = "a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type = str, help = "model class name") parser.add_argument("--model_hparams", type = str, help = "a string of comma separated list of model hyperparameters") - parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch") parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type = int, default = 1) - parser.add_argument("--num_stochastic_samples", type = int, default = 1) parser.add_argument("--gif_length", type = int, help = "default is sequence_length") parser.add_argument("--fps", type = int, default = 4) - parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use") parser.add_argument("--seed", type = int, default = 7) - args = parser.parse_args() + set_seed(args.seed) - if args.seed is not None: - tf.set_random_seed(args.seed) - np.random.seed(args.seed) - random.seed(args.seed) - #Bing:20200518 - input_dir = args.input_dir - temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/" - print ("temporal_dir:",temporal_dir) - args.results_gif_dir = args.results_gif_dir or args.results_dir - args.results_png_dir = args.results_png_dir or args.results_dir dataset_hparams_dict = {} model_hparams_dict = {} - if args.checkpoint: - checkpoint_dir = os.path.normpath(args.checkpoint) - if not os.path.isdir(args.checkpoint): - checkpoint_dir, _ = os.path.split(checkpoint_dir) - if not os.path.exists(checkpoint_dir): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) - with open(os.path.join(checkpoint_dir, "options.json")) as f: - print("loading options from checkpoint %s" % args.checkpoint) - options = json.loads(f.read()) - args.dataset = args.dataset or options['dataset'] - args.model = args.model or options['model'] - try: - with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: - dataset_hparams_dict = json.loads(f.read()) - except FileNotFoundError: - print("dataset_hparams.json was not loaded because it does not exist") - try: - with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: - model_hparams_dict = json.loads(f.read()) - except FileNotFoundError: - print("model_hparams.json was not loaded because it does not exist") - args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, - os.path.split(checkpoint_dir)[1]) - args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, - os.path.split(checkpoint_dir)[1]) - else: - if not args.dataset: - raise ValueError('dataset is required when checkpoint is not specified') - if not args.model: - raise ValueError('model is required when checkpoint is not specified') - args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) - args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) + + options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict = load_checkpoints_and_create_output_dirs(args.checkpoint,args.dataset,args.model) + print("Step 1 finished") print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): print(k, "=", v) print('------------------------------------- End --------------------------------------') - VideoDataset = datasets.get_dataset_class(args.dataset) - dataset = VideoDataset( - args.input_dir, - mode = args.mode, - num_epochs = args.num_epochs, - seed = args.seed, - hparams_dict = dataset_hparams_dict, - hparams = args.dataset_hparams) - - VideoPredictionModel = models.get_model_class(args.model) + #setup dataset and model object + input_dir_tf = os.path.join(args.input_dir, "tfrecords") # where tensorflow records are stored + dataset = setup_dataset(dataset,input_dir_tf,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict) + + print("Step 2 finished") + VideoPredictionModel = models.get_model_class(model) + hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': dataset.hparams.context_frames, 'sequence_length': dataset.hparams.sequence_length, 'repeat': dataset.hparams.time_shift, }) + model = VideoPredictionModel( mode = args.mode, hparams_dict = hparams_dict, @@ -166,32 +399,23 @@ def main(): sequence_length = model.hparams.sequence_length context_frames = model.hparams.context_frames - future_length = sequence_length - context_frames - - if args.num_samples: - if args.num_samples > dataset.num_examples_per_epoch(): - raise ValueError('num_samples cannot be larger than the dataset') - num_examples_per_epoch = args.num_samples - else: - num_examples_per_epoch = dataset.num_examples_per_epoch() - if num_examples_per_epoch % args.batch_size != 0: - raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + future_length = sequence_length - context_frames #context_Frames is the number of input frames + num_examples_per_epoch = setup_num_samples_per_epoch(args.num_samples,dataset) + inputs = dataset.make_batch(args.batch_size) + print("inputs",inputs) input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + print("input_phs",input_phs) + + + # Build graph with tf.variable_scope(''): model.build_graph(input_phs) - for output_dir in (args.output_gif_dir, args.output_png_dir): - if not os.path.exists(output_dir): - os.makedirs(output_dir) - with open(os.path.join(output_dir, "options.json"), "w") as f: - f.write(json.dumps(vars(args), sort_keys = True, indent = 4)) - with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: - f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4)) - with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: - f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4)) - + #Write the update hparameters into results_dir + write_params_to_results_dir(args=args,output_dir=args.results_dir,dataset=dataset,model=model) + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac) config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True) sess = tf.Session(config = config) @@ -199,169 +423,375 @@ def main(): sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) model.restore(sess, args.checkpoint) + + #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data + sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data() - sample_ind = 0 - gen_images_all = [] - #Bing:20200410 - persistent_images_all = [] - input_images_all = [] - #Bing:20201417 - print ("temporal_dir:",temporal_dir) - test_temporal_pkl = pickle.load(open(os.path.join(temporal_dir,"T_test.pkl"),"rb")) - #val_temporal_pkl = pickle.load(open("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/T_val.pkl","rb")) - print("test temporal_pkl file looks like folowing", test_temporal_pkl) - - #X_val = hickle.load("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/X_val.hkl") - X_test = hickle.load(os.path.join(temporal_dir,"X_test.hkl")) is_first=True - - #+++Scarlet:20200528 - norm_cls = Norm_data('T2') - norm = 'minmax' - with open(os.path.join(dirname(input_dir),"hickle/splits/statistics.json")) as js_file: - norm_cls.check_and_set_norm(json.load(js_file),norm) - #---Scarlet:20200528 - while True: - print("Sample id", sample_ind) - if sample_ind <= 24: - pass - elif sample_ind >= len(X_test): + #+++Scarlet:20200803 + lats, lons = get_coordinates(os.path.join(args.input_dir,"metadata.json")) + + #---Scarlet:20200803 + #while True: + #Change True to sample_id<=24 for debugging + + #loop for in samples + while sample_ind < 5: + gen_images_stochastic = [] + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results = sess.run(inputs) + input_images = input_results["images"] + #get the intial times + t_starts = input_results["T_start"] + except tf.errors.OutOfRangeError: break - else: - gen_images_stochastic = [] - if args.num_samples and sample_ind >= args.num_samples: - break - try: - input_results = sess.run(inputs) - input_images = input_results["images"] - - - except tf.errors.OutOfRangeError: - break - - feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} - for stochastic_sample_ind in range(args.num_stochastic_samples): - input_images_all.extend(input_images) - with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files: - pickle.dump(list(input_images_all), input_files) - - gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict) - gen_images_stochastic.append(gen_images) - #print("Stochastic_sample,", stochastic_sample_ind) - for i in range(args.batch_size): - #bing:20200417 - t_stampe = test_temporal_pkl[sample_ind+i] - print("timestamp:",type(t_stampe)) - persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1) - print ("persistent ts",persistent_ts) - persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts)) - persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length] - print("persistent index in test set:", persistent_idx) - print("persistent_X.shape",persistent_X.shape) - persistent_images_all.append(persistent_X) + + #Get prediction values + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel] + + #Loop in batch size + for i in range(args.batch_size): + + #get one seq and the corresponding start time point + input_images_,t_start = get_one_seq_and_time(input_images,t_starts,i) + #generate time stamps for sequences + ts = generate_seq_timestamps(t_start,len_seq=sequence_length) + + #Renormalized data for inputs + stat_fl = os.path.join(args.input_dir,"pickle/statistics.json") + input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"]) + print("input_images_denorm",input_images_denorm[0][0]) + + #Renormalized data for inputs + gen_images_ = gen_images[i] + gen_images_denorm = denorm_images_all_channels(stat_fl,gen_images_,["T2","MSL","gph500"]) + print("gene_images_denorm:",gen_images_denorm[0][0]) + + #Save input to netCDF file + init_date_str = ts[0].strftime("%Y%m%d%H") + save_to_netcdf_per_sequence(args.results_dir,input_images_denorm,gen_images_denorm,lons,lats,ts,context_frames,future_length,args.model,fl_name="vfp_{}.nc".format(init_date_str)) + + #Generate images inputs + plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],lats=lats,lons=lons,ts=ts[:context_frames-1],label="Ground Truth",output_png_dir=args.results_dir) + + #Generate forecast images + plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) + + #TODO: Scaret plot persistence image + #implment get_persistence() function - - #print("batch", i) - #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] - - - cmap_name = 'my_list' - if sample_ind < 100: - #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str( - # sample_ind) + " + Sample_" + str(i) - name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S") - print ("name",name) - gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :])) - #gen_images_ = gen_images[i, :] - input_images_ = input_images[i, :] - #Bing:20200417 - #persistent_images = ? - #+++Scarlet:20200528 - #print('Scarlet1') - input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm) - persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm) - #---Scarlet:20200528 - gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in - range(sequence_length)] # return the list with 10 (sequence) mse - persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in - range(sequence_length)] # return the list with 10 (sequence) mse - - fig = plt.figure(figsize=(18,6)) - gs = gridspec.GridSpec(1, 10) - gs.update(wspace = 0., hspace = 0.) - ts = list(range(10,20)) #[10,11,12,..] - xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] - ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] - - for t in ts: - - #if t==0 : ax1=plt.subplot(gs[t]) - ax1 = plt.subplot(gs[ts.index(t)]) - #+++Scarlet:20200528 - #print('Scarlet2') - input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm) - #---Scarlet:20200528 - plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300) - ax1.title.set_text("t = " + str(t+1-10)) - plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) - if t == 0: - plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels) - plt.ylabel("Ground Truth", fontsize=10) - plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg")) - plt.clf() - - fig = plt.figure(figsize=(12,6)) - gs = gridspec.GridSpec(1, 10) - gs.update(wspace = 0., hspace = 0.) - - for t in ts: - #if t==0 : ax1=plt.subplot(gs[t]) - ax1 = plt.subplot(gs[ts.index(t)]) - #+++Scarlet:20200528 - #print('Scarlet3') - gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm) - #---Scarlet:20200528 - plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300) - ax1.title.set_text("t = " + str(t+1-10)) - plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) - - plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg")) - plt.clf() - - - fig = plt.figure(figsize=(12,6)) - gs = gridspec.GridSpec(1, 10) - gs.update(wspace = 0., hspace = 0.) - for t in ts: - #if t==0 : ax1=plt.subplot(gs[t]) - ax1 = plt.subplot(gs[ts.index(t)]) - #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 - plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300) - ax1.title.set_text("t = " + str(t+1-10)) - plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) - - plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg")) - plt.clf() + #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation + + sample_ind += args.batch_size + + + #for input_image in input_images_: + +# for stochastic_sample_ind in range(args.num_stochastic_samples): +# input_images_all.extend(input_images) +# with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files: +# pickle.dump(list(input_images_all), input_files) + + +# gen_images_stochastic.append(gen_images) +# #print("Stochastic_sample,", stochastic_sample_ind) +# for i in range(args.batch_size): +# #bing:20200417 +# t_stampe = test_temporal_pkl[sample_ind+i] +# print("timestamp:",type(t_stampe)) +# persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1) +# print ("persistent ts",persistent_ts) +# persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts)) +# persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length] +# print("persistent index in test set:", persistent_idx) +# print("persistent_X.shape",persistent_X.shape) +# persistent_images_all.append(persistent_X) + +# cmap_name = 'my_list' +# if sample_ind < 100: +# #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str( +# # sample_ind) + " + Sample_" + str(i) +# name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S") +# print ("name",name) +# gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :])) +# #gen_images_ = gen_images[i, :] +# input_images_ = input_images[i, :] +# #Bing:20200417 +# #persistent_images = ? +# #+++Scarlet:20200528 +# #print('Scarlet1') +# input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm) +# persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm) +# #---Scarlet:20200528 +# gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in +# range(sequence_length)] # return the list with 10 (sequence) mse +# persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in +# range(sequence_length)] # return the list with 10 (sequence) mse + +# fig = plt.figure(figsize=(18,6)) +# gs = gridspec.GridSpec(1, 10) +# gs.update(wspace = 0., hspace = 0.) +# ts = list(range(10,20)) #[10,11,12,..] +# xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] +# ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] + +# for t in ts: + +# #if t==0 : ax1=plt.subplot(gs[t]) +# ax1 = plt.subplot(gs[ts.index(t)]) +# #+++Scarlet:20200528 +# #print('Scarlet2') +# input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm) +# #---Scarlet:20200528 +# plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300) +# ax1.title.set_text("t = " + str(t+1-10)) +# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) +# if t == 0: +# plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels) +# plt.ylabel("Ground Truth", fontsize=10) +# plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg")) +# plt.clf() + +# fig = plt.figure(figsize=(12,6)) +# gs = gridspec.GridSpec(1, 10) +# gs.update(wspace = 0., hspace = 0.) + +# for t in ts: +# #if t==0 : ax1=plt.subplot(gs[t]) +# ax1 = plt.subplot(gs[ts.index(t)]) +# #+++Scarlet:20200528 +# #print('Scarlet3') +# gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm) +# #---Scarlet:20200528 +# plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300) +# ax1.title.set_text("t = " + str(t+1-10)) +# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + +# plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg")) +# plt.clf() + + +# fig = plt.figure(figsize=(12,6)) +# gs = gridspec.GridSpec(1, 10) +# gs.update(wspace = 0., hspace = 0.) +# for t in ts: +# #if t==0 : ax1=plt.subplot(gs[t]) +# ax1 = plt.subplot(gs[ts.index(t)]) +# #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 +# plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300) +# ax1.title.set_text("t = " + str(t+1-10)) +# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + +# plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg")) +# plt.clf() - with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files: - pickle.dump(list(persistent_images_all), input_files) - print ("Save persistent all") - if is_first: - gen_images_all = gen_images_stochastic - is_first = False - else: - gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) - - if args.num_stochastic_samples == 1: - with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files: - pickle.dump(list(gen_images_all[0]), gen_files) - print ("Save generate all") - else: - with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: - pickle.dump(list(gen_images_stochastic), gen_files) - with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: - pickle.dump(list(gen_images_all), gen_files) +# with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files: +# pickle.dump(list(persistent_images_all), input_files) +# print ("Save persistent all") +# if is_first: +# gen_images_all = gen_images_stochastic +# is_first = False +# else: +# gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) + +# if args.num_stochastic_samples == 1: +# with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files: +# pickle.dump(list(gen_images_all[0]), gen_files) +# print ("Save generate all") +# else: +# with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: +# pickle.dump(list(gen_images_stochastic), gen_files) +# with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: +# pickle.dump(list(gen_images_all), gen_files) + +# sample_ind += args.batch_size + + +# with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files: +# input_images_all = pickle.load(input_files) + +# with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files: +# gen_images_all = pickle.load(gen_files) + +# with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files: +# persistent_images_all = pickle.load(gen_files) + +# #+++Scarlet:20200528 +# #print('Scarlet4') +# input_images_all = np.array(input_images_all) +# input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm) +# #---Scarlet:20200528 +# persistent_images_all = np.array(persistent_images_all) +# if len(np.array(gen_images_all).shape) == 6: +# for i in range(len(gen_images_all)): +# #+++Scarlet:20200528 +# #print('Scarlet5') +# gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:] +# gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm) +# #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 +# #---Scarlet:20200528 +# mse_all = [] +# psnr_all = [] +# ssim_all = [] +# f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w') +# for i in range(future_length): +# mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :, +# 0]) ** 2) # look at all timesteps except the first +# psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0]) +# ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0], +# data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min( +# input_images_all[:, i + 10, :, :, 0].flatten())) +# mse_all.extend([mse_model]) +# psnr_all.extend([psnr_model]) +# ssim_all.extend([ssim_model]) +# results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} +# f.write("##########Predicted Frame {}\n".format(str(i + 1))) +# f.write("Model MSE: %f\n" % mse_model) +# # f.write("Previous Frame MSE: %f\n" % mse_prev) +# f.write("Model PSNR: %f\n" % psnr_model) +# f.write("Model SSIM: %f\n" % ssim_model) + + +# pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb")) +# # f.write("Previous frame PSNR: %f\n" % psnr_prev) +# f.write("Shape of X_test: " + str(input_images_all.shape)) +# f.write("") +# f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape)) + +# else: +# #+++Scarlet:20200528 +# #print('Scarlet6') +# gen_images_all = np.array(gen_images_all) +# gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm) +# #---Scarlet:20200528 + +# # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first +# # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2) +# # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 ) +# mse_all = [] +# psnr_all = [] +# ssim_all = [] +# persistent_mse_all = [] +# persistent_psnr_all = [] +# persistent_ssim_all = [] +# f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w') +# for i in range(future_length): +# mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :, +# 0]) ** 2) # look at all timesteps except the first +# persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :, +# 0]) ** 2) # look at all timesteps except the first + +# psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0]) +# ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0], +# data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min( +# input_images_all[:, i + 10, :, :, 0].flatten())) +# persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0]) +# persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0], +# data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten())) +# mse_all.extend([mse_model]) +# psnr_all.extend([psnr_model]) +# ssim_all.extend([ssim_model]) +# persistent_mse_all.extend([persistent_mse_model]) +# persistent_psnr_all.extend([persistent_psnr_model]) +# persistent_ssim_all.extend([persistent_ssim_model]) +# results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} + +# persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all} +# f.write("##########Predicted Frame {}\n".format(str(i + 1))) +# f.write("Model MSE: %f\n" % mse_model) +# # f.write("Previous Frame MSE: %f\n" % mse_prev) +# f.write("Model PSNR: %f\n" % psnr_model) +# f.write("Model SSIM: %f\n" % ssim_model) + +# pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb")) +# pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb")) +# # f.write("Previous frame PSNR: %f\n" % psnr_prev) +# f.write("Shape of X_test: " + str(input_images_all.shape)) +# f.write("") +# f.write("Shape of X_hat: " + str(gen_images_all.shape) + +if __name__ == '__main__': + main() + + #psnr_model = psnr(input_images_all[:, :10, :, :, 0], gen_images_all[:, :10, :, :, 0]) + #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0], gen_images_all[:,10, :, :, 0]) + #psnr_prev = psnr(input_images_all[:, :, :, :, 0], input_images_all[:, 1:10, :, :, 0]) + + # ims = [] + # fig = plt.figure() + # for frame in range(20): + # input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel) + # #pix_mean = np.mean(input_gen_diff, axis = 0) + # #pix_std = np.std(input_gen_diff, axis=0) + # im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu') + # if frame == 0: + # fig.colorbar(im) + # ttl = plt.text(1.5, 2, "Frame_" + str(frame +1)) + # ims.append([im, ttl]) + # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000) + # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4")) + # plt.close("all") + + # ims = [] + # fig = plt.figure() + # for frame in range(19): + # pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) + # #pix_mean = np.mean(input_gen_diff, axis = 0) + # #pix_std = np.std(input_gen_diff, axis=0) + # im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu') + # if frame == 0: + # fig.colorbar(im) + # ttl = plt.text(1.5, 2, "Frame_" + str(frame+1)) + # ims.append([im, ttl]) + # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4")) + + # seed(1) + # s = random.sample(range(len(gen_images_all)), 100) + # print("******KDP******") + # #kernel density plot for checking the model collapse + # fig = plt.figure() + # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images") + # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True") + # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability') + # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400) + # plt.clf() + + #line plot for evaluating the prediction and groud-truth + # for i in [0,3,6,9,12,15,18]: + # fig = plt.figure() + # plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3) + # #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3) + # plt.xlabel("Prediction") + # plt.ylabel("Real values") + # plt.title("Frame_{}".format(i+1)) + # plt.plot([250,300], [250,300],color="black") + # plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i)))) + # plt.clf() + # + # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence) + # x = [str(i+1) for i in list(range(19))] + # fig,axis = plt.subplots() + # mean_f = np.mean(mse_model_by_frames, axis = 0) + # median = np.median(mse_model_by_frames, axis=0) + # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0) + # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0) + # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0) + # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0) + # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range") + # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range") + # plt.plot(x, median, color="grey", linewidth=0.6, label="Median") + # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean") + # plt.title(f'MSE percentile') + # plt.xlabel("Frames") + # plt.legend(loc=2, fontsize=8) + # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png")) + + ## ## ## # fig = plt.figure() @@ -457,7 +887,9 @@ def main(): #### else: #### pass ## - sample_ind += args.batch_size + + + # # for i, gen_mse_avg_ in enumerate(gen_mse_avg): @@ -545,187 +977,3 @@ def main(): # # else: # # gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) # # cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) - - - - with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files: - input_images_all = pickle.load(input_files) - - with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files: - gen_images_all = pickle.load(gen_files) - - with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files: - persistent_images_all = pickle.load(gen_files) - - #+++Scarlet:20200528 - #print('Scarlet4') - input_images_all = np.array(input_images_all) - input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm) - #---Scarlet:20200528 - persistent_images_all = np.array(persistent_images_all) - if len(np.array(gen_images_all).shape) == 6: - for i in range(len(gen_images_all)): - #+++Scarlet:20200528 - #print('Scarlet5') - gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:] - gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm) - #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 - #---Scarlet:20200528 - mse_all = [] - psnr_all = [] - ssim_all = [] - f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w') - for i in range(future_length): - mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :, - 0]) ** 2) # look at all timesteps except the first - psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0]) - ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0], - data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min( - input_images_all[:, i + 10, :, :, 0].flatten())) - mse_all.extend([mse_model]) - psnr_all.extend([psnr_model]) - ssim_all.extend([ssim_model]) - results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} - f.write("##########Predicted Frame {}\n".format(str(i + 1))) - f.write("Model MSE: %f\n" % mse_model) - # f.write("Previous Frame MSE: %f\n" % mse_prev) - f.write("Model PSNR: %f\n" % psnr_model) - f.write("Model SSIM: %f\n" % ssim_model) - - - pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb")) - # f.write("Previous frame PSNR: %f\n" % psnr_prev) - f.write("Shape of X_test: " + str(input_images_all.shape)) - f.write("") - f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape)) - - else: - #+++Scarlet:20200528 - #print('Scarlet6') - gen_images_all = np.array(gen_images_all) - gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm) - #---Scarlet:20200528 - - # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first - # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2) - # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 ) - mse_all = [] - psnr_all = [] - ssim_all = [] - persistent_mse_all = [] - persistent_psnr_all = [] - persistent_ssim_all = [] - f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w') - for i in range(future_length): - mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :, - 0]) ** 2) # look at all timesteps except the first - persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :, - 0]) ** 2) # look at all timesteps except the first - - psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0]) - ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0], - data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min( - input_images_all[:, i + 10, :, :, 0].flatten())) - persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0]) - persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0], - data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten())) - mse_all.extend([mse_model]) - psnr_all.extend([psnr_model]) - ssim_all.extend([ssim_model]) - persistent_mse_all.extend([persistent_mse_model]) - persistent_psnr_all.extend([persistent_psnr_model]) - persistent_ssim_all.extend([persistent_ssim_model]) - results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} - - persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all} - f.write("##########Predicted Frame {}\n".format(str(i + 1))) - f.write("Model MSE: %f\n" % mse_model) - # f.write("Previous Frame MSE: %f\n" % mse_prev) - f.write("Model PSNR: %f\n" % psnr_model) - f.write("Model SSIM: %f\n" % ssim_model) - - pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb")) - pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb")) - # f.write("Previous frame PSNR: %f\n" % psnr_prev) - f.write("Shape of X_test: " + str(input_images_all.shape)) - f.write("") - f.write("Shape of X_hat: " + str(gen_images_all.shape)) - - - - #psnr_model = psnr(input_images_all[:, :10, :, :, 0], gen_images_all[:, :10, :, :, 0]) - #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0], gen_images_all[:,10, :, :, 0]) - #psnr_prev = psnr(input_images_all[:, :, :, :, 0], input_images_all[:, 1:10, :, :, 0]) - - # ims = [] - # fig = plt.figure() - # for frame in range(20): - # input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel) - # #pix_mean = np.mean(input_gen_diff, axis = 0) - # #pix_std = np.std(input_gen_diff, axis=0) - # im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu') - # if frame == 0: - # fig.colorbar(im) - # ttl = plt.text(1.5, 2, "Frame_" + str(frame +1)) - # ims.append([im, ttl]) - # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000) - # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4")) - # plt.close("all") - - # ims = [] - # fig = plt.figure() - # for frame in range(19): - # pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) - # #pix_mean = np.mean(input_gen_diff, axis = 0) - # #pix_std = np.std(input_gen_diff, axis=0) - # im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu') - # if frame == 0: - # fig.colorbar(im) - # ttl = plt.text(1.5, 2, "Frame_" + str(frame+1)) - # ims.append([im, ttl]) - # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) - # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4")) - - # seed(1) - # s = random.sample(range(len(gen_images_all)), 100) - # print("******KDP******") - # #kernel density plot for checking the model collapse - # fig = plt.figure() - # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images") - # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True") - # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability') - # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400) - # plt.clf() - - #line plot for evaluating the prediction and groud-truth - # for i in [0,3,6,9,12,15,18]: - # fig = plt.figure() - # plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3) - # #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3) - # plt.xlabel("Prediction") - # plt.ylabel("Real values") - # plt.title("Frame_{}".format(i+1)) - # plt.plot([250,300], [250,300],color="black") - # plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i)))) - # plt.clf() - # - # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence) - # x = [str(i+1) for i in list(range(19))] - # fig,axis = plt.subplots() - # mean_f = np.mean(mse_model_by_frames, axis = 0) - # median = np.median(mse_model_by_frames, axis=0) - # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0) - # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0) - # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0) - # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0) - # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range") - # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range") - # plt.plot(x, median, color="grey", linewidth=0.6, label="Median") - # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean") - # plt.title(f'MSE percentile') - # plt.xlabel("Frames") - # plt.legend(loc=2, fontsize=8) - # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png")) - -if __name__ == '__main__': - main() diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py index 2f892f69c901f1eaa0a7ce2e57a3d0f6f131a7f9..805e81d79707338665f0fe2b1fad3b212359c233 100644 --- a/video_prediction_savp/scripts/train_dummy.py +++ b/video_prediction_savp/scripts/train_dummy.py @@ -11,6 +11,16 @@ import time import numpy as np import tensorflow as tf from video_prediction import datasets, models +import matplotlib.pyplot as plt +from json import JSONEncoder +import pickle as pkl + + +class NumpyArrayEncoder(JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + return JSONEncoder.default(self, obj) def add_tag_suffix(summary, tag_suffix): @@ -23,41 +33,11 @@ def add_tag_suffix(summary, tag_suffix): value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:]) return summary.SerializeToString() - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " - "train, val, test, etc, or a directory containing " - "the tfrecords") - parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") - parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") - parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " - "default is logs_dir/model_fname, where model_fname consists of " - "information from model and model_hparams") - parser.add_argument("--output_dir_postfix", default="") - parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") - parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') - - parser.add_argument("--dataset", type=str, help="dataset class name") - parser.add_argument("--model", type=str, help="model class name") - parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") - parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") - - # parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training") - parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") - parser.add_argument("--seed", type=int) - - args = parser.parse_args() - - if args.seed is not None: - tf.set_random_seed(args.seed) - np.random.seed(args.seed) - random.seed(args.seed) - - if args.output_dir is None: +def generate_output_dir(output_dir, model,model_hparams,logs_dir,output_dir_postfix): + if output_dir is None: list_depth = 0 model_fname = '' - for t in ('model=%s,%s' % (args.model, args.model_hparams)): + for t in ('model=%s,%s' % (model, model_hparams)): if t == '[': list_depth += 1 if t == ']': @@ -69,52 +49,78 @@ def main(): if t in '[]': t = '' model_fname += t - args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix - - if args.resume: - if args.checkpoint: + output_dir = os.path.join(logs_dir, model_fname) + output_dir_postfix + return output_dir + + +def get_model_hparams_dict(model_hparams_dict): + """ + Get model_hparams_dict from json file + """ + model_hparams_dict_load = {} + if model_hparams_dict: + with open(model_hparams_dict) as f: + model_hparams_dict_load.update(json.loads(f.read())) + return model_hparams_dict + +def resume_checkpoint(resume,checkpoint,output_dir): + """ + Resume the existing model checkpoints and set checkpoint directory + """ + if resume: + if checkpoint: raise ValueError('resume and checkpoint cannot both be specified') - args.checkpoint = args.output_dir + checkpoint = output_dir + return checkpoint +def set_seed(seed): + if seed is not None: + tf.set_random_seed(seed) + np.random.seed(seed) + random.seed(seed) - model_hparams_dict = {} - if args.model_hparams_dict: - with open(args.model_hparams_dict) as f: - model_hparams_dict.update(json.loads(f.read())) - if args.checkpoint: - checkpoint_dir = os.path.normpath(args.checkpoint) - if not os.path.isdir(args.checkpoint): +def load_params_from_checkpoints_dir(model_hparams_dict,checkpoint,dataset,model): + + model_hparams_dict_load = {} + if model_hparams_dict: + with open(model_hparams_dict) as f: + model_hparams_dict_load.update(json.loads(f.read())) + + if checkpoint: + checkpoint_dir = os.path.normpath(checkpoint) + if not os.path.isdir(checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) - args.dataset = args.dataset or options['dataset'] - args.model = args.model or options['model'] + dataset = dataset or options['dataset'] + model = model or options['model'] try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: - model_hparams_dict.update(json.loads(f.read())) + model_hparams_dict_load.update(json.loads(f.read())) except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") + return dataset, model, model_hparams_dict_load - print('----------------------------------- Options ------------------------------------') - for k, v in args._get_kwargs(): - print(k, "=", v) - print('------------------------------------- End --------------------------------------') - - VideoDataset = datasets.get_dataset_class(args.dataset) +def setup_dataset(dataset,input_dir,val_input_dir): + VideoDataset = datasets.get_dataset_class(dataset) train_dataset = VideoDataset( - args.input_dir, + input_dir, mode='train') val_dataset = VideoDataset( - args.val_input_dir or args.input_dir, + val_input_dir or input_dir, mode='val') - variable_scope = tf.get_variable_scope() variable_scope.set_use_resource(True) + return train_dataset,val_dataset,variable_scope - VideoPredictionModel = models.get_model_class(args.model) +def setup_model(model,model_hparams_dict,train_dataset,model_hparams): + """ + Set up model instance + """ + VideoPredictionModel = models.get_model_class(model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': train_dataset.hparams.context_frames, @@ -123,9 +129,21 @@ def main(): }) model = VideoPredictionModel( hparams_dict=hparams_dict, - hparams=args.model_hparams) + hparams=model_hparams) + return model - batch_size = model.hparams.batch_size +def save_dataset_model_params_to_checkpoint_dir(args,output_dir,train_dataset,model): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + return None + +def make_dataset_iterator(train_dataset, val_dataset, batch_size ): train_tf_dataset = train_dataset.make_dataset_v2(batch_size) train_iterator = train_tf_dataset.make_one_shot_iterator() # The `Iterator.string_handle()` method returns a tensor that can be evaluated @@ -138,18 +156,87 @@ def main(): # train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) inputs = train_iterator.get_next() val = val_iterator.get_next() - - model.build_graph(inputs) + return inputs,train_handle, val_handle + + +def plot_train(train_losses,val_losses,output_dir): + iterations = list(range(len(train_losses))) + plt.plot(iterations, train_losses, 'g', label='Training loss') + plt.plot(iterations, val_losses, 'b', label='validation loss') + plt.title('Training and Validation loss') + plt.xlabel('Iterations') + plt.ylabel('Loss') + plt.legend() + plt.savefig(os.path.join(output_dir,'plot_train.png')) + +def save_results_to_dict(results_dict,output_dir): + with open(os.path.join(output_dir,"results.json"),"w") as fp: + json.dump(results_dict,fp) + +def save_results_to_pkl(train_losses,val_losses, output_dir): + with open(os.path.join(output_dir,"train_losses.pkl"),"wb") as f: + pkl.dump(train_losses,f) + with open(os.path.join(output_dir,"val_losses.pkl"),"wb") as f: + pkl.dump(val_losses,f) + - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - with open(os.path.join(args.output_dir, "options.json"), "w") as f: - f.write(json.dumps(vars(args), sort_keys=True, indent=4)) - with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: - f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) - with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: - f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") + parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") + parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " + "default is logs_dir/model_fname, where model_fname consists of " + "information from model and model_hparams") + parser.add_argument("--output_dir_postfix", default="") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") + + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed",default=1234, type=int) + + args = parser.parse_args() + + #Set seed + set_seed(args.seed) + + #setup output directory + args.output_dir = generate_output_dir(args.output_dir, args.model, args.model_hparams, args.logs_dir, args.output_dir_postfix) + + #resume the existing checkpoint and set up the checkpoint directory to output directory + args.checkpoint = resume_checkpoint(args.resume,args.checkpoint,args.output_dir) + + #get model hparams dict from json file + #load the existing checkpoint related datasets, model configure (This information was stored in the checkpoint dir when last time training model) + args.dataset,args.model,model_hparams_dict = load_params_from_checkpoints_dir(args.model_hparams_dict,args.checkpoint,args.dataset,args.model) + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + #setup training val datset instance + train_dataset,val_dataset,variable_scope = setup_dataset(args.dataset,args.input_dir,args.val_input_dir) + + #setup model instance + model=setup_model(args.model,model_hparams_dict,train_dataset,args.model_hparams) + + batch_size = model.hparams.batch_size + #Create input and val iterator + inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size) + + #build model graph + model.build_graph(inputs) + + #save all the model, data params to output dirctory + save_dataset_model_params_to_checkpoint_dir(args,args.output_dir,train_dataset,model) + with tf.name_scope("parameter_count"): # exclude trainable variables that are replicas (used in multi-gpu setting) trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables) @@ -162,113 +249,88 @@ def main(): gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) - - max_steps = model.hparams.max_steps - print ("max_steps",max_steps) + max_epochs = model.hparams.max_epochs #the number of epochs + num_examples_per_epoch = train_dataset.num_examples_per_epoch() + print ("number of exmaples per epoch:",num_examples_per_epoch) + steps_per_epoch = int(num_examples_per_epoch/batch_size) + total_steps = steps_per_epoch * max_epochs + #mock total_steps only for fast debugging + #total_steps = 10 + print ("Total steps for training:",total_steps) + results_dict = {} with tf.Session(config=config) as sess: print("parameter_count =", sess.run(parameter_count)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) - #coord = tf.train.Coordinator() - #threads = tf.train.start_queue_runners(sess = sess, coord = coord) - print("Init done: {sess.run(tf.local_variables_initializer())}%") - model.restore(sess, args.checkpoint) - - #sess.run(model.post_init_ops) - - #val_handle_eval = sess.run(val_handle) - #print ("val_handle_val",val_handle_eval) - #print("val handle done") + #model.restore(sess, args.checkpoint) sess.graph.finalize() start_step = sess.run(model.global_step) - - + print("start_step", start_step) # start at one step earlier to log everything without doing any training # step is relative to the start_step - for step in range(-1, max_steps - start_step): + train_losses=[] + val_losses=[] + run_start_time = time.time() + for step in range(total_steps): global_step = sess.run(model.global_step) print ("global_step:", global_step) val_handle_eval = sess.run(val_handle) - - if step == 1: - # skip step -1 and 0 for timing purposes (for warmstarting) - start_time = time.time() + #Fetch variables in the graph fetches = {"global_step":model.global_step} fetches["train_op"] = model.train_op - - # fetches["latent_loss"] = model.latent_loss + #fetches["latent_loss"] = model.latent_loss fetches["total_loss"] = model.total_loss + + #fetch the specific loss function only for mcnet if model.__class__.__name__ == "McNetVideoPredictionModel": fetches["L_p"] = model.L_p fetches["L_gdl"] = model.L_gdl fetches["L_GAN"] =model.L_GAN - - - - fetches["summary"] = model.summary_op - - run_start_time = time.time() - #Run training results - #X = inputs["images"].eval(session=sess) - + + if model.__class__.__name__ == "SAVP": + #todo + pass + + fetches["summary"] = model.summary_op results = sess.run(fetches) - - run_elapsed_time = time.time() - run_start_time - if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: - print('running train_op took too long (%0.1fs)' % run_elapsed_time) - - #Run testing results - #val_fetches = {"global_step":global_step} + train_losses.append(results["total_loss"]) + #Fetch losses for validation data val_fetches = {} #val_fetches["latent_loss"] = model.latent_loss - #val_fetches["total_loss"] = model.total_loss + val_fetches["total_loss"] = model.total_loss val_fetches["summary"] = model.summary_op val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval}) - + val_losses.append(val_results["total_loss"]) + summary_writer.add_summary(results["summary"]) summary_writer.add_summary(val_results["summary"]) - - - - - val_datasets = [val_dataset] - val_models = [model] - - # for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)): - # sess.run(val_model.accum_eval_metrics_reset_op) - # # traverse (roughly up to rounding based on the batch size) all the validation dataset - # accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size - # val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op} - # for update_step in range(accum_eval_summary_num_updates): - # print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates)) - # val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) - # accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1)) - # print("recording accum eval summary") - # summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) summary_writer.flush() - + # global_step will have the correct step count if we resume from a checkpoint - # global step is read before it's incremented - steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size - #train_epoch = results["global_step"] / steps_per_epoch + # global step is read before it's incemented train_epoch = global_step/steps_per_epoch print("progress global step %d epoch %0.1f" % (global_step + 1, train_epoch)) - if step > 0: - elapsed_time = time.time() - start_time - average_time = elapsed_time / step - images_per_sec = batch_size / average_time - remaining_time = (max_steps - (start_step + step + 1)) * average_time - print("image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % - (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) - - print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"])) + if model.__class__.__name__ == "McNetVideoPredictionModel": + print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"])) + elif model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": + print ("Total_loss:{}".format(results["total_loss"])) + else: + print ("The model name does not exist") - print("saving model to", args.output_dir) - saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)##Bing: cheat here a little bit because of the global step issue - print("done") - + #print("saving model to", args.output_dir) + saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)# + train_time = time.time() - run_start_time + results_dict = {"train_time":train_time, + "total_steps":total_steps} + save_results_to_dict(results_dict,args.output_dir) + save_results_to_pkl(train_losses, val_losses, args.output_dir) + print("train_losses:",train_losses) + print("val_losses:",val_losses) + plot_train(train_losses,val_losses,args.output_dir) + print("Done") + if __name__ == '__main__': main() diff --git a/video_prediction_savp/video_prediction/datasets/__init__.py b/video_prediction_savp/video_prediction/datasets/__init__.py index 736b8202172051f36586db5579c545863c72e14d..e449a65bd48b14ef5a11e6846ce4d8f39f7ed193 100644 --- a/video_prediction_savp/video_prediction/datasets/__init__.py +++ b/video_prediction_savp/video_prediction/datasets/__init__.py @@ -7,7 +7,7 @@ from .kth_dataset import KTHVideoDataset from .ucf101_dataset import UCF101VideoDataset from .cartgripper_dataset import CartgripperVideoDataset from .era5_dataset_v2 import ERA5Dataset_v2 -from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly +#from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly def get_dataset_class(dataset): dataset_mappings = { @@ -19,7 +19,7 @@ def get_dataset_class(dataset): 'ucf101': 'UCF101VideoDataset', 'cartgripper': 'CartgripperVideoDataset', "era5":"ERA5Dataset_v2", - "era5_anomaly":"ERA5Dataset_v2_anomaly", +# "era5_anomaly":"ERA5Dataset_v2_anomaly", } dataset_class = dataset_mappings.get(dataset, dataset) print("datset_class",dataset_class) diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py index 9e32e0c638b4b39c588389e906ba29be5144ee35..a3c9fc3666eb21ef02f9e5f64c8c95a29034d619 100644 --- a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py +++ b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py @@ -5,7 +5,6 @@ import os import pickle import random import re -import hickle as hkl import numpy as np import json import tensorflow as tf @@ -17,9 +16,14 @@ sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/')) import DataPreprocess.process_netCDF_v2 from DataPreprocess.process_netCDF_v2 import get_unique_vars from DataPreprocess.process_netCDF_v2 import Calc_data_stat +from metadata import MetaData #from base_dataset import VarLenFeatureVideoDataset from collections import OrderedDict from tensorflow.contrib.training import HParams +from mpi4py import MPI +import glob + + class ERA5Dataset_v2(VarLenFeatureVideoDataset): def __init__(self, *args, **kwargs): @@ -28,6 +32,7 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): example = next(tf.python_io.tf_record_iterator(self.filenames[0])) dict_message = MessageToDict(tf.train.Example.FromString(example)) feature = dict_message['features']['feature'] + print("features in dataset:",feature.keys()) self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels']) self.image_shape = self.video_shape[1:] self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape @@ -57,7 +62,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) - def filter(self, serialized_example): return tf.convert_to_tensor(True) @@ -70,7 +74,8 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): 'height': tf.FixedLenFeature([], tf.int64), 'sequence_length': tf.FixedLenFeature([], tf.int64), 'channels': tf.FixedLenFeature([],tf.int64), - # 'images/encoded': tf.FixedLenFeature([], tf.string) + #'t_start': tf.FixedLenFeature([], tf.string), + 't_start': tf.VarLenFeature(tf.int64), 'images/encoded': tf.VarLenFeature(tf.float32) } @@ -79,28 +84,22 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): parsed_features = tf.parse_single_example(serialized_example, keys_to_features) print ("Parse features", parsed_features) seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) - # width = tf.sparse_tensor_to_dense(parsed_features["width"]) + T_start = tf.sparse_tensor_to_dense(parsed_features["t_start"]) + print("T_start in make dataset_v2", T_start) + #width = tf.sparse_tensor_to_dense(parsed_features["width"]) # height = tf.sparse_tensor_to_dense(parsed_features["height"]) # channels = tf.sparse_tensor_to_dense(parsed_features["channels"]) # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"]) images = [] - # for i in range(20): - # images.append(parsed_features["images/encoded"].values[i]) - # images = parsed_features["images/encoded"] - # images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets) - # seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '') - # Parse the string into an array of pixels corresponding to the image - # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32) - - # images = seq print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2])) images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new") seqs["images"] = images + seqs["T_start"] = T_start return seqs filenames = self.filenames print ("FILENAMES",filenames) - #TODO: - #temporal_filenames = self.temporal_filenames + #TODO: + #temporal_filenames = self.temporal_filenames shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) if shuffle: random.shuffle(filenames) @@ -121,7 +120,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): # dataset = dataset.apply(tf.contrib.data.map_and_batch( # _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs dataset = dataset.prefetch(batch_size) # Bing: Take the data to buffer inorder to save the waiting time for GPU - return dataset @@ -139,24 +137,26 @@ def _bytes_list_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) def _floats_feature(value): - return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) -def save_tf_record(output_fname, sequences): - print('saving sequences to %s' % output_fname) +def save_tf_record(output_fname, sequences,T_start_points): with tf.python_io.TFRecordWriter(output_fname) as writer: - for sequence in sequences: + for i in range(len(sequences)): + sequence = sequences[i] + T_start = T_start_points[i][0].strftime("%Y%m%d%H") + print("T_start:",T_start) num_frames = len(sequence) height, width, channels = sequence[0].shape encoded_sequence = np.array([list(image) for image in sequence]) - features = tf.train.Features(feature={ 'sequence_length': _int64_feature(num_frames), 'height': _int64_feature(height), 'width': _int64_feature(width), 'channels': _int64_feature(channels), + 't_start': _int64_feature(int(T_start)), 'images/encoded': _floats_feature(encoded_sequence.flatten()), }) example = tf.train.Example(features=features) @@ -207,91 +207,107 @@ class Norm_data: for stat_name in self.known_norms[norm]: #setattr(self,varname+stat_name,stat_dict[varname][0][stat_name]) setattr(self,varname+stat_name,Calc_data_stat.get_stat_vars(stat_dict,stat_name,varname)) - + self.status_ok = True # set status for normalization -> ready - + def norm_var(self,data,varname,norm): """ Performs given normalization on input data (given that the instance is already set up) """ - + # some sanity checks if not self.status_ok: raise ValueError("Norm_data-instance needs to be initialized and checked first.") # status ready? - + if not norm in self.known_norms.keys(): # valid normalization requested? print("Please select one of the following known normalizations: ") for norm_avail in self.known_norms.keys(): print(norm_avail) raise ValueError("Passed normalization '"+norm+"' is unknown.") - + # do the normalization and return if norm == "minmax": return((data[...] - getattr(self,varname+"min"))/(getattr(self,varname+"max") - getattr(self,varname+"min"))) elif norm == "znorm": return((data[...] - getattr(self,varname+"avg"))/getattr(self,varname+"sigma")**2) - + def denorm_var(self,data,varname,norm): """ Performs given denormalization on input data (given that the instance is already set up), i.e. inverse method to norm_var """ - + # some sanity checks if not self.status_ok: raise ValueError("Norm_data-instance needs to be initialized and checked first.") # status ready? - + if not norm in self.known_norms.keys(): # valid normalization requested? print("Please select one of the following known normalizations: ") for norm_avail in self.known_norms.keys(): print(norm_avail) raise ValueError("Passed normalization '"+norm+"' is unknown.") - + # do the denormalization and return if norm == "minmax": return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"min")) elif norm == "znorm": return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg")) - -def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128 + +def read_frames_and_save_tf_records(stats,output_dir,input_file, temp_input_file, vars_in,year,month,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128 + """ + Read pickle files based on month, to process and save to tfrecords + stats:dict, contains the stats information from pickle directory, + input_file: string, absolute path to pickle file + file_info: 1D list with three elements, partition_name(train,val or test), year, and month e.g.[train,1,2] + """ # ML 2020/04/08: # Include vars_in for more flexible data handling (normalization and reshaping) # and optional keyword argument for kind of normalization - + print ("read_frames_and_save_tf_records function") if 'norm' in kwargs: norm = kwargs.get("norm") else: norm = "minmax" print("Make use of default minmax-normalization...") - output_dir = os.path.join(output_dir,partition_name) + os.makedirs(output_dir,exist_ok=True) - + norm_cls = Norm_data(vars_in) # init normalization-instance nvars = len(vars_in) # open statistics file and feed it to norm-instance - with open(os.path.join(input_dir,"statistics.json")) as js_file: - norm_cls.check_and_set_norm(json.load(js_file),norm) - + #with open(os.path.join(input_dir,"statistics.json")) as js_file: + norm_cls.check_and_set_norm(stats,norm) sequences = [] + T_start_points = [] sequence_iter = 0 - sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') - X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl")) + #sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') + #Bing 2020/07/16 + #print ("open intput dir,",input_file) + with open(input_file, "rb") as data_file: + X_train = pickle.load(data_file) + with open(temp_input_file,"rb") as temp_file: + T_train = pickle.load(temp_file) + #print("T_train:",T_train) + #check to make sure the X_train and T_train has the same length + assert (len(X_train) == len(T_train)) + X_possible_starts = [i for i in range(len(X_train) - seq_length)] for X_start in X_possible_starts: - print("Interation", sequence_iter) X_end = X_start + seq_length #seq = X_train[X_start:X_end, :, :,:] - seq = X_train[X_start:X_end,:,:] - #print("*****len of seq ***.{}".format(len(seq))) - #seq = list(np.array(seq).reshape((len(seq), 64, 64, 3))) + seq = X_train[X_start:X_end,:,:,:] + #Recored the start point of the timestamps + T_start = T_train[X_start] + #print("T_start:",T_start) seq = list(np.array(seq).reshape((seq_length, height, width, nvars))) if not sequences: last_start_sequence_iter = sequence_iter - print("reading sequences starting at sequence %d" % sequence_iter) + + sequences.append(seq) - sequence_iter += 1 - sequence_lengths_file.write("%d\n" % len(seq)) - + T_start_points.append(T_start) + sequence_iter += 1 + if len(sequences) == sequences_per_file: ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables sequences = np.array(sequences) @@ -299,12 +315,29 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, for i in range(nvars): sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm) - output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) + output_fname = 'sequence_Y_{}_M_{}_{}_to_{}.tfrecords'.format(year,month,last_start_sequence_iter,sequence_iter - 1) output_fname = os.path.join(output_dir, output_fname) - save_tf_record(output_fname, list(sequences)) + print("T_start_points:",T_start_points) + save_tf_record(output_fname, list(sequences), T_start_points) + T_start_points = [] sequences = [] - sequence_lengths_file.close() + print("Finished for input file",input_file) + #sequence_lengths_file.close() + return +def write_sequence_file(output_dir,seq_length,sequences_per_file): + + partition_names = ["train","val","test"] + for partition_name in partition_names: + save_output_dir = os.path.join(output_dir,partition_name) + tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords")) + print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter)) + sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w') + for i in range(tfCounter*sequences_per_file): + sequence_lengths_file.write("%d\n" % seq_length) + sequence_lengths_file.close() + + def main(): parser = argparse.ArgumentParser() @@ -316,16 +349,116 @@ def main(): parser.add_argument("-height",type=int,default=64) parser.add_argument("-width",type = int,default=64) parser.add_argument("-seq_length",type=int,default=20) + parser.add_argument("-sequences_per_file",type=int,default=2) args = parser.parse_args() current_path = os.getcwd() #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits" #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5" - partition_names = ['train','val', 'test'] #64,64,3 val has issue# + #partition_names = ['train','val', 'test'] #64,64,3 val has issue# + + ############################################################ + # CONTROLLING variable! Needs to be adapted manually!!! + ############################################################ + partition = { + "train":{ + # "2222":[1,2,3,5,6,7,8,9,10,11,12], + # "2010_1":[1,2,3,4,5,6,7,8,9,10,11,12], + # "2012":[1,2,3,4,5,6,7,8,9,10,11,12], + # "2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12], + # "2015":[1,2,3,4,5,6,7,8,9,10,11,12], + "2017_test":[1,2,3,4,5,6,7,8,9,10] + }, + "val": + {"2017_test":[11] + }, + "test": + {"2017_test":[12] + } + } + + # ini. MPI + comm = MPI.COMM_WORLD + my_rank = comm.Get_rank() # rank of the node + p = comm.Get_size() # number of assigned nods - for partition_name in partition_names: - read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq - #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5 - + if my_rank == 0 : + # retrieve final statistics first (not parallelized!) + # some preparatory steps + stat_dir_prefix = args.input_dir + varnames = args.variables + + vars_uni, varsind, nvars = get_unique_vars(varnames) + stat_obj = Calc_data_stat(nvars) # init statistic-instance + + # loop over whole data set (training, dev and test set) to collect the intermediate statistics + print("Start collecting statistics from the whole datset to be processed...") + for split in partition.keys(): + values = partition[split] + for year in values.keys(): + file_dir = os.path.join(stat_dir_prefix,year) + for month in values[year]: + # process stat-file: + stat_obj.acc_stat_master(file_dir,int(month)) # process monthly statistic-file + + # finalize statistics and write to json-file + stat_obj.finalize_stat_master(vars_uni) + stat_obj.write_stat_json(args.input_dir) + + # organize parallelized partioning + partition_year_month = [] #contain lists of list, each list includes three element [train,year,month] + partition_names = list(partition.keys()) + print ("partition_names:",partition_names) + broadcast_lists = [] + for partition_name in partition_names: + partition_data = partition[partition_name] + years = list(partition_data.keys()) + broadcast_lists.append([partition_name,years]) + for nodes in range(1,p): + #ibroadcast_list = [partition_name,years,nodes] + #broadcast_lists.append(broadcast_list) + comm.send(broadcast_lists,dest=nodes) + + message_counter = 1 + while message_counter <= 12: + message_in = comm.recv() + message_counter = message_counter + 1 + print("Message in from slaver",message_in) + + write_sequence_file(args.output_dir,args.seq_length,args.sequences_per_file) + + #write_sequence_file + else: + message_in = comm.recv() + print ("My rank,", my_rank) + print("message_in",message_in) + # open statistics file and feed it to norm-instance + print("Opening json-file: "+os.path.join(args.input_dir,"statistics.json")) + with open(os.path.join(args.input_dir,"statistics.json")) as js_file: + stats = json.load(js_file) + #loop the partitions (train,val,test) + for partition in message_in: + print("partition on slave ",partition) + partition_name = partition[0] + save_output_dir = os.path.join(args.output_dir,partition_name) + for year in partition[1]: + input_file = "X_" + '{0:02}'.format(my_rank) + ".pkl" + temp_file = "T_" + '{0:02}'.format(my_rank) + ".pkl" + input_dir = os.path.join(args.input_dir,year) + temp_file = os.path.join(input_dir,temp_file ) + input_file = os.path.join(input_dir,input_file) + # create the tfrecords-files + read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir, \ + input_file=input_file,temp_input_file=temp_file,vars_in=args.variables, \ + partition_name=partition_name,seq_length=args.seq_length, \ + height=args.height,width=args.width,sequences_per_file=args.sequences_per_file) + + print("Year {} finished",year) + message_out = ("Node:",str(my_rank),"finished","","\r\n") + print ("Message out for slaves:",message_out) + comm.send(message_out,dest=0) + + MPI.Finalize() + if __name__ == '__main__': - main() + main() diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py index e7753004348ae0ae60057a469de1e2d1421c3869..557280d5d7169b9212844d4739ce3e0c2df7190b 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py @@ -17,21 +17,17 @@ from video_prediction.layers import layer_def as ld from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): - def __init__(self, mode='train',aggregate_nccl=None, hparams_dict=None, + def __init__(self, mode='train', hparams_dict=None, hparams=None, **kwargs): super(VanillaConvLstmVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) print ("Hparams_dict",self.hparams) self.mode = mode self.learning_rate = self.hparams.lr - self.gen_images_enc = None - self.recon_loss = None - self.latent_loss = None self.total_loss = None - self.context_frames = 10 - self.sequence_length = 20 + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length self.predict_frames = self.sequence_length - self.context_frames - self.aggregate_nccl=aggregate_nccl - + self.max_epochs = self.hparams.max_epochs def get_default_hparams_dict(self): """ The keys of this dict define valid hyperparameters for instances of @@ -44,13 +40,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): batch_size: batch size for training. lr: learning rate. if decay steps is non-zero, this is the learning rate for steps <= decay_step. - - - max_steps: number of training steps. - - - context_frames: the number of ground-truth frames to pass in at + context_frames: the number of ground-truth frames to pass :qin at start. Must be specified during instantiation. sequence_length: the number of frames in the video sequence, including the context frames, so this model predicts @@ -62,46 +53,40 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): hparams = dict( batch_size=16, lr=0.001, - end_lr=0.0, - nz=16, - decay_steps=(200000, 300000), - max_steps=350000, + max_epochs=3000, ) return dict(itertools.chain(default_hparams.items(), hparams.items())) def build_graph(self, x): self.x = x["images"] - + self.global_step = tf.Variable(0, name = 'global_step', trainable = False) original_global_variables = tf.global_variables() # ARCHITECTURE - self.x_hat_context_frames, self.x_hat_predict_frames = self.convLSTM_network() - self.x_hat = tf.concat([self.x_hat_context_frames, self.x_hat_predict_frames], 1) - - - self.context_frames_loss = tf.reduce_mean( - tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) - self.predict_frames_loss = tf.reduce_mean( - tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_predict_frames[:, :, :, :, 0])) - self.total_loss = self.context_frames_loss + self.predict_frames_loss + self.convLSTM_network() + #print("self.x",self.x) + #print("self.x_hat_context_frames,",self.x_hat_context_frames) + #self.context_frames_loss = tf.reduce_mean( + # tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) + # This is the loss function (RMSE): + self.total_loss = tf.reduce_mean( + tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_context_frames[:, (self.context_frames-1):-1, :, :, 0])) self.train_op = tf.train.AdamOptimizer( learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) self.outputs = {} self.outputs["gen_images"] = self.x_hat # Summary op - self.loss_summary = tf.summary.scalar("recon_loss", self.context_frames_loss) - self.loss_summary = tf.summary.scalar("latent_loss", self.predict_frames_loss) self.loss_summary = tf.summary.scalar("total_loss", self.total_loss) self.summary_op = tf.summary.merge_all() global_variables = [var for var in tf.global_variables() if var not in original_global_variables] self.saveable_variables = [self.global_step] + global_variables - return + return None @staticmethod - def convLSTM_cell(inputs, hidden, nz=16): + def convLSTM_cell(inputs, hidden): conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu") @@ -140,23 +125,28 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables # create network x_hat_context = [] - x_hat_predict = [] - seq_start = 1 + x_hat = [] hidden = None - for i in range(self.context_frames): - if i < seq_start: + #This is for training + for i in range(self.sequence_length): + if i < self.context_frames: x_1, hidden = network_template(self.x[:, i, :, :, :], hidden) else: x_1, hidden = network_template(x_1, hidden) x_hat_context.append(x_1) - - for i in range(self.predict_frames): - x_1, hidden = network_template(x_1, hidden) - x_hat_predict.append(x_1) - + + #This is for generating video + hidden_g = None + for i in range(self.sequence_length): + if i < self.context_frames: + x_1_g, hidden_g = network_template(self.x[:, i, :, :, :], hidden_g) + else: + x_1_g, hidden_g = network_template(x_1_g, hidden_g) + x_hat.append(x_1_g) + # pack them all together x_hat_context = tf.stack(x_hat_context) - x_hat_predict = tf.stack(x_hat_predict) - self.x_hat_context = tf.transpose(x_hat_context, [1, 0, 2, 3, 4]) # change first dim with sec dim - self.x_hat_predict = tf.transpose(x_hat_predict, [1, 0, 2, 3, 4]) # change first dim with sec dim - return self.x_hat_context, self.x_hat_predict + x_hat = tf.stack(x_hat) + self.x_hat_context_frames = tf.transpose(x_hat_context, [1, 0, 2, 3, 4]) # change first dim with sec dim + self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim + self.x_hat_predict_frames = self.x_hat[:,self.context_frames:,:,:,:] diff --git a/workflow_parallel_frame_prediction/DataExtraction/main_single_master.py b/workflow_parallel_frame_prediction/DataExtraction/main_single_master.py index 4894408adeb1a41b106faafb75def912ca5e4ad5..fda72c671d2804e87b121a2bd62038890a7f5161 100644 --- a/workflow_parallel_frame_prediction/DataExtraction/main_single_master.py +++ b/workflow_parallel_frame_prediction/DataExtraction/main_single_master.py @@ -93,19 +93,33 @@ def main(): if clear_destination == 1: shutil.rmtree(destination_dir) os.mkdir(destination_dir) - logger.critical("Destination : {destination} exist -> Remove and Re-Cereate".format(destination=destination_dir)) - print("Destination : {destination} exist -> Remove and Re-Cereate".format(destination=destination_dir)) + logger.critical("Destination : {destination} exist -> Remove and Re-Create".format(destination=destination_dir)) + print("Destination : {destination} exist -> Remove and Re-Create".format(destination=destination_dir)) else: logger.critical("Destination : {destination} exist -> will not be removed (caution : overwrite)".format(destination=destination_dir)) print("Destination : {destination} exist -> will not be rmeoved (caution : overwrite)".format(destination=destination_dir)) + + + + # 20200630 +++ Scarlet + else: + if my_rank == 0: + os.makedirs(destination_dir) #, exist_ok=True) + logger.info("Destination : {destination} does not exist -> Create".format(destination=destination_dir)) + print("Destination : {destination} does not exist -> Create".format(destination=destination_dir)) + + # 20200630 --- Scarlet + # Create a log folder for slave-nodes to write down their processes slave_log_path = os.path.join(destination_dir,log_temp) if my_rank == 0: if os.path.exists(slave_log_path) == False: - os.mkdir(slave_log_path) + # 20200630 Scarlet + #os.mkdir(slave_log_path) + os.makedirs(slave_log_path) if my_rank == 0: # node is master diff --git a/workflow_parallel_frame_prediction/DataExtraction/prepare_era5_data.py b/workflow_parallel_frame_prediction/DataExtraction/prepare_era5_data.py index f97bdf8236edb4b4eaac80513d9fb439486705a6..653716493a0514232b21aae0c34d74ddd1d82bb1 100644 --- a/workflow_parallel_frame_prediction/DataExtraction/prepare_era5_data.py +++ b/workflow_parallel_frame_prediction/DataExtraction/prepare_era5_data.py @@ -105,7 +105,9 @@ def prepare_era5_data_one_file(src_file,directory_to_process,target_dir, target= lon_new = test.createVariable('lon', float, ('lon',), zlib = True) lon_new.units = 'degrees_east' time_new = test.createVariable('time', 'f8', ('time',), zlib = True) - time_new.units = "hours since 2000-01-01 00:00:00" + #TODO: THIS SHOULD BE CHANGED TO "since 1970-01-01 00:00:00",BECAUSE ERA5 REANALYSIS DATA IS IN PRINCIPLE FROM 1979 + #WITH "2000-01-01 00:00:00" WE WOULD END UP HANDLING NEGATIVE TIME VALUES + time_new.units = "hours since 2000-01-01 00:00:00" time_new.calendar = "gregorian" p3d_new = test.createVariable('p3d', float, ('lev', 'lat', 'lon'), zlib = True) diff --git a/workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py b/workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py index f6b760c7ad3c528a745975cda9b1c420aa739d77..68d1ddfbfe413aab00950b71a336d0c1a43cbbf8 100644 --- a/workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py +++ b/workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py @@ -23,22 +23,36 @@ varnames = args.varnames #for key in all_keys: # print(partition[key]) -partition = { +cv ={} +partition1 = { "train":{ - "2017":[1] + #"2222":[1,2,3,5,6,7,8,9,10,11,12], + #"2010_1":[1,2,3,4,5,6,7,8,9,10,11,12], + #"2012":[1,2,3,4,5,6,7,8,9,10,11,12], + #"2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12], + #"2015":[1,2,3,4,5,6,7,8,9,10,11,12], + #"2017":[1,2,3,4,5,6,7,8,9,10,11,12] + "2015":[1,2,3,4,5,6,7,8,9,10,11,12] }, "val": - {"2017":[2] + {"2016":[1,2,3,4,5,6,7,8,9,10,11,12] }, "test": - {"2017":[2] + {"2017":[1,2,3,4,5,6,7,8,9,10,11,12] } } + + + + + +#cv["1"] = partition1 +#cv2["2"] = partition2 # ini. MPI comm = MPI.COMM_WORLD my_rank = comm.Get_rank() # rank of the node p = comm.Get_size() # number of assigned nods if my_rank == 0: # node is master - split_data_multiple_years(target_dir=target_dir,partition=partition,varnames=varnames) + split_data_multiple_years(target_dir=target_dir,partition=partition1,varnames=varnames) else: pass diff --git a/workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py b/workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py index fdc2f65a5092469c021e0c21b0606e7e7d248c5c..71c661c49ba3502b12dbd409fb76a5c9b4517087 100755 --- a/workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py +++ b/workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py @@ -109,7 +109,6 @@ def main(): # Expand destination_dir-variable by searching for netCDF-files in source_dir and processing the file from the first list element to obtain all relevant (meta-)data. if my_rank == 0: data_files_list = glob.glob(source_dir+"/**/*.nc",recursive=True) - if not data_files_list: raise ValueError("Could not find any data to be processed in '"+source_dir+"'") md = MetaData(suffix_indir=destination_dir,data_filename=data_files_list[0],slices=slices,variables=vars) @@ -118,15 +117,26 @@ def main(): md.write_dirs_to_batch_scripts(scr_dir+"/DataPreprocess_to_tf.sh") md.write_dirs_to_batch_scripts(scr_dir+"/generate_era5.sh") md.write_dirs_to_batch_scripts(scr_dir+"/train_era5.sh") - # ML 2020/06/08: Dirty workaround as long as data-splitting is done with a seperate Python-script - # called from the same parent Shell-/Batch-script - # -> work with temproary json-file in working directory - md.write_destdir_jsontmp(os.path.join(md.expdir,md.expname),tmp_dir=current_path) - #else: nothing to do + + elif (md.status == "old"): # meta-data file already exists and is ok + # check for temp.json in working directory (required by slave nodes) + tmp_file = os.path.join(current_path,"temp.json") + if os.path.isfile(tmp_file): + os.remove(tmp_file) + mess_tmp_file = "Auxiliary file '"+tmp_file+"' already exists, but is cleaned up to be updated" + \ + " for safety reasons." + logging.info(mess_tmp_file) + + # ML 2019/06/08: Dirty workaround as long as data-splitting is done with a seperate Python-script + # called from the same parent Shell-/Batch-script + # -> work with temproary json-file in working directory + # create or update temp.json, respectively + md.write_destdir_jsontmp(os.path.join(md.expdir, md.expname), tmp_dir=current_path) - destination_dir= os.path.join(md.expdir,md.expname,"hickle",years) + # expand destination directory by pickle-subfolder and... + destination_dir= os.path.join(md.expdir,md.expname,"pickle",years) - # ...and create directory if necessary + # ...create directory if necessary if not os.path.exists(destination_dir): # check if the Destination dir. is existing logging.critical('The Destination does not exist') logging.info('Create new destination dir') @@ -218,7 +228,7 @@ def main(): #process_era5_in_dir(job, src_dir=source_dir, target_dir=destination_dir) # ML 2020/06/09: workaround to get correct destination_dir obtained by the master node - destination_dir = os.path.join(MetaData.get_destdir_jsontmp(tmp_dir=current_path),"hickle",years) + destination_dir = os.path.join(MetaData.get_destdir_jsontmp(tmp_dir=current_path),"pickle",years) process_netCDF_in_dir(job_name=job, src_dir=source_dir, target_dir=destination_dir,slices=slices,vars=vars) if checksum_status == 1: diff --git a/workflow_parallel_frame_prediction/DataPreprocess/process_netCDF_v2.py b/workflow_parallel_frame_prediction/DataPreprocess/process_netCDF_v2.py index 21bacb746a1c95a576dabd3e548d5f5a00dafdb0..1a66fa51011174ab2d8adb2a4eb6b3b0ae3b4a7e 100644 --- a/workflow_parallel_frame_prediction/DataPreprocess/process_netCDF_v2.py +++ b/workflow_parallel_frame_prediction/DataPreprocess/process_netCDF_v2.py @@ -11,13 +11,20 @@ from netCDF4 import Dataset,num2date import numpy as np #from imageio import imread #from scipy.misc import imresize -import hickle as hkl import json import pickle # Create image datasets. # Processes images and saves them in train, val, test splits. def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2","MSL","gph500")): + ''' + :param directory_to_process: directory where netCDF-files are stored to be processed + :param target_dir: directory where pickle-files will e stored + :param job_name: job_id passed and organized by PyStager + :param slices: indices defining geographical region of interest + :param vars: variables to be processed + :return: Saves pickle-files which contain the sliced meteorological data and temporal information as well + ''' desired_im_sz = (slices["lat_e"] - slices["lat_s"], slices["lon_e"] - slices["lon_s"]) # ToDo: Define a convenient function to create a list containing all files. imageList = list(os.walk(directory_to_process, topdown = False))[-1][-1] @@ -25,31 +32,19 @@ def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2", EU_stack_list = [0] * (len(imageList)) temporal_list = [0] * (len(imageList)) nvars = len(vars) - #X = np.zeros((len(splits[split]),) + desired_im_sz + (3,), np.uint8) - #print(X) - #print('shape of X' + str(X.shape)) - ##### TODO: iterate over split and read every .nc file, cut out array, - ##### overlay arrays for RGB like style. - ##### Save everything after for loop. # ML 2020/04/06 S # Some inits stat_obj = Calc_data_stat(nvars) # ML 2020/04/06 E for j, im_file in enumerate(imageList): - #20200408,Bing try: im_path = os.path.join(directory_to_process, im_file) print('Open following dataset: '+im_path) - - - #20200408,Bing - - im = Dataset(im_path, mode = 'r') - times = im.variables['time'] - time = num2date(times[:],units=times.units,calendar=times.calendar) vars_list = [] with Dataset(im_path,'r') as data_file: + times = data_file.variables['time'] + time = num2date(times[:],units=times.units,calendar=times.calendar) for i in range(nvars): var1 = data_file.variables[vars[i]][0,slices["lat_s"]:slices["lat_e"], slices["lon_s"]:slices["lon_e"]] stat_obj.acc_stat_loc(i,var1) @@ -60,9 +55,6 @@ def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2", #20200408,bing temporal_list[j] = list(time) - #print('Does ist work? ') - #print(EU_stack_list[i][:,:,0]==EU_t2) - #print(EU_stack[:,:,1]==EU_msl except Exception as err: im_path = os.path.join(directory_to_process, im_file) #im = Dataset(im_path, mode = 'r') @@ -72,8 +64,12 @@ def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2", continue X = np.array(EU_stack_list) - target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.hkl') - hkl.dump(X, target_file) #Not optimal! + # ML 2020/07/15: Make use of pickle-files only + target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.pkl') + with open(target_file, "wb") as data_file: + pickle.dump(X,data_file) + #target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.pkl') + #hkl.dump(X, target_file) #Not optimal! print(target_file, "is saved") # ML 2020/03/31: write json file with statistics stat_obj.finalize_stat_loc(vars) @@ -82,28 +78,9 @@ def process_data(directory_to_process, target_dir, job_name, slices, vars=("T2", temporal_info = np.array(temporal_list) temporal_file = os.path.join(target_dir, 'T_' + str(job_name) + '.pkl') cwd = os.getcwd() - pickle.dump(temporal_info, open( temporal_file, "wb" ) ) - #hkl.dump(temporal_info, temporal_file) - - #hkl.dump(source_list, os.path.join(target_dir, 'sources_' + str(job) + '.hkl')) - - #for category, folder in splits[split]: - # im_dir = os.path.join(DATA_DIR, 'raw/', category, folder, folder[:10], folder, 'image_03/data/') - # files = list(os.walk(im_dir, topdown=False))[-1][-1] - # im_list += [im_dir + f for f in sorted(files)] - # multiply path of respective recording with lengths of its files in order to ensure - # that each entry in X_train.hkl corresponds with an entry of source_list/ sources_train.hkl - # source_list += [category + '-' + folder] * len(files) - - #print( 'Creating ' + split + ' data: ' + str(len(im_list)) + ' images') - #X = np.zeros((len(im_list),) + desired_im_sz + (3,), np.uint8) - # enumerate allows us to loop over something and have an automatic counter - #for i, im_file in enumerate(im_list): - # im = imread(im_file) - # X[i] = process_im(im, desired_im_sz) - - #hkl.dump(X, os.path.join(DATA_DIR, 'X_' + split + '.hkl')) - #hkl.dump(source_list, os.path.join(DATA_DIR, 'sources_' + split + '.hkl')) + with open(temporal_file,"wb") as ftemp: + pickle.dump(temporal_info,ftemp) + #pickle.dump(temporal_info, open( temporal_file, "wb" ) ) def process_netCDF_in_dir(src_dir,**kwargs): target_dir = kwargs.get("target_dir") @@ -111,6 +88,8 @@ def process_netCDF_in_dir(src_dir,**kwargs): directory_to_process = os.path.join(src_dir, job_name) os.chdir(directory_to_process) if not os.path.exists(target_dir): os.mkdir(target_dir) + #target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.hkl') + # ML 2020/07/15: Make use of pickle-files only target_file = os.path.join(target_dir, 'X_' + str(job_name) + '.hkl') if os.path.exists(target_file): print(target_file," file exists in the directory ", target_dir) @@ -119,67 +98,6 @@ def process_netCDF_in_dir(src_dir,**kwargs): process_data(directory_to_process=directory_to_process, **kwargs) -def split_data(target_dir, partition= [0.6, 0.2, 0.2]): - split_dir = target_dir + "/splits" - if not os.path.exists(split_dir): os.mkdir(split_dir) - os.chdir(target_dir) - files = glob.glob("*.hkl") - filesList = sorted(files) - #Bing: 20200415 - temporal_files = glob.glob("*.pkl") - temporal_filesList = sorted(temporal_files) - - # determine correct indicesue - train_begin = 0 - train_end = round(partition[0] * len(filesList)) - 1 - val_begin = train_end + 1 - val_end = train_end + round(partition[1] * len(filesList)) - test_begin = val_end + 1 - - - # slightly adapting start and end because starts at the first index given and stops before(!) the last. - train_files = filesList[train_begin:val_begin] - val_files = filesList[val_begin:test_begin] - test_files = filesList[test_begin:] - #bing: 20200415 - train_temporal_files = temporal_filesList[train_begin:val_begin] - val_temporal_files = temporal_filesList[val_begin:test_begin] - test_temporal_files = temporal_filesList[test_begin:] - - - splits = {s: [] for s in ['train', 'test', 'val']} - splits['val'] = val_files - splits['test'] = test_files - splits['train'] = train_files - - - splits_temporal = {s: [] for s in ['train', 'test', 'val']} - splits_temporal["train"] = train_temporal_files - splits_temporal["val"] = val_temporal_files - splits_temporal["test"] = test_temporal_files - - for split in splits: - X = [] - X_temporal = [] - files = splits[split] - temporal_files = splits_temporal[split] - for file, temporal_file in zip(files, temporal_files): - data_file = os.path.join(target_dir,file) - temporal_file = os.path.join(target_dir,temporal_file) - #load data with hkl file - data = hkl.load(data_file) - temporal_data = pickle.load(open(temporal_file,"rb")) - X_temporal = X_temporal + list(temporal_data) - X = X + list(data) - X = np.array(X) - X_temporal = np.array(X_temporal) - print ("X_temporal",X_temporal) - #save training, val and test data into splits directoyr - hkl.dump(X, os.path.join(split_dir, 'X_' + split + '.hkl')) - hkl.dump(files, os.path.join(split_dir,'sources_' + split + '.hkl')) - pickle.dump(X_temporal,open(os.path.join(split_dir,"T_"+split + ".pkl"),"wb")) - print ("PICKLE FILE FOR SPLITS SAVED") - # ML 2020/05/15 S def get_unique_vars(varnames): vars_uni, varsind = np.unique(varnames,return_index = True) @@ -300,7 +218,7 @@ class Calc_data_stat: print("Statistics file '"+file_name+"' has already been processed. Thus, just pass here...") pass - def finalize_stat_master(self,path_out,vars_uni): + def finalize_stat_master(self,vars_uni): """ Performs final compuattion of statistics after accumulation from slave nodes. """ @@ -310,7 +228,6 @@ class Calc_data_stat: if len(vars_uni) > len(set(vars_uni)): raise ValueError("Input variable names are not unique.") - js_file = os.path.join(path_out,"statistics.json") nvars = len(vars_uni) n_jsfiles = len(self.nfiles) nfiles_all= np.sum(self.nfiles) @@ -417,53 +334,62 @@ class Calc_data_stat: # ML 2020/05/15 E -def split_data_multiple_years(target_dir,partition,varnames): - """ - Collect all the X_*.hkl data across years and split them to training, val and testing datatset - """ - #target_dirs = [os.path.join(target_dir,year) for year in years] - #os.chdir(target_dir) - splits_dir = os.path.join(target_dir,"splits") - os.makedirs(splits_dir, exist_ok=True) - splits = {s: [] for s in list(partition.keys())} - # ML 2020/05/19 S - vars_uni, varsind, nvars = get_unique_vars(varnames) - stat_obj = Calc_data_stat(nvars) +# ML 2020/08/03 Not used anymore! +#def split_data_multiple_years(target_dir,partition,varnames): + #""" + #Collect all the X_*.hkl data across years and split them to training, val and testing datatset + #""" + ##target_dirs = [os.path.join(target_dir,year) for year in years] + ##os.chdir(target_dir) + #splits_dir = os.path.join(target_dir,"splits") + #os.makedirs(splits_dir, exist_ok=True) + #splits = {s: [] for s in list(partition.keys())} + ## ML 2020/05/19 S + #vars_uni, varsind, nvars = get_unique_vars(varnames) + #stat_obj = Calc_data_stat(nvars) - for split in partition.keys(): - values = partition[split] - files = [] - X = [] - Temporal_X = [] - for year in values.keys(): - file_dir = os.path.join(target_dir,year) - for month in values[year]: - month = "{0:0=2d}".format(month) - hickle_file = "X_{}.hkl".format(month) - #20200408:bing - temporal_file = "T_{}.pkl".format(month) - data_file = os.path.join(file_dir,hickle_file) - temporal_data_file = os.path.join(file_dir,temporal_file) - files.append(data_file) - data = hkl.load(data_file) - temporal_data = pickle.load(open(temporal_data_file,"rb")) - X = X + list(data) - Temporal_X = Temporal_X + list(temporal_data) - # process stat-file: - stat_obj.acc_stat_master(file_dir,int(month)) - X = np.array(X) - Temporal_X = np.array(Temporal_X) - print("==================={}=====================".format(split)) - print ("Sources for {} dataset are {}".format(split,files)) - print("Number of images in {} dataset is {} ".format(split,len(X))) - print ("dataset shape is {}".format(np.array(X).shape)) - hkl.dump(X, os.path.join(splits_dir , 'X_' + split + '.hkl')) - pickle.dump(Temporal_X, open(os.path.join(splits_dir,"T_"+split + ".pkl"),"wb")) - hkl.dump(files, os.path.join(splits_dir,'sources_' + split + '.hkl')) + #for split in partition.keys(): + #values = partition[split] + #files = [] + #X = [] + #Temporal_X = [] + #for year in values.keys(): + #file_dir = os.path.join(target_dir,year) + #for month in values[year]: + #month = "{0:0=2d}".format(month) + #hickle_file = "X_{}.hkl".format(month) + ##20200408:bing + #temporal_file = "T_{}.pkl".format(month) + ##data_file = os.path.join(file_dir,hickle_file) + #data_file = os.path.join(file_dir,hickle_file) + #temporal_data_file = os.path.join(file_dir,temporal_file) + #files.append(data_file) + #data = hkl.load(data_file) + #with open(temporal_data_file,"rb") as ftemp: + #temporal_data = pickle.load(ftemp) + #X = X + list(data) + #Temporal_X = Temporal_X + list(temporal_data) + ## process stat-file: + #stat_obj.acc_stat_master(file_dir,int(month)) + #X = np.array(X) + #Temporal_X = np.array(Temporal_X) + #print("==================={}=====================".format(split)) + #print ("Sources for {} dataset are {}".format(split,files)) + #print("Number of images in {} dataset is {} ".format(split,len(X))) + #print ("dataset shape is {}".format(np.array(X).shape)) + ## ML 2020/07/15: Make use of pickle-files only + #with open(os.path.join(splits_dir , 'X_' + split + '.pkl'),"wb") as data_file: + #pickle.dump(X,data_file,protocol=4) + ##hkl.dump(X, os.path.join(splits_dir , 'X_' + split + '.hkl')) + + #with open(os.path.join(splits_dir,"T_"+split + ".pkl"),"wb") as temp_file: + #pickle.dump(Temporal_X, temp_file) + + #hkl.dump(files, os.path.join(splits_dir,'sources_' + split + '.hkl')) - # write final statistics json-file - stat_obj.finalize_stat_master(target_dir,vars_uni) - stat_obj.write_stat_json(splits_dir) + ## write final statistics json-file + #stat_obj.finalize_stat_master(target_dir,vars_uni) + #stat_obj.write_stat_json(splits_dir)