diff --git a/scripts/train.py b/scripts/train.py
index 185033074e3c2c1cc6c08008cc85343a63139942..7723c774f823b61dbf7e8ef72273d88ede2ebe9a 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -7,7 +7,6 @@ import os
 import random
 import time
 from collections import OrderedDict
-from distutils.util import strtobool
 
 import numpy as np
 import tensorflow as tf
@@ -33,7 +32,6 @@ def main():
     parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
     parser.add_argument("--model", type=str, help="model class name")
     parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
-    parser.add_argument("--max_steps", type=int, default=300000, help="number of training steps (0 to disable)")
 
     parser.add_argument("--summary_freq", type=int, default=1000, help="save summaries (except for image and eval summaries) every summary_freq steps")
     parser.add_argument("--image_summary_freq", type=int, default=5000, help="save image summaries every image_summary_freq steps")
@@ -103,7 +101,7 @@ def main():
         try:
             with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                 model_hparams_dict = json.loads(f.read())
-                model_hparams_dict.pop('num_gpus', None)
+                model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
         except FileNotFoundError:
             print("model_hparams.json was not loaded because it does not exist")
 
@@ -187,7 +185,7 @@ def main():
     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
     config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
     global_step = tf.train.get_or_create_global_step()
-
+    max_steps = train_model.hparams.max_steps
     with tf.Session(config=config) as sess:
         print("parameter_count =", sess.run(parameter_count))
 
@@ -197,12 +195,12 @@ def main():
         start_step = sess.run(global_step)
         # start at one step earlier to log everything without doing any training
         # step is relative to the start_step
-        for step in range(-1, args.max_steps - start_step):
+        for step in range(-1, max_steps - start_step):
             if step == 0:
                 start = time.time()
 
             def should(freq):
-                return freq and ((step + 1) % freq == 0 or (step + 1) in (0, args.max_steps - start_step))
+                return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step))
 
             fetches = {"global_step": global_step}
             if step >= 0:
@@ -256,7 +254,7 @@ def main():
                     elapsed_time = time.time() - start
                     average_time = elapsed_time / (step + 1)
                     images_per_sec = batch_size / average_time
-                    remaining_time = (args.max_steps - (start_step + step)) * average_time
+                    remaining_time = (max_steps - (start_step + step)) * average_time
                     print("          image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)" %
                           (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))
 
diff --git a/video_prediction/models/base_model.py b/video_prediction/models/base_model.py
index 8014f6b4d22a780c4c6caef2e67b408e45191300..10317fa553a5e93ffad5f5308d2ec7ded672d5d4 100644
--- a/video_prediction/models/base_model.py
+++ b/video_prediction/models/base_model.py
@@ -145,6 +145,9 @@ class BaseVideoPredictionModel:
             with tf.variable_scope('vgg', reuse=tf.AUTO_REUSE):
                 _, target_vgg_features = vp.metrics._with_flat_batch(vgg_network.vgg16)(targets)
 
+            def sort_criterion(x):
+                return tf.reduce_mean(x, axis=0)
+
             def accum_gen_images_and_metrics_fn(a, unused):
                 with tf.variable_scope(self.generator_scope, reuse=True):
                     gen_images, _ = self.generator_fn(inputs)
@@ -159,8 +162,8 @@ class BaseVideoPredictionModel:
                         metric /= len(target_vgg_features)
                     else:
                         metric = metric_fn(targets, gen_images, keep_axis=(0, 1))  # time, batch_size
-                    cond_min = tf.less(tf.reduce_mean(metric, axis=0), tf.reduce_mean(a['eval_%s/min' % name], axis=0))
-                    cond_max = tf.greater(tf.reduce_mean(metric, axis=0), tf.reduce_mean(a['eval_%s/max' % name], axis=0))
+                    cond_min = tf.less(sort_criterion(metric), sort_criterion(a['eval_%s/min' % name]))
+                    cond_max = tf.greater(sort_criterion(metric), sort_criterion(a['eval_%s/max' % name]))
                     a['eval_%s/min' % name] = where_axis1(cond_min, metric, a['eval_%s/min' % name])
                     a['eval_%s/sum' % name] = metric + a['eval_%s/sum' % name]
                     a['eval_%s/max' % name] = where_axis1(cond_max, metric, a['eval_%s/max' % name])
@@ -308,23 +311,31 @@ class VideoPredictionModel(BaseVideoPredictionModel):
         Returns:
             A dict with the following hyperparameters.
 
-            context_frames: the number of ground-truth frames to pass in at
-                start. Must be specified during instantiation.
-            sequence_length: the number of frames in the video sequence,
-                including the context frames, so this model predicts
-                `sequence_length - context_frames` future frames. Must be
-                specified during instantiation.
+            batch_size: batch size for training.
             lr: learning rate. if decay steps is non-zero, this is the
                 learning rate for steps <= decay_step.
             end_lr: learning rate for steps >= end_decay_step if decay_steps
                 is non-zero, ignored otherwise.
             decay_steps: (decay_step, end_decay_step) tuple.
+            max_steps: number of training steps.
             beta1: momentum term of Adam.
             beta2: momentum term of Adam.
+            context_frames: the number of ground-truth frames to pass in at
+                start. Must be specified during instantiation.
+            sequence_length: the number of frames in the video sequence,
+                including the context frames, so this model predicts
+                `sequence_length - context_frames` future frames. Must be
+                specified during instantiation.
         """
         default_hparams = super(VideoPredictionModel, self).get_default_hparams_dict()
         hparams = dict(
             batch_size=16,
+            lr=0.001,
+            end_lr=0.0,
+            decay_steps=(200000, 300000),
+            max_steps=300000,
+            beta1=0.9,
+            beta2=0.999,
             context_frames=0,
             sequence_length=0,
             clip_length=10,
@@ -357,11 +368,6 @@ class VideoPredictionModel(BaseVideoPredictionModel):
             kl_anneal_k=-1.0,
             kl_anneal_steps=(50000, 100000),
             z_l1_weight=0.0,
-            lr=0.001,
-            end_lr=0.0,
-            decay_steps=(200000, 300000),
-            beta1=0.9,
-            beta2=0.999,
         )
         return dict(itertools.chain(default_hparams.items(), hparams.items()))