diff --git a/.gitignore b/.gitignore index 8e02b5e01fdb0fda7f265b63f53199362c2f8617..5d7d8d4f0ec66e7d19e91b726d39e1d75141e308 100644 --- a/.gitignore +++ b/.gitignore @@ -87,6 +87,7 @@ celerybeat-schedule .venv venv/ ENV/ +virtual_env*/ # Spyder project settings .spyderproject diff --git a/Zam347_scripts/generate_era5.sh b/Zam347_scripts/generate_era5.sh index 5428d697a8224237a1953f037ca7881d01a13d86..d9d710e5c4f3cc2d2825bf67bf2b668f6f9ddbd8 100755 --- a/Zam347_scripts/generate_era5.sh +++ b/Zam347_scripts/generate_era5.sh @@ -5,11 +5,13 @@ source_dir=/home/${USER}/preprocessedData/ checkpoint_dir=/home/${USER}/models/ results_dir=/home/${USER}/results/ -model=savp +# for choosing the model +model=mcnet +# execute respective Python-script python -u ../scripts/generate_transfer_learning_finetune.py \ --input_dir ${source_dir}/tfrecords \ ---dataset_hparams sequence_length=20 --checkpoint ${checkpoint_dir}/${model}/ours_savp \ +--dataset_hparams sequence_length=20 --checkpoint ${checkpoint_dir}/${model} \ --mode test --results_dir ${results_dir} \ --batch_size 2 --dataset era5 > generate_era5-out.out diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh index d2d30725a0b0985332994876072c430149d606d3..aadb25997e2715ac719457c969a6f54982ec93a6 100755 --- a/Zam347_scripts/train_era5.sh +++ b/Zam347_scripts/train_era5.sh @@ -4,7 +4,10 @@ source_dir=/home/${USER}/preprocessedData/ destination_dir=/home/${USER}/models/ -model=savp +# for choosing the model +model=mcnet +model_hparams=../hparams/era5/model_hparams.json -python ../scripts/train_v2.py --input_dir ${source_dir}/tfrecords/ --dataset era5 --model ${model} --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir ${destination_dir}/${model}/ +# execute respective Python-script +python ../scripts/train_dummy.py --input_dir ${source_dir}/tfrecords/ --dataset era5 --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/ #srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/bash/workflow_era5_macOS.sh b/bash/workflow_era5_macOS.sh index 78b5101d810d9818cfb720154e9c42cc321f44a3..1a6ebef38df877b8ee20f628d4e375a20e7c8bd5 100755 --- a/bash/workflow_era5_macOS.sh +++ b/bash/workflow_era5_macOS.sh @@ -56,7 +56,7 @@ echo "=============================================================" # --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR} #fi -#Change the .hkl data to .tfrecords files +####Change the .hkl data to .tfrecords files if [ -d "$DATA_PREPROCESS_TF_DIR" ] then echo "Step2: The Preprocessed Data (tf.records) exist" @@ -90,4 +90,4 @@ fi echo "Step4: Postprocessing start" python ./scripts/generate_transfer_learning_finetune.py --input_dir ${DATA_PREPROCESS_TF_DIR} \ --dataset_hparams sequence_length=20 --checkpoint ${CHECKPOINT_DIR} --mode test --results_dir ${RESULTS_OUTPUT_DIR} \ ---batch_size 4 --dataset era5 \ No newline at end of file +--batch_size 4 --dataset era5 diff --git a/bash/workflow_era5_zam347.sh b/bash/workflow_era5_zam347.sh new file mode 100755 index 0000000000000000000000000000000000000000..ffe7209b6099f4ad9f57b4e90247a7d7acaf009d --- /dev/null +++ b/bash/workflow_era5_zam347.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +set -e +# +#MODEL=savp +##train_mode: end_to_end, pre_trained +#TRAIN_MODE=end_to_end +#EXP_NAME=era5_size_64_64_3_3t_norm + +MODEL=$1 +TRAIN_MODE=$2 +EXP_NAME=$3 +RETRAIN=1 #if we continue training the model or using the existing end-to-end model, 1 means continue training, and 1 means use the existing one +DATA_ETL_DIR=/home/${USER}/ +DATA_ETL_DIR=/p/scratch/deepacf/${USER}/ +DATA_EXTRA_DIR=${DATA_ETL_DIR}/extractedData/${EXP_NAME} +DATA_PREPROCESS_DIR=${DATA_ETL_DIR}/preprocessedData/${EXP_NAME} +DATA_PREPROCESS_TF_DIR=./data/${EXP_NAME} +RESULTS_OUTPUT_DIR=./results_test_samples/${EXP_NAME}/${TRAIN_MODE}/ + +if [ $MODEL==savp ]; then + method_dir=ours_savp +elif [ $MODEL==gan ]; then + method_dir=ours_gan +elif [ $MODEL==vae ]; then + method_dir=ours_vae +else + echo "model does not exist" 2>&1 + exit 1 +fi + +if [ "$TRAIN_MODE" == pre_trained ]; then + TRAIN_OUTPUT_DIR=./pretrained_models/kth/${method_dir} +else + TRAIN_OUTPUT_DIR=./logs/${EXP_NAME}/${TRAIN_MODE} +fi + +CHECKPOINT_DIR=${TRAIN_OUTPUT_DIR}/${method_dir} + +echo "===========================WORKFLOW SETUP====================" +echo "Model ${MODEL}" +echo "TRAIN MODE ${TRAIN_MODE}" +echo "Method_dir ${method_dir}" +echo "DATA_ETL_DIR ${DATA_ETL_DIR}" +echo "DATA_EXTRA_DIR ${DATA_EXTRA_DIR}" +echo "DATA_PREPROCESS_DIR ${DATA_PREPROCESS_DIR}" +echo "DATA_PREPROCESS_TF_DIR ${DATA_PREPROCESS_TF_DIR}" +echo "TRAIN_OUTPUT_DIR ${TRAIN_OUTPUT_DIR}" +echo "=============================================================" + +##############Datat Preprocessing################ +#To hkl data +#if [ -d "$DATA_PREPROCESS_DIR" ]; then +# echo "The Preprocessed Data (.hkl ) exist" +#else +# python ../workflow_video_prediction/DataPreprocess/benchmark/mpi_stager_v2_process_netCDF.py \ +# --input_dir ${DATA_EXTRA_DIR} --destination_dir ${DATA_PREPROCESS_DIR} +#fi + +####Change the .hkl data to .tfrecords files +if [ -d "$DATA_PREPROCESS_TF_DIR" ] +then + echo "Step2: The Preprocessed Data (tf.records) exist" +else + echo "Step2: start, hkl. files to tf.records" + python ./video_prediction/datasets/era5_dataset_v2.py --source_dir ${DATA_PREPROCESS_DIR}/splits \ + --destination_dir ${DATA_PREPROCESS_TF_DIR} + echo "Step2: finish" +fi + +#########Train########################## +if [ "$TRAIN_MODE" == "pre_trained" ]; then + echo "step3: Using kth pre_trained model" +elif [ "$TRAIN_MODE" == "end_to_end" ]; then + echo "step3: End-to-end training" + if [ "$RETRAIN" == 1 ]; then + echo "Using the existing end-to-end model" + else + echo "Training Starts " + python ./scripts/train_v2.py --input_dir $DATA_PREPROCESS_TF_DIR --dataset era5 \ + --model ${MODEL} --model_hparams_dict hparams/kth/${method_dir}/model_hparams.json \ + --output_dir ${TRAIN_OUTPUT_DIR} --checkpoint ${CHECKPOINT_DIR} + echo "Training ends " + fi +else + echo "TRAIN_MODE is end_to_end or pre_trained" + exit 1 +fi + +#########Generate results################# +echo "Step4: Postprocessing start" +python ./scripts/generate_transfer_learning_finetune.py --input_dir ${DATA_PREPROCESS_TF_DIR} \ +--dataset_hparams sequence_length=20 --checkpoint ${CHECKPOINT_DIR} --mode test --results_dir ${RESULTS_OUTPUT_DIR} \ +--batch_size 4 --dataset era5 diff --git a/env_setup/create_env_zam347.sh b/env_setup/create_env_zam347.sh index 04873d26ebc0d307d960ca3ebc134d76b193eba6..95da5f2a7ed86183916d58a3c266846e6f0ca42b 100755 --- a/env_setup/create_env_zam347.sh +++ b/env_setup/create_env_zam347.sh @@ -22,7 +22,7 @@ pip3 install mpi4py pip3 install netCDF4 pip3 install numpy pip3 install h5py -pip3 install tensorflow==1.13.1 +pip3 install tensorflow-gpu==1.13.1 #Copy the hickle package from bing's account #cp -r /p/project/deepacf/deeprain/bing/hickle ${WORKING_DIR} diff --git a/hparams/era5/model_hparams.json b/hparams/era5/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..b121ee2f005b6db753b2536deb804204dd41b78d --- /dev/null +++ b/hparams/era5/model_hparams.json @@ -0,0 +1,11 @@ +{ + "batch_size": 8, + "lr": 0.001, + "nz": 16, + "max_steps":500, + "context_frames":10, + "sequence_length":20 + +} + + diff --git a/hparams/kth/ours_savp/model_hparams.json b/hparams/kth/ours_savp/model_hparams.json index 65c36514efbcc39ab4c3c1a15a58fc4895dc1744..66b41f87e3c0f417b492314060121a0bfd01c8f9 100644 --- a/hparams/kth/ours_savp/model_hparams.json +++ b/hparams/kth/ours_savp/model_hparams.json @@ -11,6 +11,8 @@ "vae_gan_feature_cdist_weight": 10.0, "gan_feature_cdist_weight": 0.0, "state_weight": 0.0, - "nz": 32 + "nz": 32, + "max_steps":20 } + diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py index c4fa831594910b3389a873cc9f8d4dd87944d66e..331559f6287a4f24c1c19ee9f7f4b03309a22abf 100644 --- a/scripts/generate_transfer_learning_finetune.py +++ b/scripts/generate_transfer_learning_finetune.py @@ -31,6 +31,12 @@ from matplotlib.colors import LinearSegmentedColormap #from video_prediction.utils.ffmpeg_gif import save_gif from skimage.metrics import structural_similarity as ssim import datetime +# Scarlet 2020/05/28: access to statistical values in json file +from os import path +import sys +sys.path.append(path.abspath('../video_prediction/datasets/')) +from era5_dataset_v2 import Norm_data +from os.path import dirname with open("../geo_info.json","r") as json_file: geo = json.load(json_file) @@ -82,7 +88,7 @@ def main(): parser.add_argument("--gif_length", type = int, help = "default is sequence_length") parser.add_argument("--fps", type = int, default = 4) - parser.add_argument("--gpu_mem_frac", type = float, default = 0, help = "fraction of gpu memory to use") + parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use") parser.add_argument("--seed", type = int, default = 7) args = parser.parse_args() @@ -208,8 +214,13 @@ def main(): #X_val = hickle.load("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/X_val.hkl") X_test = hickle.load(os.path.join(temporal_dir,"X_test.hkl")) is_first=True - + #+++Scarlet:20200528 + norm_cls = Norm_data('T2') + norm = 'minmax' + with open(os.path.join(dirname(input_dir),"hickle/splits/statistics.json")) as js_file: + norm_cls.check_and_set_norm(json.load(js_file),norm) + #---Scarlet:20200528 while True: print("Sample id", sample_ind) if sample_ind <= 24: @@ -265,9 +276,11 @@ def main(): input_images_ = input_images[i, :] #Bing:20200417 #persistent_images = ? - input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - persistent_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (persistent_X[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - + #+++Scarlet:20200528 + #print('Scarlet1') + input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm) + persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm) + #---Scarlet:20200528 gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in range(sequence_length)] # return the list with 10 (sequence) mse persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in @@ -284,7 +297,10 @@ def main(): #if t==0 : ax1=plt.subplot(gs[t]) ax1 = plt.subplot(gs[ts.index(t)]) - input_image = input_images_[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + #+++Scarlet:20200528 + #print('Scarlet2') + input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm) + #---Scarlet:20200528 plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300) ax1.title.set_text("t = " + str(t+1-10)) plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) @@ -301,7 +317,10 @@ def main(): for t in ts: #if t==0 : ax1=plt.subplot(gs[t]) ax1 = plt.subplot(gs[ts.index(t)]) - gen_image = gen_images_[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + #+++Scarlet:20200528 + #print('Scarlet3') + gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm) + #---Scarlet:20200528 plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300) ax1.title.set_text("t = " + str(t+1-10)) plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) @@ -538,13 +557,20 @@ def main(): with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files: persistent_images_all = pickle.load(gen_files) + #+++Scarlet:20200528 + #print('Scarlet4') input_images_all = np.array(input_images_all) - input_images_all = np.array(input_images_all) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm) + #---Scarlet:20200528 persistent_images_all = np.array(persistent_images_all) if len(np.array(gen_images_all).shape) == 6: for i in range(len(gen_images_all)): + #+++Scarlet:20200528 + #print('Scarlet5') gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:] - gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm) + #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + #---Scarlet:20200528 mse_all = [] psnr_all = [] ssim_all = [] @@ -574,7 +600,11 @@ def main(): f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape)) else: - gen_images_all = np.array(gen_images_all) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + #+++Scarlet:20200528 + #print('Scarlet6') + gen_images_all = np.array(gen_images_all) + gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm) + #---Scarlet:20200528 # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2) diff --git a/scripts/train_dummy.py b/scripts/train_dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..2f892f69c901f1eaa0a7ce2e57a3d0f6f131a7f9 --- /dev/null +++ b/scripts/train_dummy.py @@ -0,0 +1,274 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import random +import time +import numpy as np +import tensorflow as tf +from video_prediction import datasets, models + + +def add_tag_suffix(summary, tag_suffix): + summary_proto = tf.Summary() + summary_proto.ParseFromString(summary) + summary = summary_proto + + for value in summary.value: + tag_split = value.tag.split('/') + value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:]) + return summary.SerializeToString() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") + parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") + parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " + "default is logs_dir/model_fname, where model_fname consists of " + "information from model and model_hparams") + parser.add_argument("--output_dir_postfix", default="") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") + + # parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training") + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + if args.output_dir is None: + list_depth = 0 + model_fname = '' + for t in ('model=%s,%s' % (args.model, args.model_hparams)): + if t == '[': + list_depth += 1 + if t == ']': + list_depth -= 1 + if list_depth and t == ',': + t = '..' + if t in '=,': + t = '.' + if t in '[]': + t = '' + model_fname += t + args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix + + if args.resume: + if args.checkpoint: + raise ValueError('resume and checkpoint cannot both be specified') + args.checkpoint = args.output_dir + + + model_hparams_dict = {} + if args.model_hparams_dict: + with open(args.model_hparams_dict) as f: + model_hparams_dict.update(json.loads(f.read())) + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict.update(json.loads(f.read())) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + train_dataset = VideoDataset( + args.input_dir, + mode='train') + val_dataset = VideoDataset( + args.val_input_dir or args.input_dir, + mode='val') + + variable_scope = tf.get_variable_scope() + variable_scope.set_use_resource(True) + + VideoPredictionModel = models.get_model_class(args.model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': train_dataset.hparams.context_frames, + 'sequence_length': train_dataset.hparams.sequence_length, + 'repeat': train_dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + hparams_dict=hparams_dict, + hparams=args.model_hparams) + + batch_size = model.hparams.batch_size + train_tf_dataset = train_dataset.make_dataset_v2(batch_size) + train_iterator = train_tf_dataset.make_one_shot_iterator() + # The `Iterator.string_handle()` method returns a tensor that can be evaluated + # and used to feed the `handle` placeholder. + train_handle = train_iterator.string_handle() + val_tf_dataset = val_dataset.make_dataset_v2(batch_size) + val_iterator = val_tf_dataset.make_one_shot_iterator() + val_handle = val_iterator.string_handle() + #iterator = tf.data.Iterator.from_string_handle( + # train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) + inputs = train_iterator.get_next() + val = val_iterator.get_next() + + model.build_graph(inputs) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + with tf.name_scope("parameter_count"): + # exclude trainable variables that are replicas (used in multi-gpu setting) + trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables) + parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) + + saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) + + # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero + summary_writer = tf.summary.FileWriter(args.output_dir) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + + + max_steps = model.hparams.max_steps + print ("max_steps",max_steps) + with tf.Session(config=config) as sess: + print("parameter_count =", sess.run(parameter_count)) + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + #coord = tf.train.Coordinator() + #threads = tf.train.start_queue_runners(sess = sess, coord = coord) + print("Init done: {sess.run(tf.local_variables_initializer())}%") + model.restore(sess, args.checkpoint) + + #sess.run(model.post_init_ops) + + #val_handle_eval = sess.run(val_handle) + #print ("val_handle_val",val_handle_eval) + #print("val handle done") + sess.graph.finalize() + start_step = sess.run(model.global_step) + + + # start at one step earlier to log everything without doing any training + # step is relative to the start_step + for step in range(-1, max_steps - start_step): + global_step = sess.run(model.global_step) + print ("global_step:", global_step) + val_handle_eval = sess.run(val_handle) + + if step == 1: + # skip step -1 and 0 for timing purposes (for warmstarting) + start_time = time.time() + + fetches = {"global_step":model.global_step} + fetches["train_op"] = model.train_op + + # fetches["latent_loss"] = model.latent_loss + fetches["total_loss"] = model.total_loss + if model.__class__.__name__ == "McNetVideoPredictionModel": + fetches["L_p"] = model.L_p + fetches["L_gdl"] = model.L_gdl + fetches["L_GAN"] =model.L_GAN + + + + fetches["summary"] = model.summary_op + + run_start_time = time.time() + #Run training results + #X = inputs["images"].eval(session=sess) + + results = sess.run(fetches) + + run_elapsed_time = time.time() - run_start_time + if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: + print('running train_op took too long (%0.1fs)' % run_elapsed_time) + + #Run testing results + #val_fetches = {"global_step":global_step} + val_fetches = {} + #val_fetches["latent_loss"] = model.latent_loss + #val_fetches["total_loss"] = model.total_loss + val_fetches["summary"] = model.summary_op + val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval}) + + summary_writer.add_summary(results["summary"]) + summary_writer.add_summary(val_results["summary"]) + + + + + val_datasets = [val_dataset] + val_models = [model] + + # for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)): + # sess.run(val_model.accum_eval_metrics_reset_op) + # # traverse (roughly up to rounding based on the batch size) all the validation dataset + # accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size + # val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op} + # for update_step in range(accum_eval_summary_num_updates): + # print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates)) + # val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) + # accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1)) + # print("recording accum eval summary") + # summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) + summary_writer.flush() + + # global_step will have the correct step count if we resume from a checkpoint + # global step is read before it's incremented + steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size + #train_epoch = results["global_step"] / steps_per_epoch + train_epoch = global_step/steps_per_epoch + print("progress global step %d epoch %0.1f" % (global_step + 1, train_epoch)) + if step > 0: + elapsed_time = time.time() - start_time + average_time = elapsed_time / step + images_per_sec = batch_size / average_time + remaining_time = (max_steps - (start_step + step + 1)) * average_time + print("image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % + (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) + + + print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"])) + + print("saving model to", args.output_dir) + saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)##Bing: cheat here a little bit because of the global step issue + print("done") + +if __name__ == '__main__': + main() diff --git a/video_prediction/datasets/era5_dataset_v2.py b/video_prediction/datasets/era5_dataset_v2.py index 00111565556014aaa868ba4f8d0c98d8c25ac732..9e32e0c638b4b39c588389e906ba29be5144ee35 100644 --- a/video_prediction/datasets/era5_dataset_v2.py +++ b/video_prediction/datasets/era5_dataset_v2.py @@ -246,7 +246,7 @@ class Norm_data: # do the denormalization and return if norm == "minmax": - return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"max")) + return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"min")) elif norm == "znorm": return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg")) diff --git a/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction/layers/BasicConvLSTMCell.py new file mode 100644 index 0000000000000000000000000000000000000000..321f6cc7e05320cf83e1173d8004429edf07ec24 --- /dev/null +++ b/video_prediction/layers/BasicConvLSTMCell.py @@ -0,0 +1,148 @@ + +import tensorflow as tf +from .layer_def import * + +class ConvRNNCell(object): + """Abstract object representing an Convolutional RNN cell. + """ + + def __call__(self, inputs, state, scope=None): + """Run this RNN cell on inputs, starting from the given state. + """ + raise NotImplementedError("Abstract method") + + @property + def state_size(self): + """size(s) of state(s) used by this cell. + """ + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer or TensorShape: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + def zero_state(self,input, dtype): + """Return zero-filled state tensor(s). + Args: + batch_size: int, float, or unit Tensor representing the batch size. + dtype: the data type to use for the state. + Returns: + tensor of shape '[batch_size x shape[0] x shape[1] x num_features] + filled with zeros + """ + + shape = self.shape + num_features = self.num_features + #x= tf.placeholder(tf.float32, shape=[input.shape[0], shape[0], shape[1], num_features * 2])#Bing: add this to + zeros = tf.zeros([tf.shape(input)[0], shape[0], shape[1], num_features * 2]) + #zeros = tf.zeros_like(x) + return zeros + + +class BasicConvLSTMCell(ConvRNNCell): + """Basic Conv LSTM recurrent network cell. The + """ + + def __init__(self, shape, filter_size, num_features, forget_bias=1.0, input_size=None, + state_is_tuple=False, activation=tf.nn.tanh): + """Initialize the basic Conv LSTM cell. + Args: + shape: int tuple thats the height and width of the cell + filter_size: int tuple thats the height and width of the filter + num_features: int thats the depth of the cell + forget_bias: float, The bias added to forget gates (see above). + input_size: Deprecated and unused. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. The latter behavior will soon be deprecated. + activation: Activation function of the inner states. + """ + # if not state_is_tuple: + # logging.warn("%s: Using a concatenated state is slower and will soon be " + # "deprecated. Use state_is_tuple=True.", self) + if input_size is not None: + logging.warn("%s: The input_size parameter is deprecated.", self) + self.shape = shape + self.filter_size = filter_size + self.num_features = num_features + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + self._activation = activation + + @property + def state_size(self): + return (LSTMStateTuple(self._num_units, self._num_units) + if self._state_is_tuple else 2 * self._num_units) + + @property + def output_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None,reuse=None): + """Long short-term memory cell (LSTM).""" + with tf.variable_scope(scope or type(self).__name__,reuse=reuse): # "BasicLSTMCell" + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state) + concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat) + + new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * + self._activation(j)) + new_h = self._activation(new_c) * tf.nn.sigmoid(o) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = tf.concat(axis = 3, values = [new_c, new_h]) + return new_h, new_state + + +def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=None): + """convolution: + Args: + args: a 4D Tensor or a list of 4D, batch x n, Tensors. + filter_size: int tuple of filter height and width. + num_features: int, number of features. + bias_start: starting value to initialize the bias; 0 by default. + scope: VariableScope for the created subgraph; defaults to "Linear". + Returns: + A 4D Tensor with shape [batch h w num_features] + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + + # Calculate the total size of arguments on dimension 1. + total_arg_size_depth = 0 + shapes = [a.get_shape().as_list() for a in args] + for shape in shapes: + if len(shape) != 4: + raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes)) + if not shape[3]: + raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes)) + else: + total_arg_size_depth += shape[3] + + dtype = [a.dtype for a in args][0] + + # Now the computation. + with tf.variable_scope(scope or "Conv"): + matrix = tf.get_variable( + "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype) + if len(args) == 1: + res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME') + else: + res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME') + if not bias: + return res + bias_term = tf.get_variable( + "Bias", [num_features], + dtype = dtype, + initializer = tf.constant_initializer( + bias_start, dtype = dtype)) + return res + bias_term diff --git a/video_prediction/layers/layer_def.py b/video_prediction/layers/layer_def.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7f4387001c9318507ad809d7176071312742d0 --- /dev/null +++ b/video_prediction/layers/layer_def.py @@ -0,0 +1,160 @@ +"""functions used to construct different architectures +""" + +import tensorflow as tf +import numpy as np + +weight_decay = 0.0005 +def _activation_summary(x): + """Helper to create summaries for activations. + Creates a summary that provides a histogram of activations. + Creates a summary that measure the sparsity of activations. + Args: + x: Tensor + Returns: + nothing + """ + tensor_name = x.op.name + tf.summary.histogram(tensor_name + '/activations', x) + tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) + +def _variable_on_cpu(name, shape, initializer): + """Helper to create a Variable stored on CPU memory. + Args: + name: name of the variable + shape: list of ints + initializer: initializer for Variable + Returns: + Variable Tensor + """ + with tf.device('/cpu:0'): + var = tf.get_variable(name, shape, initializer=initializer) + return var + + +def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.layers.xavier_initializer()): + """Helper to create an initialized Variable with weight decay. + Note that the Variable is initialized with a truncated normal distribution. + A weight decay is added only if one is specified. + Args: + name: name of the variable + shape: list of ints + stddev: standard deviation of a truncated Gaussian + wd: add L2Loss weight decay multiplied by this float. If None, weight + decay is not added for this Variable. + Returns: + Variable Tensor + """ + #var = _variable_on_cpu(name, shape,tf.truncated_normal_initializer(stddev = stddev)) + var = _variable_on_cpu(name, shape, initializer) + if wd: + weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name = 'weight_loss') + weight_decay.set_shape([]) + tf.add_to_collection('losses', weight_decay) + return var + + +def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"): + print("conv_layer activation function",activate) + + with tf.variable_scope('{0}_conv'.format(idx)) as scope: + + input_channels = inputs.get_shape()[-1] + weights = _variable_with_weight_decay('weights',shape = [kernel_size, kernel_size, + input_channels, num_features], + stddev = 0.01, wd = weight_decay) + biases = _variable_on_cpu('biases', [num_features], initializer) + conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding = 'SAME') + conv_biased = tf.nn.bias_add(conv, biases) + if activate == "linear": + return conv_biased + elif activate == "relu": + conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "elu": + conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "leaky_relu": + conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx)) + else: + raise ("activation function is not correct") + return conv_rect + + +def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer(),activate="relu"): + with tf.variable_scope('{0}_trans_conv'.format(idx)) as scope: + input_channels = inputs.get_shape()[3] + input_shape = inputs.get_shape().as_list() + + + weights = _variable_with_weight_decay('weights', + shape = [kernel_size, kernel_size, num_features, input_channels], + stddev = 0.1, wd = weight_decay) + biases = _variable_on_cpu('biases', [num_features],initializer) + batch_size = tf.shape(inputs)[0] + + output_shape = tf.stack( + [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features]) + print ("output_shape",output_shape) + conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME') + conv_biased = tf.nn.bias_add(conv, biases) + if activate == "linear": + return conv_biased + elif activate == "elu": + return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "relu": + return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "leaky_relu": + return tf.nn.leaky_relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "sigmoid": + return tf.nn.sigmoid(conv_biased, name ='sigmoid') + else: + return conv_biased + + +def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01,initializer=tf.contrib.layers.xavier_initializer()): + with tf.variable_scope('{0}_fc'.format(idx)) as scope: + input_shape = inputs.get_shape().as_list() + if flat: + dim = input_shape[1] * input_shape[2] * input_shape[3] + inputs_processed = tf.reshape(inputs, [-1, dim]) + else: + dim = input_shape[1] + inputs_processed = inputs + + weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init, + wd = weight_decay) + biases = _variable_on_cpu('biases', [hiddens],initializer) + if activate == "linear": + return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc') + elif activate == "sigmoid": + return tf.nn.sigmoid(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "softmax": + return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "relu": + return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "leaky_relu": + return tf.nn.leaky_relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + else: + ip = tf.add(tf.matmul(inputs_processed, weights), biases) + return tf.nn.elu(ip, name = str(idx) + '_fc') + +def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None): + with tf.variable_scope('{0}_bn'.format(idx)) as scope: + #Calculate batch mean and variance + shape = inputs.get_shape().as_list() + scale = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=is_training) + beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=is_training) + pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False) + pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False) + + if is_training: + batch_mean, batch_var = tf.nn.moments(inputs,[0]) + train_mean = tf.assign(pop_mean,pop_mean * decay + batch_mean * (1 - decay)) + train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) + with tf.control_dependencies([train_mean,train_var]): + return tf.nn.batch_normalization(inputs,batch_mean,batch_var,beta,scale,epsilon) + else: + return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon) + +def bn_layers_wrapper(inputs, is_training): + pass + \ No newline at end of file diff --git a/video_prediction/layers/mcnet_ops.py b/video_prediction/layers/mcnet_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..656f66c0df1cf199fff319f7b81b01594f96332c --- /dev/null +++ b/video_prediction/layers/mcnet_ops.py @@ -0,0 +1,178 @@ +import math +import numpy as np +import tensorflow as tf + +from tensorflow.python.framework import ops +from video_prediction.utils.mcnet_utils import * + + +def batch_norm(inputs, name, train=True, reuse=False): + return tf.contrib.layers.batch_norm(inputs=inputs,is_training=train, + reuse=reuse,scope=name,scale=True) + + +def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="conv2d", reuse=False, padding='SAME'): + with tf.variable_scope(name, reuse=reuse): + w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], + initializer=tf.contrib.layers.xavier_initializer()) + conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding) + + biases = tf.get_variable('biases', [output_dim], + initializer=tf.constant_initializer(0.0)) + conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + + return conv + + +def deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, + name="deconv2d", reuse=False, with_w=False, padding='SAME'): + with tf.variable_scope(name, reuse=reuse): + # filter : [height, width, output_channels, in_channels] + w = tf.get_variable('w', [k_h, k_h, output_shape[-1], + input_.get_shape()[-1]], + initializer=tf.contrib.layers.xavier_initializer()) + + try: + deconv = tf.nn.conv2d_transpose(input_, w, + output_shape=output_shape, + strides=[1, d_h, d_w, 1], + padding=padding) + + # Support for verisons of TensorFlow before 0.7.0 + except AttributeError: + deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, + strides=[1, d_h, d_w, 1]) + biases = tf.get_variable('biases', [output_shape[-1]], + initializer=tf.constant_initializer(0.0)) + deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) + + if with_w: + return deconv, w, biases + else: + return deconv + + +def lrelu(x, leak=0.2, name="lrelu"): + with tf.variable_scope(name): + f1 = 0.5 * (1 + leak) + f2 = 0.5 * (1 - leak) + return f1 * x + f2 * abs(x) + + +def relu(x): + return tf.nn.relu(x) + + +def tanh(x): + return tf.nn.tanh(x) + + +def shape2d(a): + """ + a: a int or tuple/list of length 2 + """ + if type(a) == int: + return [a, a] + if isinstance(a, (list, tuple)): + assert len(a) == 2 + return list(a) + raise RuntimeError("Illegal shape: {}".format(a)) + + +def shape4d(a): + # for use with tensorflow + return [1] + shape2d(a) + [1] + + +def UnPooling2x2ZeroFilled(x): + out = tf.concat(axis=3, values=[x, tf.zeros_like(x)]) + out = tf.concat(axis=2, values=[out, tf.zeros_like(out)]) + + sh = x.get_shape().as_list() + if None not in sh[1:]: + out_size = [-1, sh[1] * 2, sh[2] * 2, sh[3]] + return tf.reshape(out, out_size) + else: + sh = tf.shape(x) + return tf.reshape(out, [-1, sh[1] * 2, sh[2] * 2, sh[3]]) + + +def MaxPooling(x, shape, stride=None, padding='VALID'): + """ + MaxPooling on images. + :param input: NHWC tensor. + :param shape: int or [h, w] + :param stride: int or [h, w]. default to be shape. + :param padding: 'valid' or 'same'. default to 'valid' + :returns: NHWC tensor. + """ + padding = padding.upper() + shape = shape4d(shape) + if stride is None: + stride = shape + else: + stride = shape4d(stride) + + return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding) + + +#@layer_register() +def FixedUnPooling(x, shape): + """ + Unpool the input with a fixed mat to perform kronecker product with. + :param input: NHWC tensor + :param shape: int or [h, w] + :returns: NHWC tensor + """ + shape = shape2d(shape) + + # a faster implementation for this special case + return UnPooling2x2ZeroFilled(x) + + +def gdl(gen_frames, gt_frames, alpha): + """ + Calculates the sum of GDL losses between the predicted and gt frames. + @param gen_frames: The predicted frames at each scale. + @param gt_frames: The ground truth frames at each scale + @param alpha: The power to which each gradient term is raised. + @return: The GDL loss. + """ + # create filters [-1, 1] and [[1],[-1]] + # for diffing to the left and down respectively. + pos = tf.constant(np.identity(3), dtype=tf.float32) + neg = -1 * pos + # [-1, 1] + filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) + # [[1],[-1]] + filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) + strides = [1, 1, 1, 1] # stride of (1, 1) + padding = 'SAME' + + gen_dx = tf.abs(tf.nn.conv2d(gen_frames, filter_x, strides, padding=padding)) + gen_dy = tf.abs(tf.nn.conv2d(gen_frames, filter_y, strides, padding=padding)) + gt_dx = tf.abs(tf.nn.conv2d(gt_frames, filter_x, strides, padding=padding)) + gt_dy = tf.abs(tf.nn.conv2d(gt_frames, filter_y, strides, padding=padding)) + + grad_diff_x = tf.abs(gt_dx - gen_dx) + grad_diff_y = tf.abs(gt_dy - gen_dy) + + gdl_loss = tf.reduce_mean((grad_diff_x ** alpha + grad_diff_y ** alpha)) + + # condense into one tensor and avg + return gdl_loss + + +def linear(input_, output_size, name, stddev=0.02, bias_start=0.0, + reuse=False, with_w=False): + shape = input_.get_shape().as_list() + + with tf.variable_scope(name, reuse=reuse): + matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, + tf.random_normal_initializer(stddev=stddev)) + bias = tf.get_variable("bias", [output_size], + initializer=tf.constant_initializer(bias_start)) + if with_w: + return tf.matmul(input_, matrix) + bias, matrix, bias + else: + return tf.matmul(input_, matrix) + bias diff --git a/video_prediction/models/__init__.py b/video_prediction/models/__init__.py index d88b573127a1956b0532c8b3a8b1abc15010eb30..6d7323f3750949b0ddb411d4a98934928537bc53 100644 --- a/video_prediction/models/__init__.py +++ b/video_prediction/models/__init__.py @@ -7,8 +7,9 @@ from .savp_model import SAVPVideoPredictionModel from .dna_model import DNAVideoPredictionModel from .sna_model import SNAVideoPredictionModel from .sv2p_model import SV2PVideoPredictionModel - - +from .vanilla_vae_model import VanillaVAEVideoPredictionModel +from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel +from .mcnet_model import McNetVideoPredictionModel def get_model_class(model): model_mappings = { 'ground_truth': 'GroundTruthVideoPredictionModel', @@ -17,7 +18,11 @@ def get_model_class(model): 'dna': 'DNAVideoPredictionModel', 'sna': 'SNAVideoPredictionModel', 'sv2p': 'SV2PVideoPredictionModel', - } + 'vae': 'VanillaVAEVideoPredictionModel', + 'convLSTM': 'VanillaConvLstmVideoPredictionModel', + 'mcnet': 'McNetVideoPredictionModel', + + } model_class = model_mappings.get(model, model) model_class = globals().get(model_class) if model_class is None or not issubclass(model_class, BaseVideoPredictionModel): diff --git a/video_prediction/models/mcnet_model.py b/video_prediction/models/mcnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..725ce4f46a301b6aa07f3d50ef811584d5b502db --- /dev/null +++ b/video_prediction/models/mcnet_model.py @@ -0,0 +1,467 @@ +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from video_prediction import ops, flow_ops +from video_prediction.models import BaseVideoPredictionModel +from video_prediction.models import networks +from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from video_prediction.layers import layer_def as ld +from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell +from video_prediction.layers.mcnet_ops import * +from video_prediction.utils.mcnet_utils import * +import os + +class McNetVideoPredictionModel(BaseVideoPredictionModel): + def __init__(self, mode='train', hparams_dict=None, + hparams=None, **kwargs): + super(McNetVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) + self.mode = mode + self.lr = self.hparams.lr + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length + self.predict_frames = self.sequence_length - self.context_frames + self.df_dim = self.hparams.df_dim + self.gf_dim = self.hparams.gf_dim + self.alpha = self.hparams.alpha + self.beta = self.hparams.beta + self.gen_images_enc = None + self.recon_loss = None + self.latent_loss = None + self.total_loss = None + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + batch_size: batch size for training. + lr: learning rate. if decay steps is non-zero, this is the + learning rate for steps <= decay_step. + + + + + max_steps: number of training steps. + + + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + df_dim: specific parameters for mcnet + gf_dim: specific parameters for menet + alpha: specific parameters for mcnet + beta: specific paramters for mcnet + + """ + default_hparams = super(McNetVideoPredictionModel, self).get_default_hparams_dict() + hparams = dict( + batch_size=16, + lr=0.001, + max_steps=350000, + context_frames = 10, + sequence_length = 20, + nz = 16, + gf_dim = 64, + df_dim = 64, + alpha = 1, + beta = 0.0 + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def build_graph(self, x): + + self.x = x["images"] + self.x_shape = self.x.get_shape().as_list() + self.batch_size = self.x_shape[0] + self.image_size = [self.x_shape[2],self.x_shape[3]] + self.c_dim = self.x_shape[4] + self.diff_shape = [self.batch_size, self.context_frames-1, self.image_size[0], + self.image_size[1], self.c_dim] + self.xt_shape = [self.batch_size, self.image_size[0], self.image_size[1],self.c_dim] + self.is_train = True + + + self.global_step = tf.Variable(0, name='global_step', trainable=False) + original_global_variables = tf.global_variables() + + # self.xt = tf.placeholder(tf.float32, self.xt_shape, name='xt') + self.xt = self.x[:, self.context_frames - 1, :, :, :] + + self.diff_in = tf.placeholder(tf.float32, self.diff_shape, name='diff_in') + diff_in_all = [] + for t in range(1, self.context_frames): + prev = self.x[:, t-1:t, :, :, :] + next = self.x[:, t:t+1, :, :, :] + #diff_in = tf.reshape(next - prev, [self.batch_size, 1, self.image_size[0], self.image_size[1], -1]) + print("prev:",prev) + print("next:",next) + diff_in = tf.subtract(next,prev) + print("diff_in:",diff_in) + diff_in_all.append(diff_in) + + self.diff_in = tf.concat(axis = 1, values = diff_in_all) + + cell = BasicConvLSTMCell([self.image_size[0] / 8, self.image_size[1] / 8], [3, 3], 256) + + pred = self.forward(self.diff_in, self.xt, cell) + + + self.G = tf.concat(axis=1, values=pred)#[batch_size,context_frames,image1,image2,channels] + print ("1:self.G:",self.G) + if self.is_train: + + true_sim = self.x[:, self.context_frames:, :, :, :] + + # Bing: the following make sure the channel is three dimension, if the channle is 3 then will be duplicated + if self.c_dim == 1: true_sim = tf.tile(true_sim, [1, 1, 1, 1, 3]) + + # Bing: the raw inputs shape is [batch_size, image_size[0],self.image_size[1], num_seq, channel]. tf.transpose will transpoe the shape into + # [batch size*num_seq, image_size0, image_size1, channels], for our era5 case, we do not need transpose + # true_sim = tf.reshape(tf.transpose(true_sim,[0,3,1,2,4]), + # [-1, self.image_size[0], + # self.image_size[1], 3]) + true_sim = tf.reshape(true_sim, [-1, self.image_size[0], self.image_size[1], 3]) + + + + + gen_sim = self.G + + #combine groud truth and predict frames + self.x_hat = tf.concat([self.x[:, :self.context_frames, :, :, :], self.G], 1) + print ("self.x_hat:",self.x_hat) + if self.c_dim == 1: gen_sim = tf.tile(gen_sim, [1, 1, 1, 1, 3]) + # gen_sim = tf.reshape(tf.transpose(gen_sim,[0,3,1,2,4]), + # [-1, self.image_size[0], + # self.image_size[1], 3]) + + gen_sim = tf.reshape(gen_sim, [-1, self.image_size[0], self.image_size[1], 3]) + + + binput = tf.reshape(tf.transpose(self.x[:, :self.context_frames, :, :, :], [0, 1, 2, 3, 4]), + [self.batch_size, self.image_size[0], + self.image_size[1], -1]) + + btarget = tf.reshape(tf.transpose(self.x[:, self.context_frames:, :, :, :], [0, 1, 2, 3, 4]), + [self.batch_size, self.image_size[0], + self.image_size[1], -1]) + bgen = tf.reshape(self.G, [self.batch_size, + self.image_size[0], + self.image_size[1], -1]) + + print ("binput:",binput) + print("btarget:",btarget) + print("bgen:",bgen) + + good_data = tf.concat(axis=3, values=[binput, btarget]) + gen_data = tf.concat(axis=3, values=[binput, bgen]) + self.gen_data = gen_data + print ("2:self.gen_data:", self.gen_data) + with tf.variable_scope("DIS", reuse=False): + self.D, self.D_logits = self.discriminator(good_data) + + with tf.variable_scope("DIS", reuse=True): + self.D_, self.D_logits_ = self.discriminator(gen_data) + + self.L_p = tf.reduce_mean( + tf.square(self.G - self.x[:, self.context_frames:, :, :, :])) + + self.L_gdl = gdl(gen_sim, true_sim, 1.) + self.L_img = self.L_p + self.L_gdl + + self.d_loss_real = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits( + logits = self.D_logits, labels = tf.ones_like(self.D) + )) + self.d_loss_fake = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits( + logits = self.D_logits_, labels = tf.zeros_like(self.D_) + )) + self.d_loss = self.d_loss_real + self.d_loss_fake + self.L_GAN = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits( + logits = self.D_logits_, labels = tf.ones_like(self.D_) + )) + + self.loss_sum = tf.summary.scalar("L_img", self.L_img) + self.L_p_sum = tf.summary.scalar("L_p", self.L_p) + self.L_gdl_sum = tf.summary.scalar("L_gdl", self.L_gdl) + self.L_GAN_sum = tf.summary.scalar("L_GAN", self.L_GAN) + self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) + self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real) + self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake) + + self.total_loss = self.alpha * self.L_img + self.beta * self.L_GAN + self._loss_sum = tf.summary.scalar("total_loss", self.total_loss) + self.g_sum = tf.summary.merge([self.L_p_sum, + self.L_gdl_sum, self.loss_sum, + self.L_GAN_sum]) + self.d_sum = tf.summary.merge([self.d_loss_real_sum, self.d_loss_sum, + self.d_loss_fake_sum]) + + + self.t_vars = tf.trainable_variables() + self.g_vars = [var for var in self.t_vars if 'DIS' not in var.name] + self.d_vars = [var for var in self.t_vars if 'DIS' in var.name] + num_param = 0.0 + for var in self.g_vars: + num_param += int(np.prod(var.get_shape())); + print("Number of parameters: %d" % num_param) + + # Training + self.d_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize( + self.d_loss, var_list = self.d_vars) + self.g_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5).minimize( + self.alpha * self.L_img + self.beta * self.L_GAN, var_list = self.g_vars, global_step=self.global_step) + + self.train_op = [self.d_optim,self.g_optim] + self.outputs = {} + self.outputs["gen_images"] = self.x_hat + + + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + return + + + def forward(self, diff_in, xt, cell): + # Initial state + state = tf.zeros([self.batch_size, self.image_size[0] / 8, + self.image_size[1] / 8, 512]) + reuse = False + # Encoder + for t in range(self.context_frames - 1): + enc_h, res_m = self.motion_enc(diff_in[:, t, :, :, :], reuse = reuse) + h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = reuse) + reuse = True + pred = [] + # Decoder + for t in range(self.predict_frames): + if t == 0: + h_cont, res_c = self.content_enc(xt, reuse = False) + h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = False) + res_connect = self.residual(res_m, res_c, reuse = False) + x_hat = self.dec_cnn(h_tp1, res_connect, reuse = False) + + else: + + enc_h, res_m = self.motion_enc(diff_in, reuse = True) + h_dyn, state = cell(enc_h, state, scope = 'lstm', reuse = True) + h_cont, res_c = self.content_enc(xt, reuse = reuse) + h_tp1 = self.comb_layers(h_dyn, h_cont, reuse = True) + res_connect = self.residual(res_m, res_c, reuse = True) + x_hat = self.dec_cnn(h_tp1, res_connect, reuse = True) + print ("x_hat :",x_hat) + if self.c_dim == 3: + # Network outputs are BGR so they need to be reversed to use + # rgb_to_grayscale + #x_hat_gray = tf.concat(axis=3,values=[x_hat[:,:,:,2:3], x_hat[:,:,:,1:2],x_hat[:,:,:,0:1]]) + #xt_gray = tf.concat(axis=3,values=[xt[:,:,:,2:3], xt[:,:,:,1:2],xt[:,:,:,0:1]]) + + # x_hat_gray = 1./255.*tf.image.rgb_to_grayscale( + # inverse_transform(x_hat_rgb)*255. + # ) + # xt_gray = 1./255.*tf.image.rgb_to_grayscale( + # inverse_transform(xt_rgb)*255. + # ) + + x_hat_gray = x_hat + xt_gray = xt + else: + x_hat_gray = inverse_transform(x_hat) + xt_gray = inverse_transform(xt) + + diff_in = x_hat_gray - xt_gray + xt = x_hat + + + pred.append(tf.reshape(x_hat, [self.batch_size, 1, self.image_size[0], + self.image_size[1], self.c_dim])) + + return pred + + def motion_enc(self, diff_in, reuse): + res_in = [] + + conv1 = relu(conv2d(diff_in, output_dim = self.gf_dim, k_h = 5, k_w = 5, + d_h = 1, d_w = 1, name = 'dyn1_conv1', reuse = reuse)) + res_in.append(conv1) + pool1 = MaxPooling(conv1, [2, 2]) + + conv2 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 5, k_w = 5, + d_h = 1, d_w = 1, name = 'dyn_conv2', reuse = reuse)) + res_in.append(conv2) + pool2 = MaxPooling(conv2, [2, 2]) + + conv3 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 7, k_w = 7, + d_h = 1, d_w = 1, name = 'dyn_conv3', reuse = reuse)) + res_in.append(conv3) + pool3 = MaxPooling(conv3, [2, 2]) + return pool3, res_in + + def content_enc(self, xt, reuse): + res_in = [] + conv1_1 = relu(conv2d(xt, output_dim = self.gf_dim, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv1_1', reuse = reuse)) + conv1_2 = relu(conv2d(conv1_1, output_dim = self.gf_dim, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv1_2', reuse = reuse)) + res_in.append(conv1_2) + pool1 = MaxPooling(conv1_2, [2, 2]) + + conv2_1 = relu(conv2d(pool1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv2_1', reuse = reuse)) + conv2_2 = relu(conv2d(conv2_1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv2_2', reuse = reuse)) + res_in.append(conv2_2) + pool2 = MaxPooling(conv2_2, [2, 2]) + + conv3_1 = relu(conv2d(pool2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv3_1', reuse = reuse)) + conv3_2 = relu(conv2d(conv3_1, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv3_2', reuse = reuse)) + conv3_3 = relu(conv2d(conv3_2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'cont_conv3_3', reuse = reuse)) + res_in.append(conv3_3) + pool3 = MaxPooling(conv3_3, [2, 2]) + return pool3, res_in + + def comb_layers(self, h_dyn, h_cont, reuse=False): + comb1 = relu(conv2d(tf.concat(axis = 3, values = [h_dyn, h_cont]), + output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'comb1', reuse = reuse)) + comb2 = relu(conv2d(comb1, output_dim = self.gf_dim * 2, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'comb2', reuse = reuse)) + h_comb = relu(conv2d(comb2, output_dim = self.gf_dim * 4, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'h_comb', reuse = reuse)) + return h_comb + + def residual(self, input_dyn, input_cont, reuse=False): + n_layers = len(input_dyn) + res_out = [] + for l in range(n_layers): + input_ = tf.concat(axis = 3, values = [input_dyn[l], input_cont[l]]) + out_dim = input_cont[l].get_shape()[3] + res1 = relu(conv2d(input_, output_dim = out_dim, + k_h = 3, k_w = 3, d_h = 1, d_w = 1, + name = 'res' + str(l) + '_1', reuse = reuse)) + res2 = conv2d(res1, output_dim = out_dim, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'res' + str(l) + '_2', reuse = reuse) + res_out.append(res2) + return res_out + + def dec_cnn(self, h_comb, res_connect, reuse=False): + + shapel3 = [self.batch_size, int(self.image_size[0] / 4), + int(self.image_size[1] / 4), self.gf_dim * 4] + shapeout3 = [self.batch_size, int(self.image_size[0] / 4), + int(self.image_size[1] / 4), self.gf_dim * 2] + depool3 = FixedUnPooling(h_comb, [2, 2]) + deconv3_3 = relu(deconv2d(relu(tf.add(depool3, res_connect[2])), + output_shape = shapel3, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv3_3', reuse = reuse)) + deconv3_2 = relu(deconv2d(deconv3_3, output_shape = shapel3, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv3_2', reuse = reuse)) + deconv3_1 = relu(deconv2d(deconv3_2, output_shape = shapeout3, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv3_1', reuse = reuse)) + + shapel2 = [self.batch_size, int(self.image_size[0] / 2), + int(self.image_size[1] / 2), self.gf_dim * 2] + shapeout3 = [self.batch_size, int(self.image_size[0] / 2), + int(self.image_size[1] / 2), self.gf_dim] + depool2 = FixedUnPooling(deconv3_1, [2, 2]) + deconv2_2 = relu(deconv2d(relu(tf.add(depool2, res_connect[1])), + output_shape = shapel2, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv2_2', reuse = reuse)) + deconv2_1 = relu(deconv2d(deconv2_2, output_shape = shapeout3, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv2_1', reuse = reuse)) + + shapel1 = [self.batch_size, self.image_size[0], + self.image_size[1], self.gf_dim] + shapeout1 = [self.batch_size, self.image_size[0], + self.image_size[1], self.c_dim] + depool1 = FixedUnPooling(deconv2_1, [2, 2]) + deconv1_2 = relu(deconv2d(relu(tf.add(depool1, res_connect[0])), + output_shape = shapel1, k_h = 3, k_w = 3, d_h = 1, d_w = 1, + name = 'dec_deconv1_2', reuse = reuse)) + xtp1 = tanh(deconv2d(deconv1_2, output_shape = shapeout1, k_h = 3, k_w = 3, + d_h = 1, d_w = 1, name = 'dec_deconv1_1', reuse = reuse)) + return xtp1 + + def discriminator(self, image): + h0 = lrelu(conv2d(image, self.df_dim, name = 'dis_h0_conv')) + h1 = lrelu(batch_norm(conv2d(h0, self.df_dim * 2, name = 'dis_h1_conv'), + "bn1")) + h2 = lrelu(batch_norm(conv2d(h1, self.df_dim * 4, name = 'dis_h2_conv'), + "bn2")) + h3 = lrelu(batch_norm(conv2d(h2, self.df_dim * 8, name = 'dis_h3_conv'), + "bn3")) + h = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'dis_h3_lin') + + return tf.nn.sigmoid(h), h + + def save(self, sess, checkpoint_dir, step): + model_name = "MCNET.model" + + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + self.saver.save(sess, + os.path.join(checkpoint_dir, model_name), + global_step = step) + + def load(self, sess, checkpoint_dir, model_name=None): + print(" [*] Reading checkpoints...") + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + ckpt_name = os.path.basename(ckpt.model_checkpoint_path) + if model_name is None: model_name = ckpt_name + self.saver.restore(sess, os.path.join(checkpoint_dir, model_name)) + print(" Loaded model: " + str(model_name)) + return True, model_name + else: + return False, None + + # Execute the forward and the backward pass + + def run_single_step(self, global_step): + print("global_step:", global_step) + try: + train_batch = self.sess.run(self.train_iterator.get_next()) + # z=np.random.uniform(-1,1,size=(self.batch_size,self.nz)) + x = self.sess.run([self.x], feed_dict = {self.x: train_batch["images"]}) + _, g_sum = self.sess.run([self.g_optim, self.g_sum], feed_dict = {self.x: train_batch["images"]}) + _, d_sum = self.sess.run([self.d_optim, self.d_sum], feed_dict = {self.x: train_batch["images"]}) + + gen_data, train_loss = self.sess.run([self.gen_data, self.total_loss], + feed_dict = {self.x: train_batch["images"]}) + + except tf.errors.OutOfRangeError: + print("train out of range error") + + try: + val_batch = self.sess.run(self.val_iterator.get_next()) + val_loss = self.sess.run([self.total_loss], feed_dict = {self.x: val_batch["images"]}) + # self.val_writer.add_summary(val_summary, global_step) + except tf.errors.OutOfRangeError: + print("train out of range error") + + return train_loss, val_total_loss + + + diff --git a/video_prediction/models/savp_model.py b/video_prediction/models/savp_model.py index 63776e1cac0d6cbb4efa9247647f7d8e557a74c3..ca8acd3f32a5ea1772c9fbf36003149acfdcb950 100644 --- a/video_prediction/models/savp_model.py +++ b/video_prediction/models/savp_model.py @@ -990,4 +990,4 @@ def _maybe_tile_concat_layer(conv2d_layer): outputs = conv2d_layer(inputs, out_channels, *args, **kwargs) return outputs - return layer \ No newline at end of file + return layer diff --git a/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction/models/vanilla_convLSTM_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e7753004348ae0ae60057a469de1e2d1421c3869 --- /dev/null +++ b/video_prediction/models/vanilla_convLSTM_model.py @@ -0,0 +1,162 @@ +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from video_prediction import ops, flow_ops +from video_prediction.models import BaseVideoPredictionModel +from video_prediction.models import networks +from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from video_prediction.layers import layer_def as ld +from video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell + +class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): + def __init__(self, mode='train',aggregate_nccl=None, hparams_dict=None, + hparams=None, **kwargs): + super(VanillaConvLstmVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) + print ("Hparams_dict",self.hparams) + self.mode = mode + self.learning_rate = self.hparams.lr + self.gen_images_enc = None + self.recon_loss = None + self.latent_loss = None + self.total_loss = None + self.context_frames = 10 + self.sequence_length = 20 + self.predict_frames = self.sequence_length - self.context_frames + self.aggregate_nccl=aggregate_nccl + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + batch_size: batch size for training. + lr: learning rate. if decay steps is non-zero, this is the + learning rate for steps <= decay_step. + + + + max_steps: number of training steps. + + + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + """ + default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict() + print ("default hparams",default_hparams) + hparams = dict( + batch_size=16, + lr=0.001, + end_lr=0.0, + nz=16, + decay_steps=(200000, 300000), + max_steps=350000, + ) + + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def build_graph(self, x): + self.x = x["images"] + + self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + original_global_variables = tf.global_variables() + # ARCHITECTURE + self.x_hat_context_frames, self.x_hat_predict_frames = self.convLSTM_network() + self.x_hat = tf.concat([self.x_hat_context_frames, self.x_hat_predict_frames], 1) + + + self.context_frames_loss = tf.reduce_mean( + tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) + self.predict_frames_loss = tf.reduce_mean( + tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_predict_frames[:, :, :, :, 0])) + self.total_loss = self.context_frames_loss + self.predict_frames_loss + + self.train_op = tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) + self.outputs = {} + self.outputs["gen_images"] = self.x_hat + # Summary op + self.loss_summary = tf.summary.scalar("recon_loss", self.context_frames_loss) + self.loss_summary = tf.summary.scalar("latent_loss", self.predict_frames_loss) + self.loss_summary = tf.summary.scalar("total_loss", self.total_loss) + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + return + + + @staticmethod + def convLSTM_cell(inputs, hidden, nz=16): + + conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu") + + conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2", activate = "leaky_relu") + + conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3", activate = "leaky_relu") + + y_0 = conv3 + # conv lstm cell + cell_shape = y_0.get_shape().as_list() + with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): + cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size = [3, 3], num_features = 8) + if hidden is None: + hidden = cell.zero_state(y_0, tf.float32) + + output, hidden = cell(y_0, hidden) + + + output_shape = output.get_shape().as_list() + + + z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) + + conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, "decode_5", activate = "leaky_relu") + + + conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6", activate = "leaky_relu") + + + x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, "decode_7", activate = "sigmoid") # set activation to linear + + return x_hat, hidden + + def convLSTM_network(self): + network_template = tf.make_template('network', + VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables + # create network + x_hat_context = [] + x_hat_predict = [] + seq_start = 1 + hidden = None + for i in range(self.context_frames): + if i < seq_start: + x_1, hidden = network_template(self.x[:, i, :, :, :], hidden) + else: + x_1, hidden = network_template(x_1, hidden) + x_hat_context.append(x_1) + + for i in range(self.predict_frames): + x_1, hidden = network_template(x_1, hidden) + x_hat_predict.append(x_1) + + # pack them all together + x_hat_context = tf.stack(x_hat_context) + x_hat_predict = tf.stack(x_hat_predict) + self.x_hat_context = tf.transpose(x_hat_context, [1, 0, 2, 3, 4]) # change first dim with sec dim + self.x_hat_predict = tf.transpose(x_hat_predict, [1, 0, 2, 3, 4]) # change first dim with sec dim + return self.x_hat_context, self.x_hat_predict diff --git a/video_prediction/models/vanilla_vae_model.py b/video_prediction/models/vanilla_vae_model.py new file mode 100644 index 0000000000000000000000000000000000000000..eec5598305044226280080d630313487c7d847a4 --- /dev/null +++ b/video_prediction/models/vanilla_vae_model.py @@ -0,0 +1,191 @@ +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from video_prediction import ops, flow_ops +from video_prediction.models import BaseVideoPredictionModel +from video_prediction.models import networks +from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from video_prediction.layers import layer_def as ld + +class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): + def __init__(self, mode='train', aggregate_nccl=None,hparams_dict=None, + hparams=None,**kwargs): + super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) + self.mode = mode + self.learning_rate = self.hparams.lr + self.nz = self.hparams.nz + self.aggregate_nccl=aggregate_nccl + self.gen_images_enc = None + self.train_op = None + self.summary_op = None + self.recon_loss = None + self.latent_loss = None + self.total_loss = None + + def get_default_hparams_dict(self): + """ + The keys of this dict define valid hyperparameters for instances of + this class. A class inheriting from this one should override this + method if it has a different set of hyperparameters. + + Returns: + A dict with the following hyperparameters. + + batch_size: batch size for training. + lr: learning rate. if decay steps is non-zero, this is the + learning rate for steps <= decay_step. + end_lr: learning rate for steps >= end_decay_step if decay_steps + is non-zero, ignored otherwise. + decay_steps: (decay_step, end_decay_step) tuple. + max_steps: number of training steps. + beta1: momentum term of Adam. + beta2: momentum term of Adam. + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. + """ + default_hparams = super(VanillaVAEVideoPredictionModel, self).get_default_hparams_dict() + print ("default hparams",default_hparams) + hparams = dict( + batch_size=16, + lr=0.001, + end_lr=0.0, + decay_steps=(200000, 300000), + lr_boundaries=(0,), + max_steps=350000, + nz=10, + context_frames=-1, + sequence_length=-1, + clip_length=10, #Bing: TODO What is the clip_length, original is 10, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + def build_graph(self,x): + + + + + + + tf.set_random_seed(12345) + self.x = x["images"] + + + self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + original_global_variables = tf.global_variables() + self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step') + + self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all() + + + + + + + + + + + + self.recon_loss = tf.reduce_mean(tf.square(self.x[:, 1:, :, :, 0] - self.x_hat[:, :-1, :, :, 0])) + + + + + + latent_loss = -0.5 * tf.reduce_sum( + 1 + self.z_log_sigma_sq - tf.square(self.z_mu) - + tf.exp(self.z_log_sigma_sq), axis = 1) + self.latent_loss = tf.reduce_mean(latent_loss) + self.total_loss = self.recon_loss + self.latent_loss + self.train_op = tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) + # Build a saver + + self.losses = { + 'recon_loss': self.recon_loss, + 'latent_loss': self.latent_loss, + 'total_loss': self.total_loss, + } + + # Summary op + self.loss_summary = tf.summary.scalar("recon_loss", self.recon_loss) + self.loss_summary = tf.summary.scalar("latent_loss", self.latent_loss) + self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss) + self.summary_op = tf.summary.merge_all() + + + + self.outputs = {} + self.outputs["gen_images"] = self.x_hat + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + + return + + + @staticmethod + def vae_arc3(x,l_name=0,nz=16): + seq_name = "sq_" + str(l_name) + "_" + + conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1") + + + conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2") + + + conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3") + + + conv4 = tf.layers.Flatten()(conv3) + + conv3_shape = conv3.get_shape().as_list() + + + z_mu = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m") + z_log_sigma_sq = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma') + eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32) + z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps + + z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") + + + z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]]) + + conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, + seq_name + "decode_5") + + conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, + seq_name + "decode_6") + + + x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8") + + return x_hat, z_mu, z_log_sigma_sq, z + + def vae_arc_all(self): + X = [] + z_log_sigma_sq_all = [] + z_mu_all = [] + for i in range(20): + q, z_mu, z_log_sigma_sq, z = VanillaVAEVideoPredictionModel.vae_arc3(self.x[:, i, :, :, :], l_name=i, nz=self.nz) + X.append(q) + z_log_sigma_sq_all.append(z_log_sigma_sq) + z_mu_all.append(z_mu) + x_hat = tf.stack(X, axis = 1) + z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1) + z_mu_all = tf.stack(z_mu_all, axis = 1) + + + return x_hat, z_log_sigma_sq_all, z_mu_all diff --git a/video_prediction/utils/mcnet_utils.py b/video_prediction/utils/mcnet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7bad0131218f02e96d0e39132bc0d677547041da --- /dev/null +++ b/video_prediction/utils/mcnet_utils.py @@ -0,0 +1,156 @@ +""" +Some codes from https://github.com/Newmu/dcgan_code +""" + +import cv2 +import random +import imageio +import scipy.misc +import numpy as np + + +def transform(image): + return image/127.5 - 1. + + +def inverse_transform(images): + return (images+1.)/2. + + +def save_images(images, size, image_path): + return imsave(inverse_transform(images)*255., size, image_path) + + +def merge(images, size): + h, w = images.shape[1], images.shape[2] + img = np.zeros((h * size[0], w * size[1], 3)) + + for idx, image in enumerate(images): + i = idx % size[1] + j = idx / size[1] + img[j*h:j*h+h, i*w:i*w+w, :] = image + + return img + + +def imsave(images, size, path): + return scipy.misc.imsave(path, merge(images, size)) + + +def get_minibatches_idx(n, minibatch_size, shuffle=False): + """ + Used to shuffle the dataset at each iteration. + """ + idx_list = np.arange(n, dtype="int32") + + if shuffle: + random.shuffle(idx_list) + + minibatches = [] + minibatch_start = 0 + for i in range(n // minibatch_size): + minibatches.append(idx_list[minibatch_start:minibatch_start + minibatch_size]) + minibatch_start += minibatch_size + + if (minibatch_start != n): + # Make a minibatch out of what is left + minibatches.append(idx_list[minibatch_start:]) + + return zip(range(len(minibatches)), minibatches) + + +def draw_frame(img, is_input): + if img.shape[2] == 1: + img = np.repeat(img, [3], axis=2) + if is_input: + img[:2,:,0] = img[:2,:,2] = 0 + img[:,:2,0] = img[:,:2,2] = 0 + img[-2:,:,0] = img[-2:,:,2] = 0 + img[:,-2:,0] = img[:,-2:,2] = 0 + img[:2,:,1] = 255 + img[:,:2,1] = 255 + img[-2:,:,1] = 255 + img[:,-2:,1] = 255 + else: + img[:2,:,0] = img[:2,:,1] = 0 + img[:,:2,0] = img[:,:2,2] = 0 + img[-2:,:,0] = img[-2:,:,1] = 0 + img[:,-2:,0] = img[:,-2:,1] = 0 + img[:2,:,2] = 255 + img[:,:2,2] = 255 + img[-2:,:,2] = 255 + img[:,-2:,2] = 255 + + return img + + +def load_kth_data(f_name, data_path, image_size, K, T): + flip = np.random.binomial(1,.5,1)[0] + tokens = f_name.split() + vid_path = data_path + tokens[0] + "_uncomp.avi" + vid = imageio.get_reader(vid_path,"ffmpeg") + low = int(tokens[1]) + high = np.min([int(tokens[2]),vid.get_length()])-K-T+1 + if low == high: + stidx = 0 + else: + if low >= high: print(vid_path) + stidx = np.random.randint(low=low, high=high) + seq = np.zeros((image_size, image_size, K+T, 1), dtype="float32") + for t in xrange(K+T): + img = cv2.cvtColor(cv2.resize(vid.get_data(stidx+t), + (image_size,image_size)), + cv2.COLOR_RGB2GRAY) + seq[:,:,t] = transform(img[:,:,None]) + + if flip == 1: + seq = seq[:,::-1] + + diff = np.zeros((image_size, image_size, K-1, 1), dtype="float32") + for t in xrange(1,K): + prev = inverse_transform(seq[:,:,t-1]) + next = inverse_transform(seq[:,:,t]) + diff[:,:,t-1] = next.astype("float32")-prev.astype("float32") + + return seq, diff + + +def load_s1m_data(f_name, data_path, trainlist, K, T): + flip = np.random.binomial(1,.5,1)[0] + vid_path = data_path + f_name + img_size = [240,320] + + while True: + try: + vid = imageio.get_reader(vid_path,"ffmpeg") + low = 1 + high = vid.get_length()-K-T+1 + if low == high: + stidx = 0 + else: + stidx = np.random.randint(low=low, high=high) + seq = np.zeros((img_size[0], img_size[1], K+T, 3), + dtype="float32") + for t in xrange(K+T): + img = cv2.resize(vid.get_data(stidx+t), + (img_size[1],img_size[0]))[:,:,::-1] + seq[:,:,t] = transform(img) + + if flip == 1:seq = seq[:,::-1] + + diff = np.zeros((img_size[0], img_size[1], K-1, 1), + dtype="float32") + for t in xrange(1,K): + prev = inverse_transform(seq[:,:,t-1])*255 + prev = cv2.cvtColor(prev.astype("uint8"),cv2.COLOR_BGR2GRAY) + next = inverse_transform(seq[:,:,t])*255 + next = cv2.cvtColor(next.astype("uint8"),cv2.COLOR_BGR2GRAY) + diff[:,:,t-1,0] = (next.astype("float32")-prev.astype("float32"))/255. + break + except Exception: + # In case the current video is bad load a random one + rep_idx = np.random.randint(low=0, high=len(trainlist)) + f_name = trainlist[rep_idx] + vid_path = data_path + f_name + + return seq, diff