Select Git revision
create_documentation.sh
-
lukas leufen authoredlukas leufen authored
train.py 15.09 KiB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import errno
import itertools
import json
import math
import os
import random
import time
from collections import OrderedDict
import numpy as np
import tensorflow as tf
from video_prediction import datasets, models
from video_prediction.utils import ffmpeg_gif, tf_utils
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
"train, val, test, etc, or a directory containing "
"the tfrecords")
parser.add_argument("--val_input_dirs", type=str, nargs='+', help="directories containing the tfrecords. default: [input_dir]")
parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
"default is logs_dir/model_fname, where model_fname consists of "
"information from model and model_hparams")
parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
parser.add_argument("--dataset", type=str, help="dataset class name")
parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters")
parser.add_argument("--model", type=str, help="model class name")
parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
parser.add_argument("--summary_freq", type=int, default=1000, help="save summaries (except for image and eval summaries) every summary_freq steps")
parser.add_argument("--image_summary_freq", type=int, default=5000, help="save image summaries every image_summary_freq steps")
parser.add_argument("--eval_summary_freq", type=int, default=0, help="save eval summaries every eval_summary_freq steps")
parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps")
parser.add_argument("--metrics_freq", type=int, default=0, help="run and display metrics every metrics_freq step")
parser.add_argument("--gif_freq", type=int, default=0, help="save gifs of predicted frames every gif_freq steps")
parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable")
parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
parser.add_argument("--seed", type=int)
args = parser.parse_args()
if args.seed is not None:
tf.set_random_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
if args.output_dir is None:
list_depth = 0
model_fname = ''
for t in ('model=%s,%s' % (args.model, args.model_hparams)):
if t == '[':
list_depth += 1
if t == ']':
list_depth -= 1
if list_depth and t == ',':
t = '..'
if t in '=,':
t = '.'
if t in '[]':
t = ''
model_fname += t
args.output_dir = os.path.join(args.logs_dir, model_fname)
if args.resume:
if args.checkpoint:
raise ValueError('resume and checkpoint cannot both be specified')
args.checkpoint = args.output_dir
dataset_hparams_dict = {}
model_hparams_dict = {}
if args.dataset_hparams_dict:
with open(args.dataset_hparams_dict) as f:
dataset_hparams_dict.update(json.loads(f.read()))
if args.model_hparams_dict:
with open(args.model_hparams_dict) as f:
model_hparams_dict.update(json.loads(f.read()))
if args.checkpoint:
checkpoint_dir = os.path.normpath(args.checkpoint)
if not os.path.exists(checkpoint_dir):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
if not os.path.isdir(args.checkpoint):
checkpoint_dir, _ = os.path.split(checkpoint_dir)
with open(os.path.join(checkpoint_dir, "options.json")) as f:
print("loading options from checkpoint %s" % args.checkpoint)
options = json.loads(f.read())
args.dataset = args.dataset or options['dataset']
args.model = args.model or options['model']
try:
with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
dataset_hparams_dict.update(json.loads(f.read()))
except FileNotFoundError:
print("dataset_hparams.json was not loaded because it does not exist")
try:
with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
model_hparams_dict.update(json.loads(f.read()))
model_hparams_dict.pop('num_gpus', None) # backwards-compatibility
except FileNotFoundError:
print("model_hparams.json was not loaded because it does not exist")
print('----------------------------------- Options ------------------------------------')
for k, v in args._get_kwargs():
print(k, "=", v)
print('------------------------------------- End --------------------------------------')
VideoDataset = datasets.get_dataset_class(args.dataset)
train_dataset = VideoDataset(args.input_dir, mode='train', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)
val_input_dirs = args.val_input_dirs or [args.input_dir]
val_datasets = [VideoDataset(val_input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)
for val_input_dir in val_input_dirs]
if len(val_input_dirs) > 1:
if isinstance(val_datasets[-1], datasets.KTHVideoDataset):
val_datasets[-1].set_sequence_length(40)
else:
val_datasets[-1].set_sequence_length(30)
def override_hparams_dict(dataset):
hparams_dict = dict(model_hparams_dict)
hparams_dict['context_frames'] = dataset.hparams.context_frames
hparams_dict['sequence_length'] = dataset.hparams.sequence_length
hparams_dict['repeat'] = dataset.hparams.time_shift
return hparams_dict
VideoPredictionModel = models.get_model_class(args.model)
train_model = VideoPredictionModel(mode='train', hparams_dict=override_hparams_dict(train_dataset), hparams=args.model_hparams)
val_models = [VideoPredictionModel(mode='val', hparams_dict=override_hparams_dict(val_dataset), hparams=args.model_hparams)
for val_dataset in val_datasets]
batch_size = train_model.hparams.batch_size
with tf.variable_scope('') as training_scope:
train_model.build_graph(*train_dataset.make_batch(batch_size))
for val_model, val_dataset in zip(val_models, val_datasets):
with tf.variable_scope(training_scope, reuse=True):
val_model.build_graph(*val_dataset.make_batch(batch_size))
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with open(os.path.join(args.output_dir, "options.json"), "w") as f:
f.write(json.dumps(vars(args), sort_keys=True, indent=4))
with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
f.write(json.dumps(train_model.hparams.values(), sort_keys=True, indent=4))
if args.gif_freq:
val_model = val_models[0]
val_tensors = OrderedDict()
context_images = val_model.inputs['images'][:, :val_model.hparams.context_frames]
val_tensors['gen_images_vis'] = tf.concat([context_images, val_model.gen_images], axis=1)
if val_model.gen_images_enc is not None:
val_tensors['gen_images_enc_vis'] = tf.concat([context_images, val_model.gen_images_enc], axis=1)
val_tensors.update({name: tensor for name, tensor in val_model.inputs.items() if tensor.shape.ndims >= 4})
val_tensors['targets'] = val_model.targets
val_tensors.update({name: tensor for name, tensor in val_model.outputs.items() if tensor.shape.ndims >= 4})
val_tensor_clips = OrderedDict([(name, tf_utils.tensor_to_clip(output)) for name, output in val_tensors.items()])
with tf.name_scope("parameter_count"):
parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
saver = tf.train.Saver(max_to_keep=3)
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
image_summaries = set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))
eval_summaries = set(tf.get_collection(tf_utils.EVAL_SUMMARIES))
eval_image_summaries = image_summaries & eval_summaries
image_summaries -= eval_image_summaries
eval_summaries -= eval_image_summaries
if args.summary_freq:
summary_op = tf.summary.merge(summaries)
if args.image_summary_freq:
image_summary_op = tf.summary.merge(list(image_summaries))
if args.eval_summary_freq:
eval_summary_op = tf.summary.merge(list(eval_summaries))
eval_image_summary_op = tf.summary.merge(list(eval_image_summaries))
if args.summary_freq or args.image_summary_freq or args.eval_summary_freq:
summary_writer = tf.summary.FileWriter(args.output_dir)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
global_step = tf.train.get_or_create_global_step()
max_steps = train_model.hparams.max_steps
with tf.Session(config=config) as sess:
print("parameter_count =", sess.run(parameter_count))
sess.run(tf.global_variables_initializer())
train_model.restore(sess, args.checkpoint)
start_step = sess.run(global_step)
# start at one step earlier to log everything without doing any training
# step is relative to the start_step
for step in range(-1, max_steps - start_step):
if step == 0:
start = time.time()
def should(freq):
return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step))
fetches = {"global_step": global_step}
if step >= 0:
fetches["train_op"] = train_model.train_op
if should(args.progress_freq):
fetches['d_losses'] = train_model.d_losses
fetches['g_losses'] = train_model.g_losses
if isinstance(train_model.learning_rate, tf.Tensor):
fetches["learning_rate"] = train_model.learning_rate
if should(args.metrics_freq):
fetches['metrics'] = train_model.metrics
if should(args.summary_freq):
fetches["summary"] = summary_op
if should(args.image_summary_freq):
fetches["image_summary"] = image_summary_op
if should(args.eval_summary_freq):
fetches["eval_summary"] = eval_summary_op
fetches["eval_image_summary"] = eval_image_summary_op
run_start_time = time.time()
results = sess.run(fetches)
run_elapsed_time = time.time() - run_start_time
if run_elapsed_time > 1.5:
print('session.run took %0.1fs' % run_elapsed_time)
if should(args.summary_freq):
print("recording summary")
summary_writer.add_summary(results["summary"], results["global_step"])
print("done")
if should(args.image_summary_freq):
print("recording image summary")
summary_writer.add_summary(
tf_utils.convert_tensor_to_gif_summary(results["image_summary"]), results["global_step"])
print("done")
if should(args.eval_summary_freq):
print("recording eval summary")
summary_writer.add_summary(results["eval_summary"], results["global_step"])
summary_writer.add_summary(
tf_utils.convert_tensor_to_gif_summary(results["eval_image_summary"]), results["global_step"])
print("done")
if should(args.summary_freq) or should(args.image_summary_freq) or should(args.eval_summary_freq):
summary_writer.flush()
if should(args.progress_freq):
# global_step will have the correct step count if we resume from a checkpoint
steps_per_epoch = math.ceil(train_dataset.num_examples_per_epoch() / batch_size)
train_epoch = math.ceil(results["global_step"] / steps_per_epoch)
train_step = (results["global_step"] - 1) % steps_per_epoch + 1
print("progress global step %d epoch %d step %d" % (results["global_step"], train_epoch, train_step))
if step >= 0:
elapsed_time = time.time() - start
average_time = elapsed_time / (step + 1)
images_per_sec = batch_size / average_time
remaining_time = (max_steps - (start_step + step)) * average_time
print(" image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" %
(images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))
for name, loss in itertools.chain(results['d_losses'].items(), results['g_losses'].items()):
print(name, loss)
if isinstance(train_model.learning_rate, tf.Tensor):
print("learning_rate", results["learning_rate"])
if should(args.metrics_freq):
for name, metric in results['metrics']:
print(name, metric)
if should(args.save_freq):
print("saving model to", args.output_dir)
saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step)
print("done")
if should(args.gif_freq):
image_dir = os.path.join(args.output_dir, 'images')
if not os.path.exists(image_dir):
os.makedirs(image_dir)
gif_clips = sess.run(val_tensor_clips)
gif_step = results["global_step"]
for name, clip in gif_clips.items():
filename = "%08d-%s.gif" % (gif_step, name)
print("saving gif to", os.path.join(image_dir, filename))
ffmpeg_gif.save_gif(os.path.join(image_dir, filename), clip, fps=4)
print("done")
if __name__ == '__main__':
main()