from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import errno import itertools import json import math import os import random import time from collections import OrderedDict import numpy as np import tensorflow as tf from video_prediction import datasets, models from video_prediction.utils import ffmpeg_gif, tf_utils def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--val_input_dirs", type=str, nargs='+', help="directories containing the tfrecords. default: [input_dir]") parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " "default is logs_dir/model_fname, where model_fname consists of " "information from model and model_hparams") parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") parser.add_argument("--summary_freq", type=int, default=1000, help="save summaries (except for image and eval summaries) every summary_freq steps") parser.add_argument("--image_summary_freq", type=int, default=5000, help="save image summaries every image_summary_freq steps") parser.add_argument("--eval_summary_freq", type=int, default=0, help="save eval summaries every eval_summary_freq steps") parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps") parser.add_argument("--metrics_freq", type=int, default=0, help="run and display metrics every metrics_freq step") parser.add_argument("--gif_freq", type=int, default=0, help="save gifs of predicted frames every gif_freq steps") parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable") parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if args.output_dir is None: list_depth = 0 model_fname = '' for t in ('model=%s,%s' % (args.model, args.model_hparams)): if t == '[': list_depth += 1 if t == ']': list_depth -= 1 if list_depth and t == ',': t = '..' if t in '=,': t = '.' if t in '[]': t = '' model_fname += t args.output_dir = os.path.join(args.logs_dir, model_fname) if args.resume: if args.checkpoint: raise ValueError('resume and checkpoint cannot both be specified') args.checkpoint = args.output_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.dataset_hparams_dict: with open(args.dataset_hparams_dict) as f: dataset_hparams_dict.update(json.loads(f.read())) if args.model_hparams_dict: with open(args.model_hparams_dict) as f: model_hparams_dict.update(json.loads(f.read())) if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict.update(json.loads(f.read())) except FileNotFoundError: print("dataset_hparams.json was not loaded because it does not exist") try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict.update(json.loads(f.read())) model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): print(k, "=", v) print('------------------------------------- End --------------------------------------') VideoDataset = datasets.get_dataset_class(args.dataset) train_dataset = VideoDataset(args.input_dir, mode='train', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) val_input_dirs = args.val_input_dirs or [args.input_dir] val_datasets = [VideoDataset(val_input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) for val_input_dir in val_input_dirs] if len(val_input_dirs) > 1: if isinstance(val_datasets[-1], datasets.KTHVideoDataset): val_datasets[-1].set_sequence_length(40) else: val_datasets[-1].set_sequence_length(30) def override_hparams_dict(dataset): hparams_dict = dict(model_hparams_dict) hparams_dict['context_frames'] = dataset.hparams.context_frames hparams_dict['sequence_length'] = dataset.hparams.sequence_length hparams_dict['repeat'] = dataset.hparams.time_shift return hparams_dict VideoPredictionModel = models.get_model_class(args.model) train_model = VideoPredictionModel(mode='train', hparams_dict=override_hparams_dict(train_dataset), hparams=args.model_hparams) val_models = [VideoPredictionModel(mode='val', hparams_dict=override_hparams_dict(val_dataset), hparams=args.model_hparams) for val_dataset in val_datasets] batch_size = train_model.hparams.batch_size with tf.variable_scope('') as training_scope: train_model.build_graph(*train_dataset.make_batch(batch_size)) for val_model, val_dataset in zip(val_models, val_datasets): with tf.variable_scope(training_scope, reuse=True): val_model.build_graph(*val_dataset.make_batch(batch_size)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with open(os.path.join(args.output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(train_model.hparams.values(), sort_keys=True, indent=4)) if args.gif_freq: val_model = val_models[0] val_tensors = OrderedDict() context_images = val_model.inputs['images'][:, :val_model.hparams.context_frames] val_tensors['gen_images_vis'] = tf.concat([context_images, val_model.gen_images], axis=1) if val_model.gen_images_enc is not None: val_tensors['gen_images_enc_vis'] = tf.concat([context_images, val_model.gen_images_enc], axis=1) val_tensors.update({name: tensor for name, tensor in val_model.inputs.items() if tensor.shape.ndims >= 4}) val_tensors['targets'] = val_model.targets val_tensors.update({name: tensor for name, tensor in val_model.outputs.items() if tensor.shape.ndims >= 4}) val_tensor_clips = OrderedDict([(name, tf_utils.tensor_to_clip(output)) for name, output in val_tensors.items()]) with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) saver = tf.train.Saver(max_to_keep=3) summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) image_summaries = set(tf.get_collection(tf_utils.IMAGE_SUMMARIES)) eval_summaries = set(tf.get_collection(tf_utils.EVAL_SUMMARIES)) eval_image_summaries = image_summaries & eval_summaries image_summaries -= eval_image_summaries eval_summaries -= eval_image_summaries if args.summary_freq: summary_op = tf.summary.merge(summaries) if args.image_summary_freq: image_summary_op = tf.summary.merge(list(image_summaries)) if args.eval_summary_freq: eval_summary_op = tf.summary.merge(list(eval_summaries)) eval_image_summary_op = tf.summary.merge(list(eval_image_summaries)) if args.summary_freq or args.image_summary_freq or args.eval_summary_freq: summary_writer = tf.summary.FileWriter(args.output_dir) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) global_step = tf.train.get_or_create_global_step() max_steps = train_model.hparams.max_steps with tf.Session(config=config) as sess: print("parameter_count =", sess.run(parameter_count)) sess.run(tf.global_variables_initializer()) train_model.restore(sess, args.checkpoint) start_step = sess.run(global_step) # start at one step earlier to log everything without doing any training # step is relative to the start_step for step in range(-1, max_steps - start_step): if step == 0: start = time.time() def should(freq): return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step)) fetches = {"global_step": global_step} if step >= 0: fetches["train_op"] = train_model.train_op if should(args.progress_freq): fetches['d_losses'] = train_model.d_losses fetches['g_losses'] = train_model.g_losses if isinstance(train_model.learning_rate, tf.Tensor): fetches["learning_rate"] = train_model.learning_rate if should(args.metrics_freq): fetches['metrics'] = train_model.metrics if should(args.summary_freq): fetches["summary"] = summary_op if should(args.image_summary_freq): fetches["image_summary"] = image_summary_op if should(args.eval_summary_freq): fetches["eval_summary"] = eval_summary_op fetches["eval_image_summary"] = eval_image_summary_op run_start_time = time.time() results = sess.run(fetches) run_elapsed_time = time.time() - run_start_time if run_elapsed_time > 1.5: print('session.run took %0.1fs' % run_elapsed_time) if should(args.summary_freq): print("recording summary") summary_writer.add_summary(results["summary"], results["global_step"]) print("done") if should(args.image_summary_freq): print("recording image summary") summary_writer.add_summary( tf_utils.convert_tensor_to_gif_summary(results["image_summary"]), results["global_step"]) print("done") if should(args.eval_summary_freq): print("recording eval summary") summary_writer.add_summary(results["eval_summary"], results["global_step"]) summary_writer.add_summary( tf_utils.convert_tensor_to_gif_summary(results["eval_image_summary"]), results["global_step"]) print("done") if should(args.summary_freq) or should(args.image_summary_freq) or should(args.eval_summary_freq): summary_writer.flush() if should(args.progress_freq): # global_step will have the correct step count if we resume from a checkpoint steps_per_epoch = math.ceil(train_dataset.num_examples_per_epoch() / batch_size) train_epoch = math.ceil(results["global_step"] / steps_per_epoch) train_step = (results["global_step"] - 1) % steps_per_epoch + 1 print("progress global step %d epoch %d step %d" % (results["global_step"], train_epoch, train_step)) if step >= 0: elapsed_time = time.time() - start average_time = elapsed_time / (step + 1) images_per_sec = batch_size / average_time remaining_time = (max_steps - (start_step + step)) * average_time print(" image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) for name, loss in itertools.chain(results['d_losses'].items(), results['g_losses'].items()): print(name, loss) if isinstance(train_model.learning_rate, tf.Tensor): print("learning_rate", results["learning_rate"]) if should(args.metrics_freq): for name, metric in results['metrics']: print(name, metric) if should(args.save_freq): print("saving model to", args.output_dir) saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step) print("done") if should(args.gif_freq): image_dir = os.path.join(args.output_dir, 'images') if not os.path.exists(image_dir): os.makedirs(image_dir) gif_clips = sess.run(val_tensor_clips) gif_step = results["global_step"] for name, clip in gif_clips.items(): filename = "%08d-%s.gif" % (gif_step, name) print("saving gif to", os.path.join(image_dir, filename)) ffmpeg_gif.save_gif(os.path.join(image_dir, filename), clip, fps=4) print("done") if __name__ == '__main__': main()