diff --git a/.gitignore b/.gitignore index cc57ba756a2d218d4c1120e71b449e02e95eff18..880544c7da6d759eb4ce6c5d5370ed38a3355f51 100644 --- a/.gitignore +++ b/.gitignore @@ -124,10 +124,13 @@ virt_env*/ # Ignore (Batch) runscripts video_prediction_tools/HPC_scripts/** !video_prediction_tools/HPC_scripts/*_template.sh +!video_predcition_tools/HPC_scripts/config_train.py video_prediction_tools/Zam347_scripts/** !video_prediction_tools/Zam347_scripts/*_template.sh +<<<<<<< HEAD # Ignore datasplit config files video_prediction_tools/data_split/** !video_prediction_tools/data_split/datasplit_template.json !video_prediction_tools/data_split/cv_test.json +!video_predcition_tools/Zam347_scripts/config_train.py diff --git a/test/run_pytest.sh b/test/run_pytest.sh index f5ff0312acf63ee7e7d728fd7c466e058f4b1e73..2f9fc13eb68a9e02a171834c11bd0efecac5fb7c 100644 --- a/test/run_pytest.sh +++ b/test/run_pytest.sh @@ -27,4 +27,5 @@ source ../video_prediction_tools/env_setup/modules_train.sh #First remove all the files in the test folder #rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/* #python -m pytest test_train_model_era5.py -python -m pytest test_visualize_postprocess.py +#python -m pytest test_visualize_postprocess.py +python -m pytest test_meta_postprocess.py diff --git a/test/test_meta_postprocess.py b/test/test_meta_postprocess.py index 4bca02bd5494bc15d8e1d3ff453d2f6237700c08..1892b0c3c22d723eeba5e08cce238a5177e3860b 100644 --- a/test/test_meta_postprocess.py +++ b/test/test_meta_postprocess.py @@ -10,9 +10,13 @@ import pytest #Params analysis_config = "/p/home/jusers/gong1/juwels/ambs/video_prediction_tools/analysis_config/analysis_test.json" analysis_dir = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/analysis/bing_test1" + test_nc_fl = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny/vfp_date_2017030118_sample_ind_13.nc" test_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny" + + + #setup instance @pytest.fixture(scope="module") def analysis_inst(): diff --git a/video_prediction_tools/HPC_scripts/config_train.py b/video_prediction_tools/HPC_scripts/config_train.py new file mode 100644 index 0000000000000000000000000000000000000000..de4751668bbe28536ba0817b962e5e3d71ba5c20 --- /dev/null +++ b/video_prediction_tools/HPC_scripts/config_train.py @@ -0,0 +1,257 @@ +""" +Basic task of the Python-script: + +Creates user-defined runscripts for training, set ups a user-defined target directory and allows for full control +on the setting of hyperparameters. +""" + +__email__ = "b.gong@fz-juelich.de" +__authors__ = "Michael Langguth" +__date__ = "2020-11-19" + +# import modules +import sys, os, glob +import subprocess +import datetime as dt +import json as js +from os import path +if sys.version_info[0] < 3: + raise Exception("This script has to be run with Python 3!") +sys.path.append(os.path.dirname(sys.path[0])) +from model_modules.model_architectures import known_models +from data_preprocess.dataset_options import known_datasets + +# some auxiliary functions + +# robust check if script is running in virtual env from +# https://stackoverflow.com/questions/1871549/determine-if-python-is-running-inside-virtualenv/38939054 +def get_base_prefix_compat(): + """Get base/real prefix, or sys.prefix if there is none.""" + return getattr(sys, "base_prefix", None) or getattr(sys, "real_prefix", None) or sys.prefix +# +#-------------------------------------------------------------------------------------------------------- +# +def in_virtualenv(): + return get_base_prefix_compat() != sys.prefix +# +#-------------------------------------------------------------------------------------------------------- +# +def check_virtualenv(labort=False): + ''' + Checks if current script is running a virtual environment and returns the directory's name + :param labort: If True, the an Exception is raised. If False, only a Warning is given + :return: name of virtual environment + ''' + lvirt = in_virtualenv() + + if not lvirt: + if labort: + raise EnvironmentError("config_train.py has to run in an activated virtual environment!") + else: + raise Warning("config_train.py is not running in an activated virtual environment!") + return + else: + return os.path.basename(sys.prefix) +# +# -------------------------------------------------------------------------------------------------------- +# +def get_variable_from_runscript(runscript_file,script_variable): + ''' + Search for the declaration of variable in a Shell script and returns its value. + :param runscript_file: path to shell script/runscript + :param script_variable: name of variable which is declared in shell script at hand + :return: value of script_variable + ''' + script_variable = script_variable + "=" + found = False + + with open(runscript_file) as runscript: + # Skips text before the beginning of the interesting block: + for line in runscript: + if script_variable in line: + var_value = (line.strip(script_variable)).replace("\n", "") + found = True + break + + if not found: + raise Exception("Could not find declaration of '"+script_variable+"' in '"+runscript_file+"'.") + + return var_value +# +#-------------------------------------------------------------------------------------------------------- +# +def path_rec_split(full_path): + """ + :param full_path: input path to be splitted in its components + :return: list of all splitted components + """ + rest, tail = os.path.split(full_path) + if rest in ('', os.path.sep): return tail, + + return path_rec_split(rest) + (tail,) +# +#-------------------------------------------------------------------------------------------------------- +# +def keyboard_interaction(console_str,check_input,err,ntries=1): + """ + Function to check if the user has passed a proper input via keyboard interaction + :param console_str: Request printed to the console + :param check_input: function returning boolean which needs to be passed by input from keyboard interaction. + Must have two arguments with the latter being an optional bool called silent. + :param ntries: maximum number of tries (default: 1) + :return: The approved input from keyboard interaction + """ + # sanity checks + if not callable(check_input): + raise ValueError("check_input must be a function!") + else: + try: + if not type(check_input("xxx",silent=True)) is bool: + raise TypeError("check_input argument does not return a boolean.") + else: + pass + except: + raise Exception("Cannot approve check_input-argument to be proper.") + if not isinstance(err,BaseException): + raise ValueError("err_str-argument must be an instance of BaseException!") + if not isinstance(ntries,int) and ntries <= 1: + raise ValueError("ntries-argument must be an integer greater equal 1!") + + attempt = 0 + while attempt < ntries: + input_req = input(console_str) + if check_input(input_req): + break + else: + attempt += 1 + if attempt < ntries: + print(err) + console_str = "Retry!\n" + else: + raise err + + return input_req + + +def main(): + + list_models = known_models().keys() + list_datasets = known_datasets().keys() + + # sanity check (is Python running in a virtual environment) + venv_name = check_virtualenv(labort=True) + + ## get required information from the user by keyboard interaction + + # dataset used for training + def check_dataset(dataset_name, silent=False): + # NOTE: Generic template for training still has to be integrated! + # After this is done, the latter part of the if-clause can be removed + # and further adaptions for the target_dir and for retrieving base_dir (see below) are required + if not dataset_name in list_datasets or dataset_name != "era5": + if not silent: + print("The following dataset can be used for training:") + for dataset_avail in list_datasets: print("* " + dataset_avail) + return False + else: + return True + + dataset_req_str = "Enter the name of the dataset for training:\n" + dataset_err = ValueError("Please select a dataset from the ones listed above.") + + dataset = keyboard_interaction(dataset_req_str,check_dataset,dataset_err,ntries=2) + # path to preprocessed data + def check_expdir(exp_dir, silent=False): + status = False + if os.path.isdir(exp_dir): + file_list = glob.glob(os.path.join(exp_dir,"sequence*.tfrecords")) + if len(file_list) > 0: + status = True + else: + print("{0} does not contain any tfrecord-files.".format(exp_dir)) + else: + if not silent: print("Passed directory does not exist!") + return status + + expdir_req_str = "Enter the path to the preprocessed data (directory where tf-records files are located):\n" + expdir_err = FileNotFoundError("Could not find any tfrecords.") + + exp_dir_full = keyboard_interaction(expdir_req_str, check_expdir, expdir_err, ntries=3) + + # split up directory path + exp_dir_split = path_rec_split(exp_dir_full) + index = [idx for idx, s in enumerate(exp_dir_split) if dataset in s] + if index == []: + raise ValueError("tfrecords found under '{0}', but directory does not seem to reflect naming convention.".format(exp_dir_full)) + exp_dir = exp_dir_split[index[0]] + + # model + def check_model(model_name, silent=False): + if not model_name in list_models: + if not silent: + print("{0} is not a valid model!".format(model_name)) + print("The following models are implemented in the workflow:") + for model_avail in list_models: print("* " + model_avail) + return False + else: + return True + + model_req_str = "Enter the name of the model you want to train:\n" + model_err = ValueError("Please select a model from the ones listed above.") + + model = keyboard_interaction(model_req_str, check_model, model_err, ntries=2) + + # experimental ID + # No need to call keyboard_interaction here, because the user can pass whatever we wants + exp_id = input("Enter your desired experimental id (will be extended by timestamp and username):\n") + + # also get current timestamp and user-name... + timestamp = dt.datetime.now().strftime("%Y%m%dT%H%M%S") + user_name = os.environ["USER"] + # ... to construct final target_dir and exp_dir_ext as well + exp_id = timestamp +"_"+ user_name +"_"+ exp_id # by convention, exp_id is extended by timestamp and username + base_dir = get_variable_from_runscript('train_model_era5_template.sh','destination_dir') + exp_dir_ext= os.path.join(exp_dir,model,exp_id) + target_dir = os.path.join(base_dir,exp_dir,model,exp_id) + + # sanity check (target_dir is unique): + if os.path.isdir(target_dir): + raise IsADirectoryError(target_dir+" already exists! Make sure that it is unique.") + + # create destination directory... + os.makedirs(target_dir) + source_hparams = os.path.join("..","hparams",dataset,model,"model_hparams.json") + # sanity check (default hyperparameter json-file exists) + if not os.path.isfile(source_hparams): + raise FileNotFoundError("Could not find default hyperparameter json-file '"+source_hparams+"'") + # ...copy over json-file for hyperparamters... + os.system("cp "+source_hparams+" "+target_dir) + # ...and open vim + cmd_vim = os.environ.get('EDITOR', 'vi') + ' ' + os.path.join(target_dir,"model_hparams.json") + subprocess.call(cmd_vim, shell=True) + + # finally, create runscript for training... + cmd = "cd ../env_setup; ./generate_workflow_runscripts.sh ../HPC_scripts/train_model_era5 "+ venv_name+ \ + " -exp_id="+exp_id+" -exp_dir="+exp_dir+" -exp_dir_ext="+exp_dir_ext+" -model="+model+" ; cd -" + os.system(cmd) + # ...and postprocessing as well + cmd = cmd.replace("train_model_era5","visualize_postprocess_era5") + os.system(cmd) + +if __name__== '__main__': + main() + + + + + + + + + + + + + + + diff --git a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh b/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh index 8539f7c0534a5e3e0ac5e4733635e569d495284a..7b72f3ebe8b7624f19f2ea35c7e0876e4528b83f 100644 --- a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh +++ b/video_prediction_tools/HPC_scripts/preprocess_data_era5_step2_template.sh @@ -39,6 +39,7 @@ datasplit_dir=../data_split/cv_test.json model=convLSTM hparams_dict_config=../hparams/era5/${model}/model_hparams.json sequences_per_file=10 + # run Preprocessing (step 2 where Tf-records are generated) -srun python ../main_scripts/main_preprocess_data_step2.py -input_dir ${source_dir} -output_dir ${destination_dir}/tfrecords -datasplit_config ${datasplit_dir} -hparams_dict_config ${hparams_dict_config} -sequences_per_file ${sequences_per_file} +srun python ../main_scripts/main_preprocess_data_step2.py -input_dir ${source_dir} -output_dir ${destination_dir} -datasplit_config ${datasplit_dir} -hparams_dict_config ${hparams_dict_config} -sequences_per_file ${sequences_per_file} diff --git a/video_prediction_tools/HPC_scripts/train_model_era5_template.sh b/video_prediction_tools/HPC_scripts/train_model_era5_template.sh index 0cf2aad686ebe41204968e89432940edf4d0167e..d351dceb7e7c369efa224423b951bbe25441fc27 100644 --- a/video_prediction_tools/HPC_scripts/train_model_era5_template.sh +++ b/video_prediction_tools/HPC_scripts/train_model_era5_template.sh @@ -41,10 +41,12 @@ destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/model # valid identifiers for model-argument are: convLSTM, savp, mcnet and vae # the destination_dir_full cannot end up with "/", this will cause to save all the checkpoints issue in the results_dir model=convLSTM -model_hparams=../hparams/era5/${model}/model_hparams.json -destination_dir_full=${destination_dir}/${model}/"$(date +"%Y%m%dT%H%M")_"$USER"" +datasplit_dict=../data_split/cv_test.json +model_hparams=${destination_dir}/model_hparams.json # run training -srun python ../main_scripts/main_train_models.py --input_dir ${source_dir}/tfrecords/ --dataset era5 --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir_full}/ +srun python ../main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} --dataset era5 --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir} + + diff --git a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh index b2531f5644891a3144134f8fc4632d926b943c92..ef78f05c0a233b5c9fca423c6d9dcd97aaf42c0c 100644 --- a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh +++ b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh @@ -41,9 +41,10 @@ results_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/results/ # name of model model=convLSTM -exp=[specify experiment name] + # run postprocessing/generation of model results including evaluation metrics srun python -u ../main_scripts/main_visualize_postprocess.py \ ---input_dir ${source_dir} --dataset_hparams sequence_length=20 --checkpoint ${checkpoint_dir}/${model}/${exp} \ ---mode test --model ${model} --results_dir ${results_dir}/${model}/${exp}/ --batch_size 2 --dataset era5 > generate_era5-out.out - +--input_dir ${source_dir} --checkpoint ${checkpoint_dir} \ +--mode test --results_dir ${results_dir} \ +--batch_size 2 --num_samples 20 --num_stochastic_samples 2 \ + > generate_era5-out.out diff --git a/video_prediction_tools/analysis_config/analysis_test.json b/video_prediction_tools/analysis_config/analysis_test.json new file mode 100644 index 0000000000000000000000000000000000000000..613467fcdabf74a56a2236194cdf9d0c5dae6ff8 --- /dev/null +++ b/video_prediction_tools/analysis_config/analysis_test.json @@ -0,0 +1,11 @@ + +{ + +"results_dir": ["/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1", + "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1"], + +"metric": ["mse"], + +"compare_by":["model"] + +} diff --git a/video_prediction_tools/data_preprocess/dataset_options.py b/video_prediction_tools/data_preprocess/dataset_options.py new file mode 100644 index 0000000000000000000000000000000000000000..28dffb6c8879bd934c6a8f7169ee0a6bcf679999 --- /dev/null +++ b/video_prediction_tools/data_preprocess/dataset_options.py @@ -0,0 +1,19 @@ +def known_datasets(): + """ + An auxilary function + :return: dictionary of known datasets + """ + dataset_mappings = { + 'google_robot': 'GoogleRobotVideoDataset', + 'sv2p': 'SV2PVideoDataset', + 'softmotion': 'SoftmotionVideoDataset', + 'bair': 'SoftmotionVideoDataset', # alias of softmotion + 'kth': 'KTHVideoDataset', + 'ucf101': 'UCF101VideoDataset', + 'cartgripper': 'CartgripperVideoDataset', + "era5": "ERA5Dataset", + "moving_mnist": "MovingMnist" + # "era5_anomaly":"ERA5Dataset_v2_anomaly", + } + + return dataset_mappings \ No newline at end of file diff --git a/video_prediction_tools/data_preprocess/preprocess_data_step2.py b/video_prediction_tools/data_preprocess/preprocess_data_step2.py index 085a47513f13bc141f7cbcaea6d0fe00fbd601f9..1d7e589193a52e480db151ad859c007794a691e6 100644 --- a/video_prediction_tools/data_preprocess/preprocess_data_step2.py +++ b/video_prediction_tools/data_preprocess/preprocess_data_step2.py @@ -14,7 +14,7 @@ import json import tensorflow as tf from normalization import Norm_data import datetime -from video_prediction.datasets import ERA5Dataset +from model_modules.video_prediction.datasets import ERA5Dataset class ERA5Pkl2Tfrecords(ERA5Dataset): @@ -236,6 +236,10 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ERA5Pkl2Tfrecords.save_tf_record(output_fname, list(sequences), t_start_points) def write_sequence_file(self): + """ + Generate a txt file, with the numbers of sequences for each tfrecords file. + This is mainly used for calculting the number of samples for each epoch during training epoch + """ with open(os.path.join(self.output_dir, 'sequence_lengths.txt'), 'w') as seq_file: seq_file.write("%d\n" % self.sequences_per_file) diff --git a/video_prediction_tools/docs/discussion/20201112_AMBS_report_to_Martin.pptx b/video_prediction_tools/docs/discussion/20201112_AMBS_report_to_Martin.pptx new file mode 100644 index 0000000000000000000000000000000000000000..844d31e7a358a970a2567b6826b7151a52a6971c Binary files /dev/null and b/video_prediction_tools/docs/discussion/20201112_AMBS_report_to_Martin.pptx differ diff --git a/video_prediction_tools/docs/discussion/discussion.md b/video_prediction_tools/docs/discussion/discussion.md index ff1d00c4c064c72c13029e84fbd0c18c3e4ff59b..6134f377fb840d9ae7f22d5b45f1153b403a81d2 100644 --- a/video_prediction_tools/docs/discussion/discussion.md +++ b/video_prediction_tools/docs/discussion/discussion.md @@ -1,5 +1,24 @@ This is the list of last-mins files for VP group -## 2020-03-01 - 2020-03-31 +## 2020-03-01 - 2020-04-15 AMBS internal meeting + +- https://docs.google.com/document/d/1cQUEWrenIlW1zebZwSSHpfka2Bhb8u63kPM3x7nya_o/edit#heading=h.yjmq51s4fxnm + +## 2020-08-31 - 2020-11-04 AMBS internal meeting + +- https://docs.google.com/document/d/1mHKey_lcy6-UluVm-nrpOBoNgochOxWnnwZ4XaJ-d2c/edit?usp=sharing + + +## 2020-11-12 AMBS update with Martin + +- https://docs.google.com/document/d/1rc-hImd_A0rdOTSem461vZCY_8-GZY1zvwCvl5J8BQ8/edit?usp=sharing +- Presentation: https://gitlab.version.fz-juelich.de/toar/ambs/-/blob/bing_%2337_organize_last_mins_meeting/video_prediction_tools/docs/discussion/20201112_AMBS_report_to_Martin.pptx + + +## 2020-09-11 - 2021-01-20 JUWELS Booster Early Access Program +- Instruction: How to submit jobs in container on Booster. https://docs.google.com/document/d/1t2cmjTDbNtzEYBQSfeJn11T5w-wLjgMaJm9vwlgtCMA/edit?usp=sharing +- EA program Profile (German): ea-application-DeepACF-de_stadtler+michael.docx +- EA program Profile (English):ea-application-DeepACF_Bing.docx +- convLSTM training visu animation: https://fz-juelich.sciebo.de/s/2cSpnnEzPlqZufL + -- https://docs.google.com/document/d/1cQUEWrenIlW1zebZwSSHpfka2Bhb8u63kPM3x7nya_o/edit#heading=h.yjmq51s4fxnm \ No newline at end of file diff --git a/video_prediction_tools/docs/discussion/ea-application-DeepACF-de_stadtler+michael.docx b/video_prediction_tools/docs/discussion/ea-application-DeepACF-de_stadtler+michael.docx new file mode 100644 index 0000000000000000000000000000000000000000..81ed732038ce0d9a0a2308b7e7aaa68fe84aa1e4 Binary files /dev/null and b/video_prediction_tools/docs/discussion/ea-application-DeepACF-de_stadtler+michael.docx differ diff --git a/video_prediction_tools/docs/discussion/ea-application-DeepACF_Bing.DOCX b/video_prediction_tools/docs/discussion/ea-application-DeepACF_Bing.DOCX new file mode 100644 index 0000000000000000000000000000000000000000..d14a0368f10f6c8aabf7db61c0492d6abd23147e Binary files /dev/null and b/video_prediction_tools/docs/discussion/ea-application-DeepACF_Bing.DOCX differ diff --git a/video_prediction_tools/env_setup/create_env.sh b/video_prediction_tools/env_setup/create_env.sh index 319b8cfc5e471cce4708b8792d878e1a0606d46a..b0105a4a206f0e6e1380121f22214a8e204ff44e 100755 --- a/video_prediction_tools/env_setup/create_env.sh +++ b/video_prediction_tools/env_setup/create_env.sh @@ -31,8 +31,9 @@ else fi # list of (Batch) scripts used for the steps in the workflow -# !!! Expects that a template named [script_name]_template.sh exists!!! -workflow_scripts=(data_extraction_era5 preprocess_data_era5_step1 preprocess_data_era5_step2 train_model_era5 visualize_postprocess_era5 preprocess_data_moving_mnist train_model_moving_mnist visualize_postprocess_moving_mnist) +# !!! Expects that a template named [script_name]_template.sh exists !!! +# !!! Runscripts for training and postprocessing shall be created with config_train.py !!! +workflow_scripts=(data_extraction_era5 preprocess_data_era5_step1 preprocess_data_era5_step2 preprocess_data_moving_mnist) HOST_NAME=`hostname` ENV_NAME=$1 @@ -104,6 +105,7 @@ if [[ "$ENV_EXIST" == 0 ]]; then export PYTHONPATH=${WORKING_DIR}/utils:$PYTHONPATH >> ${activate_virt_env} #export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/external_package/lpips-tensorflow:$PYTHONPATH >> ${activate_virt_env} + export PYTHONPATH=${WORKING_DIR}/model_modules:$PYTHONPATH >> ${activate_virt_env} if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH >> ${activate_virt_env} @@ -113,6 +115,7 @@ if [[ "$ENV_EXIST" == 0 ]]; then echo "# Expand PYTHONPATH..." >> ${activate_virt_env} echo "export PYTHONPATH=${WORKING_DIR}:\$PYTHONPATH" >> ${activate_virt_env} echo "export PYTHONPATH=${WORKING_DIR}/utils/:\$PYTHONPATH" >> ${activate_virt_env} + echo "export PYTHONPATH=${WORKING_DIR}/model_modules:$PYTHONPATH " >> ${activate_virt_env} #export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:\$PYTHONPATH echo "export PYTHONPATH=${WORKING_DIR}/external_package/lpips-tensorflow:\$PYTHONPATH" >> ${activate_virt_env} @@ -126,13 +129,12 @@ fi # Finish by creating runscripts # After checking and setting up the virt env, create user-specific runscripts for all steps of the workflow if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then - echo "***** Creating Batch-scripts for running workflow... *****" script_dir=../HPC_scripts elif [[ "${HOST_NAME}" == "zam347" ]]; then - echo "***** Creating Batch-scripts for running workflow... *****" script_dir=../Zam347_scripts fi +echo "***** Creating Batch-scripts for data extraction and preprpcessing substeps... *****" for wf_script in "${workflow_scripts[@]}"; do curr_script=${script_dir}/${wf_script} if [[ -z "${exp_id}" ]]; then @@ -141,4 +143,6 @@ for wf_script in "${workflow_scripts[@]}"; do ./generate_workflow_runscripts.sh ${curr_script} ${ENV_NAME} -exp_id=${exp_id} fi done +echo "******************************************** NOTE ********************************************" +echo "Runscripts for training and postprocessing can be generated with ../HPC_scripts/config_train.py" diff --git a/video_prediction_tools/env_setup/generate_workflow_runscripts.sh b/video_prediction_tools/env_setup/generate_workflow_runscripts.sh index 34aebf26fe5021575a93839866f80f9f69a54d8d..b8ce0b4a0f6ec7dfbbe3b91d6c1de6f1fe62e859 100755 --- a/video_prediction_tools/env_setup/generate_workflow_runscripts.sh +++ b/video_prediction_tools/env_setup/generate_workflow_runscripts.sh @@ -1,28 +1,34 @@ #!/usr/bin/env bash # # __authors__ = Michael Langguth -# __date__ = '2020_09_29' +# __date__ = '2020_10_28' # # **************** Description **************** # Converts a given template workflow script (path/name has to be passed as first argument) to -# an executable workflow (Batch) script. +# an executable workflow (Batch) script. However, use 'config_train.py' for convenience when runscripts for the +# training and postprocessing substeps should be generated. # Note, that the first argument has to be passed with "_template.sh" omitted! # The second argument denotes the name of the virtual environment to be used. -# Additionally, -exp_id=[some_id] and -exp_dir=[some_dir] can be optionally passed as NON-POSITIONAL arguments. -# -exp_id allows to set an experimental identifier explicitly (default is -exp_id=exp1) while -# -exp_dir allows setting manually the experimental directory. -# Note, that the latter is done during the preprocessing step in an end-to-end workflow. -# However, if the preprocessing step can be skipped (i.e. preprocessed data already exists), -# one may wish to set the experimental directory explicitly +# Additionally, the following optional arguments can be passed as NON-POSITIONAL arguments: +# -exp_id : set an experimental identifier explicitly (default is -exp_id=exp1) +# -exp_dir : set manually the basic experimental directory +# -exp_dir_ext: set manually the extended basic experimental directory (has to be passed in conjunction with +# -exp_dir!) following the naming convention for storing the trained models and their postprocessed data +# -model : set manually the model to be trained/postprocessed +# Note, that -exp_dir is useful if the first preprocessing step can be skipped (i.e. preprocessed netCDf files already +# exist). +# The optional arguments -exp_dir_ext and -model are additionally used by config_train.py to create the runscripts for +# the training and postprocessing step. # # Examples: -# ./generate_workflow_scripts.sh ../HPC_scripts/generate_era5 venv_hdfml -exp_id=exp5 -# ... will convert generate_era5_template.sh to generate_era5_exp5.sh where +# ./generate_workflow_scripts.sh ../HPC_scripts/preprocess_data_era5_step2 venv_hdfml -exp_id=exp5 +# ... will convert process_data_era5_template.sh to process_data_era5_exp5.sh where # venv_hdfml is the virtual environment for operation. +# Note, that source_dir and destination_dir are not properly set in that case! # -# ./generate_workflow_scripts.sh ../HPC_scripts/generate_era5 venv_hdfml -exp_id=exp5 -exp_dir=testdata -# ... does the same as the previous example, but additionally extends source_dir=[...]/preprocessedData/, -# checkpoint_dir=[...]/models/ and results_dir=[...]/results/ by testdata/ +# ./generate_workflow_scripts.sh ../HPC_scripts/preprocess_data_era5_step2 venv_hdfml -exp_id=exp5 -exp_dir=testdata +# ... does the same as the previous example, but additionally extends source_dir=[...]/preprocessedData/ and +# destination_dir=[...]/models/ properly # **************** Description **************** # # **************** Auxilary functions **************** @@ -34,16 +40,21 @@ check_argin() { exp_id=${argin#"-exp_id="} elif [[ $argin == *"-exp_dir="* ]]; then exp_dir=${argin#"-exp_dir="} + elif [[ $argin == *"-exp_dir_ext"* ]]; then + exp_dir_ext=${argin#"-exp_dir_ext="} + elif [[ $argin == *"-model"* ]]; then + model=${argin#"-model="} fi done } -add_exp_dir() { -# Add exp_dir to paths in <target_script> which end with /<prefix>/ +extend_path() { +# Add <extension> to paths in <target_script> which end with /<prefix>/ prefix=$1 + extension=$2 if [[ `grep "/${prefix}/$" ${target_script}` ]]; then - echo "Add experimental directory after '${prefix}/' in runscript '${target_script}'" - sed -i "s|/${prefix}/$|/${prefix}/${exp_dir}/|g" ${target_script} + echo "Perform extension on path '${prefix}/' in runscript '${target_script}'" + sed -i "s|/${prefix}/$|/${prefix}/${extension}/|g" ${target_script} status=1 fi } @@ -68,10 +79,13 @@ if [[ "$#" -lt 2 ]]; then else curr_script=$1 curr_script_loc="$(basename "$curr_script")" - curr_venv=$2 + curr_venv=$2 # # check if any known non-positional argument is present... if [[ "$#" -gt 2 ]]; then check_argin ${@:3} + if [[ ! -z "${exp_dir_ext}" ]] && [[ -z "${exp_dir}" ]]; then + echo "WARNING: -exp_dir_ext is passed without passing -ext_dir and thus has no effect!" + fi fi #...and ensure that exp_id is always set if [[ -z "${exp_id}" ]]; then @@ -142,7 +156,7 @@ if [[ `grep "exp_id=" ${target_script}` ]]; then fi # set correct e-mail address in Batch scripts on Juwels and HDF-ML -if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then +if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == *juwels* ]]; then if ! [[ -z `command -v jutil` ]]; then USER_EMAIL=$(jutil user show -o json | grep email | cut -f2 -d':' | cut -f1 -d',' | cut -f2 -d'"') else @@ -151,13 +165,26 @@ if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == juwels* ]]; then sed -i "s/--mail-user=.*/--mail-user=$USER_EMAIL/g" ${target_script} fi +# set model if model was passed as optional argument +if [[ ! -z "${model}" ]]; then + sed -i "s/model=.*/model=${model}/g" ${target_script} +fi + # finally set experimental directory if exp_dir is present if [[ ! -z "${exp_dir}" ]]; then + if [[ ! -z "${exp_dir_ext}" ]]; then + status=0 # status to check if exp_dir_ext is added to the runscript at hand + # -> will be set to one by extend_path if modifictaion takes place + extend_path models ${exp_dir_ext}/ + extend_path results ${exp_dir_ext} + + if [[ ${status} == 0 ]]; then + echo "WARNING: -exp_dir_ext has been passed, but no addition to any path in runscript at hand done..." + fi + fi status=0 # status to check if exp_dir is added to the runscript at hand # -> will be set to one by add_exp_dir if modifictaion takes place - add_exp_dir preprocessedData - add_exp_dir models - add_exp_dir results + extend_path preprocessedData ${exp_dir} if [[ ${status} == 0 ]]; then echo "WARNING: -exp_dir has been passed, but no addition to any path in runscript at hand done..." diff --git a/video_prediction_tools/hparams/era5/savp/model_hparams.json b/video_prediction_tools/hparams/era5/savp/model_hparams.json index 641ffb36f764f5ae720a534d7d9eef0ebad644d8..d7058c6b2534d46cd6e08672d33a76bfcc4c7a35 100644 --- a/video_prediction_tools/hparams/era5/savp/model_hparams.json +++ b/video_prediction_tools/hparams/era5/savp/model_hparams.json @@ -11,8 +11,10 @@ "vae_gan_feature_cdist_weight": 10.0, "gan_feature_cdist_weight": 0.0, "state_weight": 0.0, - "nz": 32, - "max_epochs":2 + "nz": 16, + "max_epochs":2, + "context_frames":10, + "sequence_length":20 } diff --git a/video_prediction_tools/main_scripts/main_meta_postprocess.py b/video_prediction_tools/main_scripts/main_meta_postprocess.py index 044dffe0b1cc3ef94a844158456077f19b88f719..51f6e19aaa3c0525a6304b38251474d0793a0a8f 100644 --- a/video_prediction_tools/main_scripts/main_meta_postprocess.py +++ b/video_prediction_tools/main_scripts/main_meta_postprocess.py @@ -78,7 +78,7 @@ class MetaPostprocess(object): @staticmethod def read_values_by_var_from_nc(fl_nc,var="T2",stochastic_ind=0): - #if not var in ["T2","MSL","GPH500"]: raise ValueError ("var name is not correct, should be 'T2','MSL',or 'GPH500'") + if not var in ["T2","MSL","GPH500"]: raise ValueError ("var name is not correct, should be 'T2','MSL',or 'GPH500'") with Dataset(fl_nc, mode = 'r') as fl: #load var prediction, real and persistent values real = fl["/analysis/reference/"].variables[var][:] @@ -133,24 +133,6 @@ class MetaPostprocess(object): evals_forecast = xr.DataArray(eval_forecast_all_dirs, coords=[self.results_dirs, samples , times], dims=["results_dirs", "samples","time_forecast"]) return evals_forecast - - def save_metrics_all_dir_to_json(self): - with open("metrics_results.json","w") as f: - json.dump(self.eval_all,f) - - - def load_results_dir_parameters(self,compare_by="model"): - self.compare_by_values = [] - for results_dir in self.results_dirs: - with open(os.path.join(results_dir, "options_checkpoints.json")) as f: - self.options = json.loads(f.read()) - print("self.options:",self.options) - #if self.compare_by == "model": - self.compare_by_values.append(self.options[compare_by]) - - - - def plot_results(self,one_persistent=True): """ diff --git a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py b/video_prediction_tools/main_scripts/main_preprocess_data_step2.py index cba0bbed917d383d07ed1ee4d4102c22f02816bd..3695b90b30465b6698223bfcf9415149533be6d2 100644 --- a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py +++ b/video_prediction_tools/main_scripts/main_preprocess_data_step2.py @@ -11,7 +11,7 @@ import argparse from mpi4py import MPI from general_utils import get_unique_vars from statistics import Calc_data_stat -from video_prediction.datasets.era5_dataset import * +from data_preprocess.preprocess_data_step2 import * def main(): diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index bbb8d43133c45667f570d54cd0e27ebfeb8a702c..7db0513049dc76526966118a88473100b9b8c637 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -18,10 +18,10 @@ import random import time import numpy as np import tensorflow as tf -from video_prediction import datasets, models +from model_modules.video_prediction import datasets, models import matplotlib.pyplot as plt import pickle as pkl -from video_prediction.utils import tf_utils +from model_modules.video_prediction.utils import tf_utils class TrainModel(object): @@ -58,7 +58,6 @@ class TrainModel(object): self.save_interval = save_interval def setup(self): - self.generate_output_dir() self.set_seed() self.get_model_hparams_dict() self.load_params_from_checkpoints_dir() @@ -271,14 +270,14 @@ class TrainModel(object): """ Start session and train the model """ - global_step = tf.train.get_or_create_global_step() + self.global_step = tf.train.get_or_create_global_step() with tf.Session(config=self.config) as sess: print("parameter_count =", sess.run(self.parameter_count)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) self.restore(sess, self.checkpoint) #sess.graph.finalize() - self.start_step = sess.run(global_step) + self.start_step = sess.run(self.global_step) print("start_step", self.start_step) # start at one step earlier to log everything without doing any training # step is relative to the start_step @@ -322,8 +321,7 @@ class TrainModel(object): #This is the base fetch that for all the models self.fetches = {"train_op": self.video_model.train_op} # fetching the optimizer! self.fetches["summary"] = self.video_model.summary_op - self.fetches["global_step"] = self.video_model.global_step - self.fetches["total_loss"] = self.video_model.total_loss + self.fetches["global_step"] = self.global_step if self.video_model.__class__.__name__ == "McNetVideoPredictionModel": self.fetches_for_train_mcnet() if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": self.fetches_for_train_convLSTM() if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": self.fetches_for_train_savp() @@ -334,7 +332,9 @@ class TrainModel(object): """ Fetch variables in the graph for convLSTM model, this can be custermized based on models and the needs of users """ - pass + self.fetches["total_loss"] = self.video_model.total_loss + + def fetches_for_train_savp(self): @@ -345,6 +345,9 @@ class TrainModel(object): self.fetches["d_losses"] = self.video_model.d_losses self.fetches["d_loss"] = self.video_model.d_loss self.fetches["g_loss"] = self.video_model.g_loss + self.fetches["total_loss"] = self.video_model.g_loss + + def fetches_for_train_mcnet(self): """ @@ -360,13 +363,17 @@ class TrainModel(object): """ self.fetches["latent_loss"] = self.video_model.latent_loss self.fetches["recon_loss"] = self.video_model.recon_loss - + self.fetches["total_loss"] = self.video_model.total_loss def create_fetches_for_val(self): """ Fetch variables in the graph for validation dataset, this can be custermized based on models and the needs of users """ - self.val_fetches = {"total_loss": self.video_model.total_loss} + if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": + self.val_fetches = {"total_loss": self.video_model.g_loss} + else: + self.val_fetches = {"total_loss": self.video_model.total_loss} + self.val_fetches["summary"] = self.video_model.summary_op def write_to_summary(self): diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 1901fdeff7c63303e66e107a859f1b47b9e88331..c78eecf87fa704bf06c700825910914f63110094 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -25,7 +25,7 @@ from metadata import MetaData as MetaData from main_scripts.main_train_models import * from data_preprocess.preprocess_data_step2 import * import shutil -from video_prediction import datasets, models +from model_modules.video_prediction import datasets, models class Postprocess(TrainModel,ERA5Pkl2Tfrecords): @@ -266,8 +266,9 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)#return [batchsize,seq_len,lat,lon,channel] assert gen_images.shape[1] == self.sequence_length - 1 #The generate images seq_len should be sequence_len -1, since the last one is not used for comparing with groud truth gen_images_per_batch = [] - persistent_images_per_batch = [] - ts_batch = [] + if stochastic_sample_ind == 0: + persistent_images_per_batch = [] + ts_batch = [] for i in range(self.batch_size): # generate time stamps for sequences only once, since they are the same for all ensemble members if stochastic_sample_ind == 0: @@ -275,25 +276,25 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): init_date_str = self.ts[0].strftime("%Y%m%d%H") ts_batch.append(init_date_str) # get persistence_images - self.persistence_images, self.ts_persistence = Postprocess.get_persistence(self.ts, - self.input_dir_pkl) + print("self.ts:",self.ts) + self.persistence_images, self.ts_persistence = Postprocess.get_persistence(self.ts,self.input_dir_pkl) persistent_images_per_batch.append(self.persistence_images) self.plot_persistence_images() # Denormalized data for generate gen_images_ = gen_images[i] - self.gen_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl, gen_images_, - self.vars_in) + self.gen_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl, gen_images_, self.vars_in) gen_images_per_batch.append(self.gen_images_denorm) # only plot when the first stochastic ind otherwise too many plots would be created # only plot the stochastic results of user-defined ind self.plot_generate_images(stochastic_sample_ind, self.stochastic_plot_id) - gen_images_stochastic.append(gen_images_per_batch) + gen_images_stochastic.append(gen_images_per_batch) gen_images_stochastic = Postprocess.check_gen_images_stochastic_shape(gen_images_stochastic) # save input and stochastic generate images to netcdf file # For each prediction (either deterministic or ensemble) we create one netCDF file. + print("persistent_images_per_batch",len(np.array(persistent_images_per_batch))) for batch_id in range(self.batch_size): print("batch_id is here",batch_id) self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id], persistent_images_per_batch[batch_id], @@ -512,17 +513,17 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): ################ forecast group ##################### for stochastic_sample_ind in range(self.num_stochastic_samples): #Temperature: - t2 = nc_file.createVariable("/forecast/T2/stochastic/{}".format(stochastic_sample_ind),"f4",("time_forecast","lat","lon"), zlib = True) + t2 = nc_file.createVariable("/forecasts/T2/stochastic/{}".format(stochastic_sample_ind),"f4",("time_forecast","lat","lon"), zlib = True) t2.units = 'K' t2[:,:,:] = gen_images_[stochastic_sample_ind,self.context_frames-1:,:,:,0] #mean sea level pressure - msl = nc_file.createVariable("/forecast/MSL/stochastic/{}".format(stochastic_sample_ind),"f4",("time_forecast","lat","lon"), zlib = True) + msl = nc_file.createVariable("/forecasts/MSL/stochastic/{}".format(stochastic_sample_ind),"f4",("time_forecast","lat","lon"), zlib = True) msl.units = 'Pa' msl[:,:,:] = gen_images_[stochastic_sample_ind,self.context_frames-1:,:,:,1] #Geopotential at 500 - gph500 = nc_file.createVariable("/forecast/GPH500/stochastic/{}".format(stochastic_sample_ind),"f4",("time_forecast","lat","lon"), zlib = True) + gph500 = nc_file.createVariable("/forecasts/GPH500/stochastic/{}".format(stochastic_sample_ind),"f4",("time_forecast","lat","lon"), zlib = True) gph500.units = 'm' gph500[:,:,:] = gen_images_[stochastic_sample_ind,self.context_frames-1:,:,:,2] @@ -537,7 +538,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): if len(np.array(imgs).shape)!=3:raise("img dims should be four: (seq_len,lat,lon)") if np.array(imgs).shape[0]!= len(ts): raise("The len of timestamps should be equal the image seq_len") fig = plt.figure(figsize=(18,6)) - gs = gridspec.GridSpec(1, 10) + gs = gridspec.GridSpec(1, len(ts)) gs.update(wspace = 0., hspace = 0.) xlables = [round(i,2) for i in list(np.linspace(np.min(lons),np.max(lons),5))] ylabels = [round(i,2) for i in list(np.linspace(np.max(lats),np.min(lats),5))] @@ -553,6 +554,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): plt.ylabel(label, fontsize=10) plt.savefig(os.path.join(output_png_dir, label + "_TS_" + str(ts[0]) + ".jpg")) plt.clf() + plt.close() output_fname = label + "_TS_" + ts[0].strftime("%Y%m%d%H") + ".jpg" print("image {} saved".format(output_fname)) @@ -575,28 +577,36 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): in ts_persistence """ ts_persistence = [] + year_origin = ts[0].year for t in range(len(ts)): # Scarlet: this certainly can be made nicer with list comprehension ts_temp = ts[t] - datetime.timedelta(days=1) ts_persistence.append(ts_temp) t_persistence_start = ts_persistence[0] t_persistence_end = ts_persistence[-1] - year_start = t_persistence_start.year + year_start = t_persistence_start.year #Bing to address the issue #43 and Scarelet please confirm this change month_start = t_persistence_start.month month_end = t_persistence_end.month - + print("start year:",year_start) # only one pickle file is needed (all hours during the same month) if month_start == month_end: # Open files to search for the indizes of the corresponding time - time_pickle = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T') + time_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T')) # Open file to search for the correspoding meteorological fields - var_pickle = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X') - # Retrieve starting index + var_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X')) + + if year_origin != year_start: + time_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'T')) + var_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'X')) + time_pickle.extend(time_origin_pickle) + var_pickle.extend(var_origin_pickle) + + # Retrieve starting index ind = list(time_pickle).index(np.array(ts_persistence[0])) #print('Scarlet, Original', ts_persistence) #print('From Pickle', time_pickle[ind:ind+len(ts_persistence)]) - var_persistence = var_pickle[ind:ind+len(ts_persistence)] - time_persistence = time_pickle[ind:ind+len(ts_persistence)].ravel() + var_persistence = np.array(var_pickle)[ind:ind+len(ts_persistence)] + time_persistence = np.array(time_pickle)[ind:ind+len(ts_persistence)].ravel() #print(' Scarlet Shape of time persistence',time_persistence.shape) #print(' Scarlet Shape of var persistence',var_persistence.shape) @@ -612,24 +622,34 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): t_persistence_first_m.append(ts_persistence[t]) if m == month_end: t_persistence_second_m.append(ts_persistence[t]) + if year_origin == year_start: + # Open files to search for the indizes of the corresponding time + time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, month_start, 'T') + time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, month_end, 'T') - # Open files to search for the indizes of the corresponding time - time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, month_start, 'T') - time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, month_end, 'T') - - # Open file to search for the correspoding meteorological fields - var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, month_start, 'X') - var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, month_end, 'X') - + # Open file to search for the correspoding meteorological fields + var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, month_start, 'X') + var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, month_end, 'X') + + if year_origin != year_start: + # Open files to search for the indizes of the corresponding time + time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_origin, 1, 'T') + time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, 12, 'T') + + # Open file to search for the correspoding meteorological fields + var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_origin, 1, 'X') + var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl,year_start, 12, 'X') + + #print('Scarlet, Original', ts_persistence) + #print('From Pickle', time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)], time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]) + #print(' Scarlet before', time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)].shape, time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)].shape) + # Retrieve starting index ind_first_m = list(time_pickle_first).index(np.array(t_persistence_first_m[0])) + print ("time_pickle_second:",time_pickle_second) ind_second_m = list(time_pickle_second).index(np.array(t_persistence_second_m[0])) - #print('Scarlet, Original', ts_persistence) - #print('From Pickle', time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)], time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]) - #print(' Scarlet before', time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)].shape, time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)].shape) - - # append the sequence of the second month to the first month + # append the sequence of the second month to the first month var_persistence = np.concatenate((var_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)], var_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]), axis=0) @@ -637,8 +657,13 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]), axis=0).ravel() # ravel is needed to eliminate the unnecessary dimension (20,1) becomes (20,) #print(' Scarlet concatenate and ravel (time)', var_persistence.shape, time_persistence.shape) - - + + + + + + if len(time_persistence.tolist()) == 0 : raise ("The time_persistent is empty!") + if len(var_persistence) ==0 : raise ("The var persistence is empty!") # tolist() is needed for plotting return var_persistence, time_persistence.tolist() diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py new file mode 100644 index 0000000000000000000000000000000000000000..79e91888ef6b46eea88e5ace3496daf56f436259 --- /dev/null +++ b/video_prediction_tools/model_modules/model_architectures.py @@ -0,0 +1,18 @@ +def known_models(): + """ + An auxilary function + :return: dictionary of known model architectures + """ + model_mappings = { + 'ground_truth': 'GroundTruthVideoPredictionModel', + 'repeat': 'RepeatVideoPredictionModel', + 'savp': 'SAVPVideoPredictionModel', + 'dna': 'DNAVideoPredictionModel', + 'sna': 'SNAVideoPredictionModel', + 'sv2p': 'SV2PVideoPredictionModel', + 'vae': 'VanillaVAEVideoPredictionModel', + 'convLSTM': 'VanillaConvLstmVideoPredictionModel', + 'mcnet': 'McNetVideoPredictionModel', + } + + return model_mappings diff --git a/video_prediction_tools/video_prediction/.DS_Store b/video_prediction_tools/model_modules/video_prediction/.DS_Store similarity index 100% rename from video_prediction_tools/video_prediction/.DS_Store rename to video_prediction_tools/model_modules/video_prediction/.DS_Store diff --git a/video_prediction_tools/video_prediction/__init__.py b/video_prediction_tools/model_modules/video_prediction/__init__.py similarity index 100% rename from video_prediction_tools/video_prediction/__init__.py rename to video_prediction_tools/model_modules/video_prediction/__init__.py diff --git a/video_prediction_tools/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py similarity index 73% rename from video_prediction_tools/video_prediction/datasets/__init__.py rename to video_prediction_tools/model_modules/video_prediction/datasets/__init__.py index 2556d86d55f43ecaba93e3a22be484e9a2af36f7..7a70351e7808103e9a3e02e65654f151213c45ec 100644 --- a/video_prediction_tools/video_prediction/datasets/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py @@ -8,21 +8,11 @@ from .ucf101_dataset import UCF101VideoDataset from .cartgripper_dataset import CartgripperVideoDataset from .era5_dataset import ERA5Dataset from .moving_mnist import MovingMnist +from data_preprocess.dataset_options import known_datasets #from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly def get_dataset_class(dataset): - dataset_mappings = { - 'google_robot': 'GoogleRobotVideoDataset', - 'sv2p': 'SV2PVideoDataset', - 'softmotion': 'SoftmotionVideoDataset', - 'bair': 'SoftmotionVideoDataset', # alias of softmotion - 'kth': 'KTHVideoDataset', - 'ucf101': 'UCF101VideoDataset', - 'cartgripper': 'CartgripperVideoDataset', - "era5":"ERA5Dataset", - "moving_mnist":"MovingMnist" -# "era5_anomaly":"ERA5Dataset_v2_anomaly", - } + dataset_mappings = known_datasets() dataset_class = dataset_mappings.get(dataset, dataset) print("datset_class",dataset_class) if dataset_class is None: diff --git a/video_prediction_tools/video_prediction/datasets/base_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py similarity index 100% rename from video_prediction_tools/video_prediction/datasets/base_dataset.py rename to video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py diff --git a/video_prediction_tools/video_prediction/datasets/cartgripper_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/cartgripper_dataset.py similarity index 100% rename from video_prediction_tools/video_prediction/datasets/cartgripper_dataset.py rename to video_prediction_tools/model_modules/video_prediction/datasets/cartgripper_dataset.py diff --git a/video_prediction_tools/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py similarity index 99% rename from video_prediction_tools/video_prediction/datasets/era5_dataset.py rename to video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py index 25034ea8dd8f671101b148cfe0e27f24da481d29..993a528d52fac33e5a6523a4f321ca511d0fd074 100644 --- a/video_prediction_tools/video_prediction/datasets/era5_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py @@ -113,7 +113,7 @@ class ERA5Dataset(object): if self.filenames: self.filenames = sorted(self.filenames) # ensures order is the same across systems if not self.filenames: - raise FileNotFoundError('No tfrecords were found in %s.' % self.input_dir_tfrecords) + raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir_tfrecords) def get_example_info(self): diff --git a/video_prediction_tools/video_prediction/datasets/google_robot_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/google_robot_dataset.py similarity index 100% rename from video_prediction_tools/video_prediction/datasets/google_robot_dataset.py rename to video_prediction_tools/model_modules/video_prediction/datasets/google_robot_dataset.py diff --git a/video_prediction_tools/video_prediction/datasets/kth_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py similarity index 98% rename from video_prediction_tools/video_prediction/datasets/kth_dataset.py rename to video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py index e1e11d51968e706868fd89f26faa25d1999d3a9b..d0187304caef9b255907354cdef1415579d9a86f 100644 --- a/video_prediction_tools/video_prediction/datasets/kth_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py @@ -8,8 +8,7 @@ import re import tensorflow as tf import numpy as np import skimage.io - -from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset +from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset class KTHVideoDataset(VarLenFeatureVideoDataset): diff --git a/video_prediction_tools/video_prediction/datasets/moving_mnist.py b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py similarity index 99% rename from video_prediction_tools/video_prediction/datasets/moving_mnist.py rename to video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py index 8d1beddcc9c1b54a2cd899326b881a4ba8f53874..2334cc2b7cf02dcd0a9b99c8cbd5f3cf2b2e2900 100644 --- a/video_prediction_tools/video_prediction/datasets/moving_mnist.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py @@ -14,7 +14,7 @@ from mpi4py import MPI from collections import OrderedDict import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec -from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset +from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset import data_preprocess.process_netCDF_v2 from general_utils import get_unique_vars from statistics import Calc_data_stat diff --git a/video_prediction_tools/video_prediction/datasets/softmotion_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/softmotion_dataset.py similarity index 98% rename from video_prediction_tools/video_prediction/datasets/softmotion_dataset.py rename to video_prediction_tools/model_modules/video_prediction/datasets/softmotion_dataset.py index 106869ca46d857d56fdd8d0f0d20ae5e2b69c3b5..bd248984cd7a5aa5957b65ec44fffaca38bdc851 100644 --- a/video_prediction_tools/video_prediction/datasets/softmotion_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/softmotion_dataset.py @@ -1,10 +1,8 @@ import itertools import os import re - import tensorflow as tf - -from video_prediction.utils import tf_utils +from model_modules.video_prediction.utils import tf_utils from .base_dataset import VideoDataset diff --git a/video_prediction_tools/video_prediction/datasets/sv2p_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/sv2p_dataset.py similarity index 100% rename from video_prediction_tools/video_prediction/datasets/sv2p_dataset.py rename to video_prediction_tools/model_modules/video_prediction/datasets/sv2p_dataset.py diff --git a/video_prediction_tools/video_prediction/datasets/ucf101_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/ucf101_dataset.py similarity index 99% rename from video_prediction_tools/video_prediction/datasets/ucf101_dataset.py rename to video_prediction_tools/model_modules/video_prediction/datasets/ucf101_dataset.py index cba078ab4f66acb3fd80546a016fb9dd94b1d551..4273728e257207318aa91a1c0b1673f14ff7f159 100644 --- a/video_prediction_tools/video_prediction/datasets/ucf101_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/ucf101_dataset.py @@ -7,8 +7,7 @@ import re from multiprocessing import Pool import cv2 import tensorflow as tf - -from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset +from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset class UCF101VideoDataset(VarLenFeatureVideoDataset): diff --git a/video_prediction_tools/video_prediction/flow_ops.py b/video_prediction_tools/model_modules/video_prediction/flow_ops.py similarity index 100% rename from video_prediction_tools/video_prediction/flow_ops.py rename to video_prediction_tools/model_modules/video_prediction/flow_ops.py diff --git a/video_prediction_tools/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py similarity index 100% rename from video_prediction_tools/video_prediction/layers/BasicConvLSTMCell.py rename to video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py diff --git a/video_prediction_tools/video_prediction/layers/__init__.py b/video_prediction_tools/model_modules/video_prediction/layers/__init__.py similarity index 100% rename from video_prediction_tools/video_prediction/layers/__init__.py rename to video_prediction_tools/model_modules/video_prediction/layers/__init__.py diff --git a/video_prediction_tools/video_prediction/layers/layer_def.py b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py similarity index 100% rename from video_prediction_tools/video_prediction/layers/layer_def.py rename to video_prediction_tools/model_modules/video_prediction/layers/layer_def.py diff --git a/video_prediction_tools/video_prediction/layers/mcnet_ops.py b/video_prediction_tools/model_modules/video_prediction/layers/mcnet_ops.py similarity index 98% rename from video_prediction_tools/video_prediction/layers/mcnet_ops.py rename to video_prediction_tools/model_modules/video_prediction/layers/mcnet_ops.py index 656f66c0df1cf199fff319f7b81b01594f96332c..fbe0a3d6260366f07971d0198fc73062a73452de 100644 --- a/video_prediction_tools/video_prediction/layers/mcnet_ops.py +++ b/video_prediction_tools/model_modules/video_prediction/layers/mcnet_ops.py @@ -3,7 +3,7 @@ import numpy as np import tensorflow as tf from tensorflow.python.framework import ops -from video_prediction.utils.mcnet_utils import * +from model_modules.video_prediction.utils.mcnet_utils import * def batch_norm(inputs, name, train=True, reuse=False): diff --git a/video_prediction_tools/video_prediction/layers/normalization.py b/video_prediction_tools/model_modules/video_prediction/layers/normalization.py similarity index 100% rename from video_prediction_tools/video_prediction/layers/normalization.py rename to video_prediction_tools/model_modules/video_prediction/layers/normalization.py diff --git a/video_prediction_tools/video_prediction/losses.py b/video_prediction_tools/model_modules/video_prediction/losses.py similarity index 97% rename from video_prediction_tools/video_prediction/losses.py rename to video_prediction_tools/model_modules/video_prediction/losses.py index 662da29dbadf091bc59fcd0e7ed62fd71bcf0f81..0f5c07d1ab521cd21ac4bc63ba208217e4b1f493 100644 --- a/video_prediction_tools/video_prediction/losses.py +++ b/video_prediction_tools/model_modules/video_prediction/losses.py @@ -1,6 +1,6 @@ import tensorflow as tf -from video_prediction.ops import sigmoid_kl_with_logits +from model_modules.video_prediction.ops import sigmoid_kl_with_logits def l1_loss(pred, target): diff --git a/video_prediction_tools/video_prediction/metrics.py b/video_prediction_tools/model_modules/video_prediction/metrics.py similarity index 85% rename from video_prediction_tools/video_prediction/metrics.py rename to video_prediction_tools/model_modules/video_prediction/metrics.py index 6b8a8d1381138c01d13fdf44a3134d1c973e7e06..6ba7464cca2a9b7724aa9be302bb6fb5240950fd 100644 --- a/video_prediction_tools/video_prediction/metrics.py +++ b/video_prediction_tools/model_modules/video_prediction/metrics.py @@ -1,7 +1,7 @@ import tensorflow as tf #import lpips_tf - - +import numpy as np +import math def mse(a, b): return tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1]) @@ -21,7 +21,9 @@ def psnr_imgs(img1, img2): return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) - +def mse_imgs(image1,image2): + mse = ((image1 - image2)**2).mean(axis=None) + return mse # def lpips(input0, input1): # if input0.shape[-1].value == 1: # input0 = tf.tile(input0, [1] * (input0.shape.ndims - 1) + [3]) diff --git a/video_prediction_tools/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py similarity index 64% rename from video_prediction_tools/video_prediction/models/__init__.py rename to video_prediction_tools/model_modules/video_prediction/models/__init__.py index 2ccc20e10df5a0492ce502ebdb30dfa09ccbe1d5..960f608deed07e715190cdecb38efeb2eb4c5ace 100644 --- a/video_prediction_tools/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -11,20 +11,10 @@ from .vanilla_vae_model import VanillaVAEVideoPredictionModel from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel from .mcnet_model import McNetVideoPredictionModel from .test_model import TestModelVideoPredictionModel +from model_modules.model_architectures import known_models def get_model_class(model): - model_mappings = { - 'ground_truth': 'GroundTruthVideoPredictionModel', - 'repeat': 'RepeatVideoPredictionModel', - 'savp': 'SAVPVideoPredictionModel', - 'dna': 'DNAVideoPredictionModel', - 'sna': 'SNAVideoPredictionModel', - 'sv2p': 'SV2PVideoPredictionModel', - 'vae': 'VanillaVAEVideoPredictionModel', - 'convLSTM': 'VanillaConvLstmVideoPredictionModel', - 'mcnet': 'McNetVideoPredictionModel', - 'test_model': 'TestModelVideoPredictionModel' - } + model_mappings = known_models() model_class = model_mappings.get(model, model) model_class = globals().get(model_class) if model_class is None: diff --git a/video_prediction_tools/video_prediction/models/base_model.py b/video_prediction_tools/model_modules/video_prediction/models/base_model.py similarity index 96% rename from video_prediction_tools/video_prediction/models/base_model.py rename to video_prediction_tools/model_modules/video_prediction/models/base_model.py index 0d3bf6e4b554c70671d4678b530688c44f999b77..011dedbd7bf681541c77a85ca6708774654792b6 100644 --- a/video_prediction_tools/video_prediction/models/base_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/base_model.py @@ -7,9 +7,9 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.training import HParams from tensorflow.python.util import nest -import video_prediction as vp -from video_prediction.utils import tf_utils -from video_prediction.utils.tf_utils import compute_averaged_gradients, reduce_tensors, local_device_setter, \ +import model_modules.video_prediction as vp +from model_modules.video_prediction.utils import tf_utils +from model_modules.video_prediction.utils.tf_utils import compute_averaged_gradients, reduce_tensors, local_device_setter, \ replace_read_ops, print_loss_info, transpose_batch_time, add_gif_summaries, add_scalar_summaries, \ add_plot_and_scalar_summaries, add_summaries @@ -371,7 +371,7 @@ class VideoPredictionModel(BaseVideoPredictionModel): beta2=0.999, context_frames=-1, sequence_length=-1, - clip_length=10, #Bing: TODO What is the clip_length, original is 10, + clip_length=10, l1_weight=0.0, l2_weight=1.0, vgg_cdist_weight=0.0, @@ -474,20 +474,7 @@ class VideoPredictionModel(BaseVideoPredictionModel): # be captured here. original_global_variables = tf.global_variables() - - # ########Bing: fine-tune####### - # variables_to_restore = tf.contrib.framework.get_variables_to_restore( - # exclude = ["discriminator/video/sn_fc4/dense/bias"]) - # init_fn = tf.contrib.framework.assign_from_checkpoint_fn(checkpoint) - # restore_variables = tf.contrib.framework.get_variables("discriminator/video/sn_fc4/dense/bias") - # restore_init = tf.variables_initializer(restore_variables) - # restore_optimizer = tf.train.GradientDescentOptimizer( - # learning_rate = 0.001) # TODO: need to change the learning rate - # ###Bing: fine-tune####### - # skip_vars = {" discriminator_encoder/video_sn_fc4/dense/bias"} - if self.num_gpus <= 1: # cpu or 1 gpu - print("self.inputs:>20200822",self.inputs) outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs) self.outputs, self.eval_outputs = outputs_tuple self.d_losses, self.g_losses, g_losses_post = losses_tuple @@ -498,14 +485,7 @@ class VideoPredictionModel(BaseVideoPredictionModel): self.g_vars = tf.trainable_variables(self.generator_scope) g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2) - - if finetune: - ##Bing: fine-tune - #self.g_vars = tf.contrib.framework.get_variables("discriminator/video/sn_fc4/dense/bias")#generator/encoder/layer_3/conv2d/kernel/Adam_1 - self.g_vars = tf.contrib.framework.get_variables("discriminator/encoder/video/sn_conv3_0/conv3d/kernel") - self.g_vars_init = tf.variables_initializer(self.g_vars) - g_optimizer = tf.train.AdamOptimizer(0.00001) - print("############Bing: Fine Tune here##########") + if self.mode == 'train' and (self.d_losses or self.g_losses): with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): @@ -525,11 +505,6 @@ class VideoPredictionModel(BaseVideoPredictionModel): g_gradvars = g_optimizer.compute_gradients(g_loss_post, var_list=self.g_vars) with tf.name_scope('g_apply_gradients'): g_train_op = g_optimizer.apply_gradients(g_gradvars) - # #######Bing; finetune######## - # with tf.name_scope("finetune_gradients"): - # finetune_grads_vars = finetune_vars_optimizer.compute_gradients(self.d_loss, var_list = self.finetune_vars) - # with tf.name_scope("finetune_apply_gradients"): - # finetune_train_op = finetune_vars_optimizer.apply_gradients(finetune_grads_vars) else: g_train_op = tf.no_op() with tf.control_dependencies([g_train_op]): diff --git a/video_prediction_tools/video_prediction/models/dna_model.py b/video_prediction_tools/model_modules/video_prediction/models/dna_model.py similarity index 99% rename from video_prediction_tools/video_prediction/models/dna_model.py rename to video_prediction_tools/model_modules/video_prediction/models/dna_model.py index c4fa8b97bc523382adfa14f564aa30193920ed48..8badf600f62c21d71cd81d8c2bfcde2f75e91d34 100644 --- a/video_prediction_tools/video_prediction/models/dna_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/dna_model.py @@ -21,8 +21,7 @@ import numpy as np import tensorflow as tf import tensorflow.contrib.slim as slim from tensorflow.contrib.layers.python import layers as tf_layers - -from video_prediction.models import VideoPredictionModel +from model_modules.video_prediction.models import VideoPredictionModel from .sna_model import basic_conv_lstm_cell diff --git a/video_prediction_tools/video_prediction/models/mcnet_model.py b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py similarity index 96% rename from video_prediction_tools/video_prediction/models/mcnet_model.py rename to video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py index d9d7842d025de29adf07adc471edad497189b61d..a946bd555a603fd9be14306929e0a8e722a24673 100644 --- a/video_prediction_tools/video_prediction/models/mcnet_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py @@ -10,18 +10,18 @@ from collections import OrderedDict import numpy as np import tensorflow as tf from tensorflow.python.util import nest -from video_prediction import ops, flow_ops -from video_prediction.models import BaseVideoPredictionModel -from video_prediction.models import networks -from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from video_prediction.utils import tf_utils +from model_modules.video_prediction import ops, flow_ops +from model_modules.video_prediction.models import BaseVideoPredictionModel +from model_modules.video_prediction.models import networks +from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from model_modules.video_prediction.utils import tf_utils from datetime import datetime from pathlib import Path -from video_prediction.layers import layer_def as ld -from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell -from video_prediction.layers.mcnet_ops import * -from video_prediction.utils.mcnet_utils import * +from model_modules.video_prediction.layers import layer_def as ld +from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell +from model_modules.video_prediction.layers.mcnet_ops import * +from model_modules.video_prediction.utils.mcnet_utils import * import os class McNetVideoPredictionModel(BaseVideoPredictionModel): diff --git a/video_prediction_tools/video_prediction/models/networks.py b/video_prediction_tools/model_modules/video_prediction/models/networks.py similarity index 92% rename from video_prediction_tools/video_prediction/models/networks.py rename to video_prediction_tools/model_modules/video_prediction/models/networks.py index 844c28295cfea3597bf6a1ce52c9b0f3891370a9..db372ea396daed58e995fd6f7f947ef79670a26a 100644 --- a/video_prediction_tools/video_prediction/models/networks.py +++ b/video_prediction_tools/model_modules/video_prediction/models/networks.py @@ -1,12 +1,12 @@ import tensorflow as tf from tensorflow.python.util import nest -from video_prediction import ops -from video_prediction.ops import conv2d -from video_prediction.ops import dense -from video_prediction.ops import lrelu -from video_prediction.ops import pool2d -from video_prediction.utils import tf_utils +from model_modules.video_prediction import ops +from model_modules.video_prediction.ops import conv2d +from model_modules.video_prediction.ops import dense +from model_modules.video_prediction.ops import lrelu +from model_modules.video_prediction.ops import pool2d +from model_modules.video_prediction.utils import tf_utils def encoder(inputs, nef=64, n_layers=3, norm_layer='instance'): diff --git a/video_prediction_tools/video_prediction/models/non_trainable_model.py b/video_prediction_tools/model_modules/video_prediction/models/non_trainable_model.py similarity index 97% rename from video_prediction_tools/video_prediction/models/non_trainable_model.py rename to video_prediction_tools/model_modules/video_prediction/models/non_trainable_model.py index a5082b5c0357d37262b3236185be89e146d92a84..cdab65c145cc99d7df261f685c65ed2588e989c0 100644 --- a/video_prediction_tools/video_prediction/models/non_trainable_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/non_trainable_model.py @@ -1,6 +1,6 @@ from collections import OrderedDict from tensorflow.python.util import nest -from video_prediction.utils.tf_utils import transpose_batch_time +from model_modules.video_prediction.utils.tf_utils import transpose_batch_time import tensorflow as tf diff --git a/video_prediction_tools/video_prediction/models/savp_model.py b/video_prediction_tools/model_modules/video_prediction/models/savp_model.py similarity index 99% rename from video_prediction_tools/video_prediction/models/savp_model.py rename to video_prediction_tools/model_modules/video_prediction/models/savp_model.py index 039533864f34d1608c5a10d4a664d40ce73594a7..48c7900aaaa5859e932f4c289dac0ee7b91fe627 100644 --- a/video_prediction_tools/video_prediction/models/savp_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/savp_model.py @@ -2,17 +2,15 @@ import collections import functools import itertools from collections import OrderedDict - import numpy as np import tensorflow as tf from tensorflow.python.util import nest - -from video_prediction import ops, flow_ops -from video_prediction.models import VideoPredictionModel -from video_prediction.models import networks -from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from video_prediction.utils import tf_utils +from model_modules.video_prediction import ops, flow_ops +from model_modules.video_prediction.models import VideoPredictionModel +from model_modules.video_prediction.models import networks +from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from model_modules.video_prediction.utils import tf_utils # Amount to use when lower bounding tensors RELU_SHIFT = 1e-12 diff --git a/video_prediction_tools/video_prediction/models/sna_model.py b/video_prediction_tools/model_modules/video_prediction/models/sna_model.py similarity index 99% rename from video_prediction_tools/video_prediction/models/sna_model.py rename to video_prediction_tools/model_modules/video_prediction/models/sna_model.py index ddb04deafc73f49d0466acd74ac4a43d94ac72f0..033f2de90a123f6cda6c2616e5115825182f5386 100644 --- a/video_prediction_tools/video_prediction/models/sna_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/sna_model.py @@ -24,7 +24,7 @@ from tensorflow.contrib.layers.python import layers as tf_layers from tensorflow.contrib.slim import add_arg_scope from tensorflow.contrib.slim import layers -from video_prediction.models import VideoPredictionModel +from model_modules.video_prediction.models import VideoPredictionModel # Amount to use when lower bounding tensors diff --git a/video_prediction_tools/video_prediction/models/sv2p_model.py b/video_prediction_tools/model_modules/video_prediction/models/sv2p_model.py similarity index 99% rename from video_prediction_tools/video_prediction/models/sv2p_model.py rename to video_prediction_tools/model_modules/video_prediction/models/sv2p_model.py index e7a06364178dc380456e47569787ab693ad121de..f0ddd99cecec43348ef00f87162d6dbf51ed95aa 100644 --- a/video_prediction_tools/video_prediction/models/sv2p_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/sv2p_model.py @@ -22,8 +22,7 @@ import tensorflow.contrib.slim as slim from tensorflow.contrib.layers.python import layers as tf_layers from tensorflow.contrib.slim import add_arg_scope from tensorflow.contrib.slim import layers - -from video_prediction.models import VideoPredictionModel +from model_modules.video_prediction.models import VideoPredictionModel # Amount to use when lower bounding tensors diff --git a/video_prediction_tools/video_prediction/models/test_model.py b/video_prediction_tools/model_modules/video_prediction/models/test_model.py similarity index 86% rename from video_prediction_tools/video_prediction/models/test_model.py rename to video_prediction_tools/model_modules/video_prediction/models/test_model.py index 919293c970408686dbfec7e848dcca7f16e0ecf7..0f1770e84c2a181f8a753ddcbf046765a2bf1784 100644 --- a/video_prediction_tools/video_prediction/models/test_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/test_model.py @@ -11,16 +11,16 @@ from collections import OrderedDict import numpy as np import tensorflow as tf from tensorflow.python.util import nest -from video_prediction import ops, flow_ops -from video_prediction.models import BaseVideoPredictionModel -from video_prediction.models import networks -from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from video_prediction.utils import tf_utils +from model_modules.video_prediction import ops, flow_ops +from model_modules.video_prediction.models import BaseVideoPredictionModel +from model_modules.video_prediction.models import networks +from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from model_modules.video_prediction.utils import tf_utils from datetime import datetime from pathlib import Path -from video_prediction.layers import layer_def as ld -from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell +from model_modules.video_prediction.layers import layer_def as ld +from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from tensorflow.contrib.training import HParams class TestModelVideoPredictionModel(object): diff --git a/video_prediction_tools/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py similarity index 91% rename from video_prediction_tools/video_prediction/models/vanilla_convLSTM_model.py rename to video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py index 00dfc6b1aa87c1b8b40291aa402b6512757a6954..ece024616c1f18058ad19cc88497637de15752ac 100644 --- a/video_prediction_tools/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py @@ -9,16 +9,16 @@ from collections import OrderedDict import numpy as np import tensorflow as tf from tensorflow.python.util import nest -from video_prediction import ops, flow_ops -from video_prediction.models import BaseVideoPredictionModel -from video_prediction.models import networks -from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from video_prediction.utils import tf_utils +from model_modules.video_prediction import ops, flow_ops +from model_modules.video_prediction.models import BaseVideoPredictionModel +from model_modules.video_prediction.models import networks +from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from model_modules.video_prediction.utils import tf_utils from datetime import datetime from pathlib import Path -from video_prediction.layers import layer_def as ld -from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell +from model_modules.video_prediction.layers import layer_def as ld +from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from tensorflow.contrib.training import HParams class VanillaConvLstmVideoPredictionModel(object): diff --git a/video_prediction_tools/video_prediction/models/vanilla_vae_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py similarity index 92% rename from video_prediction_tools/video_prediction/models/vanilla_vae_model.py rename to video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py index 8f0812acf89bc732c2f5302da90ac4b95f5045d9..98b8bc144dd84a017256cf3cd03406c01a5bd76d 100644 --- a/video_prediction_tools/video_prediction/models/vanilla_vae_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py @@ -11,15 +11,15 @@ from collections import OrderedDict import numpy as np import tensorflow as tf from tensorflow.python.util import nest -from video_prediction import ops, flow_ops -from video_prediction.models import BaseVideoPredictionModel -from video_prediction.models import networks -from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from video_prediction.utils import tf_utils +from model_modules.video_prediction import ops, flow_ops +from model_modules.video_prediction.models import BaseVideoPredictionModel +from model_modules.video_prediction.models import networks +from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from model_modules.video_prediction.utils import tf_utils from datetime import datetime from pathlib import Path -from video_prediction.layers import layer_def as ld +from model_modules.video_prediction.layers import layer_def as ld class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): def __init__(self, mode='train', aggregate_nccl=None,hparams_dict=None, diff --git a/video_prediction_tools/video_prediction/ops.py b/video_prediction_tools/model_modules/video_prediction/ops.py similarity index 99% rename from video_prediction_tools/video_prediction/ops.py rename to video_prediction_tools/model_modules/video_prediction/ops.py index 1d2e9f2eac608d2a9ec61e24a354882b3acce2de..7e41af95c2941f610235c7fa5744d326d333418d 100644 --- a/video_prediction_tools/video_prediction/ops.py +++ b/video_prediction_tools/model_modules/video_prediction/ops.py @@ -1065,7 +1065,7 @@ def get_norm_layer(layer_type): elif layer_type == 'layer': layer = tf.contrib.layers.layer_norm elif layer_type == 'instance': - from video_prediction.layers import fused_instance_norm + from model_modules.video_prediction.layers import fused_instance_norm layer = fused_instance_norm elif layer_type == 'none': layer = tf.identity diff --git a/video_prediction_tools/video_prediction/rnn_ops.py b/video_prediction_tools/model_modules/video_prediction/rnn_ops.py similarity index 100% rename from video_prediction_tools/video_prediction/rnn_ops.py rename to video_prediction_tools/model_modules/video_prediction/rnn_ops.py diff --git a/video_prediction_tools/video_prediction/utils/README.md b/video_prediction_tools/model_modules/video_prediction/utils/README.md similarity index 100% rename from video_prediction_tools/video_prediction/utils/README.md rename to video_prediction_tools/model_modules/video_prediction/utils/README.md diff --git a/video_prediction_tools/video_prediction/utils/__init__.py b/video_prediction_tools/model_modules/video_prediction/utils/__init__.py similarity index 100% rename from video_prediction_tools/video_prediction/utils/__init__.py rename to video_prediction_tools/model_modules/video_prediction/utils/__init__.py diff --git a/video_prediction_tools/video_prediction/utils/ffmpeg_gif.py b/video_prediction_tools/model_modules/video_prediction/utils/ffmpeg_gif.py similarity index 100% rename from video_prediction_tools/video_prediction/utils/ffmpeg_gif.py rename to video_prediction_tools/model_modules/video_prediction/utils/ffmpeg_gif.py diff --git a/video_prediction_tools/video_prediction/utils/gif_summary.py b/video_prediction_tools/model_modules/video_prediction/utils/gif_summary.py similarity index 98% rename from video_prediction_tools/video_prediction/utils/gif_summary.py rename to video_prediction_tools/model_modules/video_prediction/utils/gif_summary.py index 55f89987855c0288de827326107c9d72abc8ba6c..7f9ce616951a66dee17c1081a4961b4af1d6a57f 100644 --- a/video_prediction_tools/video_prediction/utils/gif_summary.py +++ b/video_prediction_tools/model_modules/video_prediction/utils/gif_summary.py @@ -21,7 +21,7 @@ import numpy as np import tensorflow as tf from tensorflow.python.ops import summary_op_util #from tensorflow.python.distribute.summary_op_util import skip_summary TODO: IMPORT ERRORS IN juwels -from video_prediction.utils import ffmpeg_gif +from model_modules.video_prediction.utils import ffmpeg_gif def py_gif_summary(tag, images, max_outputs, fps): diff --git a/video_prediction_tools/video_prediction/utils/html.py b/video_prediction_tools/model_modules/video_prediction/utils/html.py similarity index 100% rename from video_prediction_tools/video_prediction/utils/html.py rename to video_prediction_tools/model_modules/video_prediction/utils/html.py diff --git a/video_prediction_tools/video_prediction/utils/mcnet_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/mcnet_utils.py similarity index 100% rename from video_prediction_tools/video_prediction/utils/mcnet_utils.py rename to video_prediction_tools/model_modules/video_prediction/utils/mcnet_utils.py diff --git a/video_prediction_tools/video_prediction/utils/tf_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py similarity index 99% rename from video_prediction_tools/video_prediction/utils/tf_utils.py rename to video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py index 51e49a54b11fb07833f0ef33278d2e894b905afc..7a1da880defb61dbd018c6f11ee14c34cf0ce43e 100644 --- a/video_prediction_tools/video_prediction/utils/tf_utils.py +++ b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py @@ -10,9 +10,8 @@ from tensorflow.core.framework import node_def_pb2 from tensorflow.python.framework import device as pydev from tensorflow.python.training import device_setter from tensorflow.python.util import nest - -from video_prediction.utils import ffmpeg_gif -from video_prediction.utils import gif_summary +from model_modules.video_prediction.utils import ffmpeg_gif +from model_modules.video_prediction.utils import gif_summary IMAGE_SUMMARIES = "image_summaries" EVAL_SUMMARIES = "eval_summaries"