diff --git a/scripts/train.py b/scripts/train.py index 185033074e3c2c1cc6c08008cc85343a63139942..7723c774f823b61dbf7e8ef72273d88ede2ebe9a 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -7,7 +7,6 @@ import os import random import time from collections import OrderedDict -from distutils.util import strtobool import numpy as np import tensorflow as tf @@ -33,7 +32,6 @@ def main(): parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") - parser.add_argument("--max_steps", type=int, default=300000, help="number of training steps (0 to disable)") parser.add_argument("--summary_freq", type=int, default=1000, help="save summaries (except for image and eval summaries) every summary_freq steps") parser.add_argument("--image_summary_freq", type=int, default=5000, help="save image summaries every image_summary_freq steps") @@ -103,7 +101,7 @@ def main(): try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) - model_hparams_dict.pop('num_gpus', None) + model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") @@ -187,7 +185,7 @@ def main(): gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) global_step = tf.train.get_or_create_global_step() - + max_steps = train_model.hparams.max_steps with tf.Session(config=config) as sess: print("parameter_count =", sess.run(parameter_count)) @@ -197,12 +195,12 @@ def main(): start_step = sess.run(global_step) # start at one step earlier to log everything without doing any training # step is relative to the start_step - for step in range(-1, args.max_steps - start_step): + for step in range(-1, max_steps - start_step): if step == 0: start = time.time() def should(freq): - return freq and ((step + 1) % freq == 0 or (step + 1) in (0, args.max_steps - start_step)) + return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step)) fetches = {"global_step": global_step} if step >= 0: @@ -256,7 +254,7 @@ def main(): elapsed_time = time.time() - start average_time = elapsed_time / (step + 1) images_per_sec = batch_size / average_time - remaining_time = (args.max_steps - (start_step + step)) * average_time + remaining_time = (max_steps - (start_step + step)) * average_time print(" image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) diff --git a/video_prediction/models/base_model.py b/video_prediction/models/base_model.py index 8014f6b4d22a780c4c6caef2e67b408e45191300..10317fa553a5e93ffad5f5308d2ec7ded672d5d4 100644 --- a/video_prediction/models/base_model.py +++ b/video_prediction/models/base_model.py @@ -145,6 +145,9 @@ class BaseVideoPredictionModel: with tf.variable_scope('vgg', reuse=tf.AUTO_REUSE): _, target_vgg_features = vp.metrics._with_flat_batch(vgg_network.vgg16)(targets) + def sort_criterion(x): + return tf.reduce_mean(x, axis=0) + def accum_gen_images_and_metrics_fn(a, unused): with tf.variable_scope(self.generator_scope, reuse=True): gen_images, _ = self.generator_fn(inputs) @@ -159,8 +162,8 @@ class BaseVideoPredictionModel: metric /= len(target_vgg_features) else: metric = metric_fn(targets, gen_images, keep_axis=(0, 1)) # time, batch_size - cond_min = tf.less(tf.reduce_mean(metric, axis=0), tf.reduce_mean(a['eval_%s/min' % name], axis=0)) - cond_max = tf.greater(tf.reduce_mean(metric, axis=0), tf.reduce_mean(a['eval_%s/max' % name], axis=0)) + cond_min = tf.less(sort_criterion(metric), sort_criterion(a['eval_%s/min' % name])) + cond_max = tf.greater(sort_criterion(metric), sort_criterion(a['eval_%s/max' % name])) a['eval_%s/min' % name] = where_axis1(cond_min, metric, a['eval_%s/min' % name]) a['eval_%s/sum' % name] = metric + a['eval_%s/sum' % name] a['eval_%s/max' % name] = where_axis1(cond_max, metric, a['eval_%s/max' % name]) @@ -308,23 +311,31 @@ class VideoPredictionModel(BaseVideoPredictionModel): Returns: A dict with the following hyperparameters. - context_frames: the number of ground-truth frames to pass in at - start. Must be specified during instantiation. - sequence_length: the number of frames in the video sequence, - including the context frames, so this model predicts - `sequence_length - context_frames` future frames. Must be - specified during instantiation. + batch_size: batch size for training. lr: learning rate. if decay steps is non-zero, this is the learning rate for steps <= decay_step. end_lr: learning rate for steps >= end_decay_step if decay_steps is non-zero, ignored otherwise. decay_steps: (decay_step, end_decay_step) tuple. + max_steps: number of training steps. beta1: momentum term of Adam. beta2: momentum term of Adam. + context_frames: the number of ground-truth frames to pass in at + start. Must be specified during instantiation. + sequence_length: the number of frames in the video sequence, + including the context frames, so this model predicts + `sequence_length - context_frames` future frames. Must be + specified during instantiation. """ default_hparams = super(VideoPredictionModel, self).get_default_hparams_dict() hparams = dict( batch_size=16, + lr=0.001, + end_lr=0.0, + decay_steps=(200000, 300000), + max_steps=300000, + beta1=0.9, + beta2=0.999, context_frames=0, sequence_length=0, clip_length=10, @@ -357,11 +368,6 @@ class VideoPredictionModel(BaseVideoPredictionModel): kl_anneal_k=-1.0, kl_anneal_steps=(50000, 100000), z_l1_weight=0.0, - lr=0.001, - end_lr=0.0, - decay_steps=(200000, 300000), - beta1=0.9, - beta2=0.999, ) return dict(itertools.chain(default_hparams.items(), hparams.items()))