Skip to content
Snippets Groups Projects
Commit a03e812b authored by b.gong's avatar b.gong
Browse files

Integrated vanilla VAE to workflow

parent ea4037bf
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
source_dir=/home/$USER/extractedData source_dir=/home/$USER/extractedData
destination_dir=/home/$USER/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle destination_dir=/home/$USER/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/hickle
declare -a years=("2017") declare -a years=("2017")
for year in "${years[@]}"; for year in "${years[@]}";
...@@ -11,7 +11,7 @@ for year in "${years[@]}"; ...@@ -11,7 +11,7 @@ for year in "${years[@]}";
echo "source_dir ${source_dir}/${year}" echo "source_dir ${source_dir}/${year}"
mpirun -np 2 python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \ mpirun -np 2 python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_stager_v2_process_netCDF.py \
--source_dir ${source_dir}/${year}/ \ --source_dir ${source_dir}/${year}/ \
--destination_dir ${destination_dir}/${year}/ --vars T2 MSL gph500 --lat_s 74 --lat_e 202 --lon_s 550 --lon_e 710 --destination_dir ${destination_dir}/${year}/ --vars T2 MSL gph500 --lat_s 138 --lat_e 202 --lon_s 646 --lon_e 710
done done
python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir} python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir}
......
#!/bin/bash -x #!/bin/bash -x
python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20 python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500 -height 64 -width 64 -seq_length 20
temporal_dir: /home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/hickle/splits/
loading options from checkpoint /home/b.gong/models/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/vae
----------------------------------- Options ------------------------------------
batch_size = 2
checkpoint = /home/b.gong/models/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/vae
dataset = era5
dataset_hparams = sequence_length=20
fps = 4
gif_length = None
gpu_mem_frac = 0
input_dir = /home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords
mode = test
model = vae
model_hparams = None
num_epochs = 1
num_samples = None
num_stochastic_samples = 1
output_gif_dir = /home/b.gong/results/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/vae
output_png_dir = /home/b.gong/results/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/vae
results_dir = /home/b.gong/results/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500
results_gif_dir = /home/b.gong/results/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500
results_png_dir = /home/b.gong/results/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500
seed = 7
------------------------------------- End --------------------------------------
datset_class ERA5Dataset_v2
FILENAMES ['/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_0_to_1.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_10_to_11.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_12_to_13.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_2_to_3.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_4_to_5.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_6_to_7.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_8_to_9.tfrecords']
files ['/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_0_to_1.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_10_to_11.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_12_to_13.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_2_to_3.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_4_to_5.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_6_to_7.tfrecords', '/home/b.gong/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords/test/sequence_8_to_9.tfrecords']
mode test
Parse features {'images/encoded': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f71880c6b00>, 'channels': <tf.Tensor 'ParseSingleExample/ParseSingleExample:3' shape=() dtype=int64>, 'height': <tf.Tensor 'ParseSingleExample/ParseSingleExample:4' shape=() dtype=int64>, 'sequence_length': <tf.Tensor 'ParseSingleExample/ParseSingleExample:5' shape=() dtype=int64>, 'width': <tf.Tensor 'ParseSingleExample/ParseSingleExample:6' shape=() dtype=int64>}
Image shape 20, 64,64,3
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
python -u ../scripts/generate_transfer_learning_finetune.py \ python -u ../scripts/generate_transfer_learning_finetune.py \
--input_dir /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords \ --input_dir /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords \
--dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/ours_savp \ --dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/vae \
--mode test --results_dir /home/${USER}/results/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500 \ --mode test --results_dir /home/${USER}/results/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500 \
--batch_size 2 --dataset era5 > generate_era5-out.out --batch_size 2 --dataset era5 > generate_era5-out.out
#srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp #srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
#!/bin/bash -x
python ../scripts/train_dummy.py --input_dir /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model vae --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir /home/${USER}/models/era5-Y2015toY2017M01to12-64x64-74d00N71d00E-T_MSL_gph500/vae
#srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
...@@ -56,7 +56,7 @@ echo "=============================================================" ...@@ -56,7 +56,7 @@ echo "============================================================="
# --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR} # --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR}
#fi #fi
#Change the .hkl data to .tfrecords files ####Change the .hkl data to .tfrecords files
if [ -d "$DATA_PREPROCESS_TF_DIR" ] if [ -d "$DATA_PREPROCESS_TF_DIR" ]
then then
echo "Step2: The Preprocessed Data (tf.records) exist" echo "Step2: The Preprocessed Data (tf.records) exist"
......
#!/usr/bin/env bash
set -e
#
#MODEL=savp
##train_mode: end_to_end, pre_trained
#TRAIN_MODE=end_to_end
#EXP_NAME=era5_size_64_64_3_3t_norm
MODEL=$1
TRAIN_MODE=$2
EXP_NAME=$3
RETRAIN=1 #if we continue training the model or using the existing end-to-end model, 1 means continue training, and 1 means use the existing one
DATA_ETL_DIR=/home/${USER}/
DATA_ETL_DIR=/p/scratch/deepacf/${USER}/
DATA_EXTRA_DIR=${DATA_ETL_DIR}/extractedData/${EXP_NAME}
DATA_PREPROCESS_DIR=${DATA_ETL_DIR}/preprocessedData/${EXP_NAME}
DATA_PREPROCESS_TF_DIR=./data/${EXP_NAME}
RESULTS_OUTPUT_DIR=./results_test_samples/${EXP_NAME}/${TRAIN_MODE}/
if [ $MODEL==savp ]; then
method_dir=ours_savp
elif [ $MODEL==gan ]; then
method_dir=ours_gan
elif [ $MODEL==vae ]; then
method_dir=ours_vae
else
echo "model does not exist" 2>&1
exit 1
fi
if [ "$TRAIN_MODE" == pre_trained ]; then
TRAIN_OUTPUT_DIR=./pretrained_models/kth/${method_dir}
else
TRAIN_OUTPUT_DIR=./logs/${EXP_NAME}/${TRAIN_MODE}
fi
CHECKPOINT_DIR=${TRAIN_OUTPUT_DIR}/${method_dir}
echo "===========================WORKFLOW SETUP===================="
echo "Model ${MODEL}"
echo "TRAIN MODE ${TRAIN_MODE}"
echo "Method_dir ${method_dir}"
echo "DATA_ETL_DIR ${DATA_ETL_DIR}"
echo "DATA_EXTRA_DIR ${DATA_EXTRA_DIR}"
echo "DATA_PREPROCESS_DIR ${DATA_PREPROCESS_DIR}"
echo "DATA_PREPROCESS_TF_DIR ${DATA_PREPROCESS_TF_DIR}"
echo "TRAIN_OUTPUT_DIR ${TRAIN_OUTPUT_DIR}"
echo "============================================================="
##############Datat Preprocessing################
#To hkl data
#if [ -d "$DATA_PREPROCESS_DIR" ]; then
# echo "The Preprocessed Data (.hkl ) exist"
#else
# python ../workflow_video_prediction/DataPreprocess/benchmark/mpi_stager_v2_process_netCDF.py \
# --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR}
#fi
####Change the .hkl data to .tfrecords files
if [ -d "$DATA_PREPROCESS_TF_DIR" ]
then
echo "Step2: The Preprocessed Data (tf.records) exist"
else
echo "Step2: start, hkl. files to tf.records"
python ./video_prediction/datasets/era5_dataset_v2.py --source_dir ${DATA_PREPROCESS_DIR}/splits \
--destination_dir ${DATA_PREPROCESS_TF_DIR}
echo "Step2: finish"
fi
#########Train##########################
if [ "$TRAIN_MODE" == "pre_trained" ]; then
echo "step3: Using kth pre_trained model"
elif [ "$TRAIN_MODE" == "end_to_end" ]; then
echo "step3: End-to-end training"
if [ "$RETRAIN" == 1 ]; then
echo "Using the existing end-to-end model"
else
echo "Training Starts "
python ./scripts/train_v2.py --input_dir $DATA_PREPROCESS_TF_DIR --dataset era5 \
--model ${MODEL} --model_hparams_dict hparams/kth/${method_dir}/model_hparams.json \
--output_dir ${TRAIN_OUTPUT_DIR} --checkpoint ${CHECKPOINT_DIR}
echo "Training ends "
fi
else
echo "TRAIN_MODE is end_to_end or pre_trained"
exit 1
fi
#########Generate results#################
echo "Step4: Postprocessing start"
python ./scripts/generate_transfer_learning_finetune.py --input_dir ${DATA_PREPROCESS_TF_DIR} \
--dataset_hparams sequence_length=20 --checkpoint ${CHECKPOINT_DIR} --mode test --results_dir ${RESULTS_OUTPUT_DIR} \
--batch_size 4 --dataset era5
...@@ -22,7 +22,7 @@ pip3 install mpi4py ...@@ -22,7 +22,7 @@ pip3 install mpi4py
pip3 install netCDF4 pip3 install netCDF4
pip3 install numpy pip3 install numpy
pip3 install h5py pip3 install h5py
pip3 install tensorflow==1.13.1 pip3 install tensorflow-gpu==1.14.0
#Copy the hickle package from bing's account #Copy the hickle package from bing's account
#cp -r /p/project/deepacf/deeprain/bing/hickle ${WORKING_DIR} #cp -r /p/project/deepacf/deeprain/bing/hickle ${WORKING_DIR}
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
"vae_gan_feature_cdist_weight": 10.0, "vae_gan_feature_cdist_weight": 10.0,
"gan_feature_cdist_weight": 0.0, "gan_feature_cdist_weight": 0.0,
"state_weight": 0.0, "state_weight": 0.0,
"nz": 32 "nz": 32,
"max_steps":20
} }
...@@ -7,7 +7,7 @@ from .savp_model import SAVPVideoPredictionModel ...@@ -7,7 +7,7 @@ from .savp_model import SAVPVideoPredictionModel
from .dna_model import DNAVideoPredictionModel from .dna_model import DNAVideoPredictionModel
from .sna_model import SNAVideoPredictionModel from .sna_model import SNAVideoPredictionModel
from .sv2p_model import SV2PVideoPredictionModel from .sv2p_model import SV2PVideoPredictionModel
from .vanilla_vae_model import VanillaVAEVideoPredictionModel
def get_model_class(model): def get_model_class(model):
model_mappings = { model_mappings = {
...@@ -17,6 +17,7 @@ def get_model_class(model): ...@@ -17,6 +17,7 @@ def get_model_class(model):
'dna': 'DNAVideoPredictionModel', 'dna': 'DNAVideoPredictionModel',
'sna': 'SNAVideoPredictionModel', 'sna': 'SNAVideoPredictionModel',
'sv2p': 'SV2PVideoPredictionModel', 'sv2p': 'SV2PVideoPredictionModel',
'vae': 'VanillaVAEVideoPredictionModel',
} }
model_class = model_mappings.get(model, model) model_class = model_mappings.get(model, model)
model_class = globals().get(model_class) model_class = globals().get(model_class)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment