diff --git a/scripts/evaluate.py b/scripts/evaluate.py
index 8de2a58b006c1c71d1f9e58d0eaba6ce4b1a23d3..24d84fb6105a6b4a3485680991d9fa46b3b1cb9d 100644
--- a/scripts/evaluate.py
+++ b/scripts/evaluate.py
@@ -2,6 +2,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import re
 import argparse
 import csv
 import errno
@@ -11,10 +12,8 @@ import random
 
 import numpy as np
 import tensorflow as tf
-from tensorflow.python.util import nest
 
-from video_prediction import datasets, models, metrics
-from video_prediction.policies.servo_policy import ServoPolicy
+from video_prediction import datasets, models
 
 
 def compute_expectation_np(pix_distrib):
@@ -142,23 +141,25 @@ def merge_hparams(hparams0, hparams1):
 def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None):
     context_frames = model_hparams.context_frames
     context_images = results['images'][:, :context_frames]
-    images = results['eval_images']
-    metric_names = ['psnr', 'ssim', 'ssim_scikit', 'ssim_finn', 'ssim_mcnet', 'vgg_csim']
-    metric_fns = [metrics.peak_signal_to_noise_ratio_np,
-                  metrics.structural_similarity_np,
-                  metrics.structural_similarity_scikit_np,
-                  metrics.structural_similarity_finn_np,
-                  metrics.structural_similarity_mcnet_np,
-                  None]
+
+    if 'eval_diversity' in results:
+        metric = results['eval_diversity']
+        metric_name = 'diversity'
+        subtask_dir = task_dir + '_%s' % metric_name
+        save_metrics(os.path.join(subtask_dir, 'metrics', metric_name),
+                     metric, sample_start_ind=sample_start_ind)
+
     subtasks = subtasks or ['max']
-    for metric_name, metric_fn in zip(metric_names, metric_fns):
-        for subtask in subtasks:
+    for subtask in subtasks:
+        metric_names = []
+        for k in results.keys():
+            if re.match('eval_(\w+)/%s' % subtask, k) and not re.match('eval_gen_images_(\w+)/%s' % subtask, k):
+                m = re.match('eval_(\w+)/%s' % subtask, k)
+                metric_names.append(m.group(1))
+        for metric_name in metric_names:
             subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask)
             gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images'))
-            if metric_fn is not None:  # recompute using numpy implementation
-                metric = metric_fn(images, gen_images, keep_axis=(0, 1))
-            else:
-                metric = results['eval_%s/%s' % (metric_name, subtask)]
+            metric = results['eval_%s/%s' % (metric_name, subtask)]
             save_metrics(os.path.join(subtask_dir, 'metrics', metric_name),
                          metric, sample_start_ind=sample_start_ind)
             if only_metrics:
@@ -170,133 +171,28 @@ def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_
                                  gen_images, sample_start_ind=sample_start_ind)
 
 
-def save_prediction_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False):
-    context_frames = model_hparams.context_frames
-    sequence_length = model_hparams.sequence_length
-    context_images, images = np.split(results['images'], [context_frames], axis=1)
-    gen_images = results['gen_images'][:, context_frames - sequence_length:]
-    psnr = metrics.peak_signal_to_noise_ratio_np(images, gen_images, keep_axis=(0, 1))
-    mse = metrics.mean_squared_error_np(images, gen_images, keep_axis=(0, 1))
-    ssim = metrics.structural_similarity_np(images, gen_images, keep_axis=(0, 1))
-    save_metrics(os.path.join(task_dir, 'metrics', 'psnr'),
-                 psnr, sample_start_ind=sample_start_ind)
-    save_metrics(os.path.join(task_dir, 'metrics', 'mse'),
-                 mse, sample_start_ind=sample_start_ind)
-    save_metrics(os.path.join(task_dir, 'metrics', 'ssim'),
-                 ssim, sample_start_ind=sample_start_ind)
-    if only_metrics:
-        return
-
-    save_image_sequences(os.path.join(task_dir, 'inputs', 'context_image'),
-                         context_images, sample_start_ind=sample_start_ind)
-    save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_image'),
-                         gen_images, sample_start_ind=sample_start_ind)
-
-
-def save_motion_results(task_dir, results, model_hparams, draw_center=False,
-                        sample_start_ind=0, only_metrics=False):
-    context_frames = model_hparams.context_frames
-    sequence_length = model_hparams.sequence_length
-    pix_distribs = results['pix_distribs'][:, context_frames:]
-    gen_pix_distribs = results['gen_pix_distribs'][:, context_frames - sequence_length:]
-    pix_dist = metrics.expected_pixel_distance_np(pix_distribs, gen_pix_distribs, keep_axis=(0, 1))
-    save_metrics(os.path.join(task_dir, 'metrics', 'pix_dist'),
-                 pix_dist, sample_start_ind=sample_start_ind)
-    if only_metrics:
-        return
-
-    context_images, images = np.split(results['images'], [context_frames], axis=1)
-    gen_images = results['gen_images'][:, context_frames - sequence_length:]
-    initial_pix_distrib = results['pix_distribs'][:, 0:1]
-    num_motions = pix_distribs.shape[-1]
-    for i in range(num_motions):
-        output_name_posfix = '%d' % i if num_motions > 1 else ''
-        centers = compute_expectation_np(initial_pix_distrib[..., i:i + 1]) if draw_center else None
-        save_image_sequences(os.path.join(task_dir, 'inputs', 'pix_distrib%s' % output_name_posfix),
-                             context_images[:, 0:1], initial_pix_distrib[..., i:i + 1], centers, sample_start_ind=sample_start_ind)
-        centers = compute_expectation_np(gen_pix_distribs[..., i:i + 1]) if draw_center else None
-        save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_pix_distrib%s' % output_name_posfix),
-                             gen_images, gen_pix_distribs[..., i:i + 1], centers, sample_start_ind=sample_start_ind)
-
-
-def save_servo_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False):
-    context_frames = model_hparams.context_frames
-    sequence_length = model_hparams.sequence_length
-    context_images, images = np.split(results['images'], [context_frames], axis=1)
-    gen_images = results['gen_images'][:, context_frames - sequence_length:]
-    goal_image = results['goal_image']
-    # TODO: should exclude "context" actions assuming that they are passed in to the network
-    actions = results['actions']
-    gen_actions = results['gen_actions']
-    goal_image_mse = metrics.mean_squared_error_np(goal_image, gen_images[:, -1], keep_axis=0)
-    action_mse = metrics.mean_squared_error_np(actions, gen_actions, keep_axis=(0, 1))
-    save_metrics(os.path.join(task_dir, 'metrics', 'goal_image_mse'),
-                 goal_image_mse[:, None], sample_start_ind=sample_start_ind)
-    save_metrics(os.path.join(task_dir, 'metrics', 'action_mse'),
-                 action_mse, sample_start_ind=sample_start_ind)
-    if only_metrics:
-        return
-
-    save_image_sequences(os.path.join(task_dir, 'inputs', 'context_image'),
-                         context_images, sample_start_ind=sample_start_ind)
-    save_image_sequences(os.path.join(task_dir, 'inputs', 'goal_image'),
-                         goal_image[:, None], sample_start_ind=sample_start_ind)
-    save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_image'),
-                         gen_images, sample_start_ind=sample_start_ind)
-    gen_image_goal_diffs = np.abs(gen_images - goal_image[:, None])
-    save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_image_goal_diff'),
-                         gen_image_goal_diffs, sample_start_ind=sample_start_ind)
-
-
 def main():
     """
     results_dir
     ├── output_dir                              # condition / method
-    │   ├── prediction                          # task
+    │   ├── prediction_eval_lpips_max           # task: best sample in terms of LPIPS similarity
     │   │   ├── inputs
     │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
     │   │   │   └── ...
     │   │   ├── outputs
-    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the ones in the loss)
+    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the future ones)
     │   │   │   └── ...
     │   │   └── metrics
-    │   │       ├── psnr.csv
-    │   │       ├── mse.csv
-    │   │       └── ssim.csv
-    │   ├── prediction_eval_vgg_csim_max        # task: best sample in terms of VGG cosine similarity
+    │   │       └── lpips.csv
+    │   ├── prediction_eval_ssim_max            # task: best sample in terms of SSIM
     │   │   ├── inputs
     │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
     │   │   │   └── ...
     │   │   ├── outputs
-    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the ones in the loss)
-    │   │   │   └── ...
-    │   │   └── metrics
-    │   │       └── vgg_csim.csv
-    │   ├── servo
-    │   │   ├── inputs
-    │   │   │   ├── context_image_00000_00.png
-    │   │   │   ├── ...
-    │   │   │   ├── goal_image_00000_00.png     # only one goal image per sample
-    │   │   │   └── ...
-    │   │   ├── outputs
-    │   │   │   ├── gen_image_00000_00.png
-    │   │   │   ├── ...
-    │   │   │   ├── gen_image_goal_diff_00000_00.png
-    │   │   │   └── ...
-    │   │   └── metrics
-    │   │       ├── action_mse.csv
-    │   │       └── goal_image_mse.csv
-    │   ├── motion
-    │   │   ├── inputs
-    │   │   │   ├── pix_distrib_00000_00.png
-    │   │   │   └── ...
-    │   │   ├── outputs
-    │   │   │   ├── gen_pix_distrib_00000_00.png
-    │   │   │   ├── ...
-    │   │   │   ├── gen_pix_distrib_overlaid_00000_00.png
+    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the future ones)
     │   │   │   └── ...
     │   │   └── metrics
-    │   │       └── pix_dist.csv
+    │   │       └── ssim.csv
     │   └── ...
     └── ...
     """
@@ -320,8 +216,7 @@ def main():
     parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
     parser.add_argument("--num_epochs", type=int, default=1)
 
-    parser.add_argument("--tasks", type=str, nargs='+', help='tasks to evaluate (e.g. prediction, prediction_eval, servo, motion)')
-    parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'min'], help='subtasks to evaluate (e.g. max, avg, min). only applicable to prediction_eval')
+    parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min)')
     parser.add_argument("--only_metrics", action='store_true')
     parser.add_argument("--num_stochastic_samples", type=int, default=100)
 
@@ -343,10 +238,10 @@ def main():
     model_hparams_dict = {}
     if args.checkpoint:
         checkpoint_dir = os.path.normpath(args.checkpoint)
-        if not os.path.exists(checkpoint_dir):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
         if not os.path.isdir(args.checkpoint):
             checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
         with open(os.path.join(checkpoint_dir, "options.json")) as f:
             print("loading options from checkpoint %s" % args.checkpoint)
             options = json.loads(f.read())
@@ -360,7 +255,6 @@ def main():
         try:
             with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                 model_hparams_dict = json.loads(f.read())
-                model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
         except FileNotFoundError:
             print("model_hparams.json was not loaded because it does not exist")
         args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1])
@@ -377,21 +271,26 @@ def main():
     print('------------------------------------- End --------------------------------------')
 
     VideoDataset = datasets.get_dataset_class(args.dataset)
-    dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed,
-                           hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)
-
-    def override_hparams_dict(dataset):
-        hparams_dict = dict(model_hparams_dict)
-        hparams_dict['context_frames'] = dataset.hparams.context_frames
-        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
-        hparams_dict['repeat'] = dataset.hparams.time_shift
-        return hparams_dict
+    dataset = VideoDataset(
+        args.input_dir,
+        mode=args.mode,
+        num_epochs=args.num_epochs,
+        seed=args.seed,
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
 
     VideoPredictionModel = models.get_model_class(args.model)
-    model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams,
-                                 eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations)
-    context_frames = model.hparams.context_frames
-    sequence_length = model.hparams.sequence_length
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams,
+        eval_num_samples=args.num_stochastic_samples,
+        eval_parallel_iterations=args.eval_parallel_iterations)
 
     if args.num_samples:
         if args.num_samples > dataset.num_examples_per_epoch():
@@ -400,44 +299,12 @@ def main():
     else:
         num_examples_per_epoch = dataset.num_examples_per_epoch()
     if num_examples_per_epoch % args.batch_size != 0:
-        raise ValueError('batch_size should evenly divide the dataset')
-
-    inputs, target = dataset.make_batch(args.batch_size)
-    if not isinstance(model, models.GroundTruthVideoPredictionModel):
-        # remove ground truth data past context_frames to prevent accidentally using it
-        for k, v in inputs.items():
-            if k != 'actions':
-                inputs[k] = v[:, :context_frames]
+        raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
 
+    inputs = dataset.make_batch(args.batch_size)
     input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
-    target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph')
-
     with tf.variable_scope(''):
-        model.build_graph(input_phs, target_ph)
-
-    tasks = args.tasks
-    if tasks is None:
-        tasks = ['prediction_eval']
-        if 'pix_distribs' in inputs:
-            tasks.append('motion')
-
-    if 'servo' in tasks:
-        servo_model = VideoPredictionModel(mode='test', hparams_dict=model_hparams_dict, hparams=args.model_hparams)
-        cem_batch_size = 200
-        plan_horizon = sequence_length - 1
-        image_shape = inputs['images'].shape.as_list()[2:]
-        state_shape = inputs['states'].shape.as_list()[2:]
-        action_shape = inputs['actions'].shape.as_list()[2:]
-        servo_input_phs = {
-            'images': tf.placeholder(tf.float32, shape=[cem_batch_size, context_frames] + image_shape),
-            'states': tf.placeholder(tf.float32, shape=[cem_batch_size, 1] + state_shape),
-            'actions': tf.placeholder(tf.float32, shape=[cem_batch_size, plan_horizon] + action_shape),
-        }
-        if isinstance(servo_model, models.GroundTruthVideoPredictionModel):
-            images_shape = inputs['images'].shape.as_list()[1:]
-            servo_input_phs['images'] = tf.placeholder(tf.float32, shape=[cem_batch_size] + images_shape)
-        with tf.variable_scope('', reuse=True):
-            servo_model.build_graph(servo_input_phs)
+        model.build_graph(input_phs)
 
     output_dir = args.output_dir
     if not os.path.exists(output_dir):
@@ -452,156 +319,37 @@ def main():
     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
     config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
     sess = tf.Session(config=config)
+    sess.graph.as_default()
 
     model.restore(sess, args.checkpoint)
 
-    if 'servo' in tasks:
-        servo_policy = ServoPolicy(servo_model, sess)
-
     sample_ind = 0
     while True:
         if args.num_samples and sample_ind >= args.num_samples:
             break
         try:
-            input_results, target_result = sess.run([inputs, target])
+            input_results = sess.run(inputs)
         except tf.errors.OutOfRangeError:
             break
         print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))
 
-        if 'prediction_eval' in tasks:
-            feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
-            feed_dict.update({target_ph: target_result})
-            # compute "best" metrics using the computation graph (if available) or explicitly with python logic
-            if model.eval_outputs and model.eval_metrics:
-                fetches = {'images': model.inputs['images']}
-                fetches.update(model.eval_outputs.items())
-                fetches.update(model.eval_metrics.items())
-                results = sess.run(fetches, feed_dict=feed_dict)
-            else:
-                metric_names = ['psnr', 'ssim', 'ssim_scikit', 'ssim_finn', 'ssim_mcnet', 'vgg_csim']
-                metric_fns = [metrics.peak_signal_to_noise_ratio_np,
-                              metrics.structural_similarity_np,
-                              metrics.structural_similarity_scikit_np,
-                              metrics.structural_similarity_finn_np,
-                              metrics.structural_similarity_mcnet_np,
-                              metrics.vgg_cosine_similarity_np]
-
-                all_gen_images = []
-                all_metrics = [np.empty((args.num_stochastic_samples, args.batch_size, sequence_length - context_frames)) for _ in metric_names]
-                for s in range(args.num_stochastic_samples):
-                    gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
-                    all_gen_images.append(gen_images)
-                    for metric_name, metric_fn, all_metric in zip(metric_names, metric_fns, all_metrics):
-                        metric = metric_fn(gen_images, target_result, keep_axis=(0, 1))
-                        all_metric[s] = metric
-
-                results = {'images': input_results['images'],
-                           'eval_images': target_result}
-                for metric_name, all_metric in zip(metric_names, all_metrics):
-                    for subtask in args.eval_substasks:
-                        results['eval_gen_images_%s/%s' % (metric_name, subtask)] = np.empty_like(all_gen_images[0])
-                        results['eval_%s/%s' % (metric_name, subtask)] = np.empty_like(all_metric[0])
-
-                for i in range(args.batch_size):
-                    for metric_name, all_metric in zip(metric_names, all_metrics):
-                        ordered = np.argsort(np.mean(all_metric, axis=-1)[:, i])  # mean over time and sort over samples
-                        for subtask in args.eval_substasks:
-                            if subtask == 'max':
-                                sidx = ordered[-1]
-                            elif subtask == 'min':
-                                sidx = ordered[0]
-                            else:
-                                raise NotImplementedError
-                            results['eval_gen_images_%s/%s' % (metric_name, subtask)][i] = all_gen_images[sidx][i]
-                            results['eval_%s/%s' % (metric_name, subtask)][i] = all_metric[sidx][i]
-            save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'),
-                                         results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks)
-
-        if 'prediction' in tasks or 'motion' in tasks:  # do these together
-            feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
-            fetches = {'images': model.inputs['images'],
-                       'gen_images': model.outputs['gen_images']}
-            if 'motion' in tasks:
-                fetches.update({'pix_distribs': model.inputs['pix_distribs'],
-                                'gen_pix_distribs': model.outputs['gen_pix_distribs']})
-
-            if args.num_stochastic_samples:
-                all_results = [sess.run(fetches, feed_dict=feed_dict) for _ in range(args.num_stochastic_samples)]
-                all_results = nest.map_structure(lambda *x: np.stack(x), *all_results)
-                all_context_images, all_images = np.split(all_results['images'], [context_frames], axis=2)
-                all_gen_images = all_results['gen_images'][:, :, context_frames - sequence_length:]
-                all_mse = metrics.mean_squared_error_np(all_images, all_gen_images, keep_axis=(0, 1))
-                all_mse_argsort = np.argsort(all_mse, axis=0)
-
-                for subtask, argsort_ind in zip(['_best', '_median', '_worst'],
-                                                [0, args.num_stochastic_samples // 2, -1]):
-                    all_mse_inds = all_mse_argsort[argsort_ind]
-                    gather = lambda x: np.array([x[ind, sample_ind] for sample_ind, ind in enumerate(all_mse_inds)])
-                    results = nest.map_structure(gather, all_results)
-                    if 'prediction' in tasks:
-                        save_prediction_results(os.path.join(output_dir, 'prediction' + subtask),
-                                                results, model.hparams, sample_ind, args.only_metrics)
-                    if 'motion' in tasks:
-                        draw_center = isinstance(model, models.NonTrainableVideoPredictionModel)
-                        save_motion_results(os.path.join(output_dir, 'motion' + subtask),
-                                            results, model.hparams, draw_center, sample_ind, args.only_metrics)
-            else:
-                results = sess.run(fetches, feed_dict=feed_dict)
-                if 'prediction' in tasks:
-                    save_prediction_results(os.path.join(output_dir, 'prediction'),
-                                            results, model.hparams, sample_ind, args.only_metrics)
-                if 'motion' in tasks:
-                    draw_center = isinstance(model, models.NonTrainableVideoPredictionModel)
-                    save_motion_results(os.path.join(output_dir, 'motion'),
-                                        results, model.hparams, draw_center, sample_ind, args.only_metrics)
-
-        if 'servo' in tasks:
-            images = input_results['images']
-            states = input_results['states']
-            gen_actions = []
-            gen_images = []
-            for images_, states_ in zip(images, states):
-                obs = {'context_images': images_[:context_frames],
-                       'context_state': states_[0],
-                       'goal_image': images_[-1]}
-                if isinstance(servo_model, models.GroundTruthVideoPredictionModel):
-                    obs['context_images'] = images_
-                gen_actions_, gen_images_ = servo_policy.act(obs, servo_model.outputs['gen_images'])
-                gen_actions.append(gen_actions_)
-                gen_images.append(gen_images_)
-            gen_actions = np.stack(gen_actions)
-            gen_images = np.stack(gen_images)
-            results = {'images': input_results['images'],
-                       'actions': input_results['actions'],
-                       'goal_image': input_results['images'][:, -1],
-                       'gen_actions': gen_actions,
-                       'gen_images': gen_images}
-            save_servo_results(os.path.join(output_dir, 'servo'),
-                               results, servo_model.hparams, sample_ind, args.only_metrics)
-
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        # compute "best" metrics using the computation graph
+        fetches = {'images': model.inputs['images']}
+        fetches.update(model.eval_outputs.items())
+        fetches.update(model.eval_metrics.items())
+        results = sess.run(fetches, feed_dict=feed_dict)
+        save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'),
+                                     results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks)
         sample_ind += args.batch_size
 
     metric_fnames = []
-    if 'prediction_eval' in tasks:
-        metric_names = ['psnr', 'ssim', 'ssim_finn', 'vgg_csim']
-        subtasks = ['max']
-        for metric_name in metric_names:
-            for subtask in subtasks:
-                metric_fnames.append(
-                    os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name))
-    if 'prediction' in tasks:
-        subtask = '_best' if args.num_stochastic_samples else ''
-        metric_fnames.extend([
-            os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'psnr'),
-            os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'mse'),
-            os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'ssim'),
-        ])
-    if 'motion' in tasks:
-        subtask = '_best' if args.num_stochastic_samples else ''
-        metric_fnames.append(os.path.join(output_dir, 'motion' + subtask, 'metrics', 'pix_dist'))
-    if 'servo' in tasks:
-        metric_fnames.append(os.path.join(output_dir, 'servo', 'metrics', 'goal_image_mse'))
-        metric_fnames.append(os.path.join(output_dir, 'servo', 'metrics', 'action_mse'))
+    metric_names = ['psnr', 'ssim', 'lpips']
+    subtasks = ['max']
+    for metric_name in metric_names:
+        for subtask in subtasks:
+            metric_fnames.append(
+                os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name))
 
     for metric_fname in metric_fnames:
         task_name, _, metric_name = metric_fname.split('/')[-3:]
diff --git a/scripts/generate.py b/scripts/generate.py
index bfc5e89eb8651895aacca3d96be8630ab4ee1517..55eab9d221553d65190f7ba25d344b86171c6fa2 100644
--- a/scripts/generate.py
+++ b/scripts/generate.py
@@ -61,10 +61,10 @@ def main():
     model_hparams_dict = {}
     if args.checkpoint:
         checkpoint_dir = os.path.normpath(args.checkpoint)
-        if not os.path.exists(checkpoint_dir):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
         if not os.path.isdir(args.checkpoint):
             checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
         with open(os.path.join(checkpoint_dir, "options.json")) as f:
             print("loading options from checkpoint %s" % args.checkpoint)
             options = json.loads(f.read())
@@ -78,7 +78,6 @@ def main():
         try:
             with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                 model_hparams_dict = json.loads(f.read())
-                model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
         except FileNotFoundError:
             print("model_hparams.json was not loaded because it does not exist")
         args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1])
@@ -97,18 +96,24 @@ def main():
     print('------------------------------------- End --------------------------------------')
 
     VideoDataset = datasets.get_dataset_class(args.dataset)
-    dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed,
-                           hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)
-
-    def override_hparams_dict(dataset):
-        hparams_dict = dict(model_hparams_dict)
-        hparams_dict['context_frames'] = dataset.hparams.context_frames
-        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
-        hparams_dict['repeat'] = dataset.hparams.time_shift
-        return hparams_dict
+    dataset = VideoDataset(
+        args.input_dir,
+        mode=args.mode,
+        num_epochs=args.num_epochs,
+        seed=args.seed,
+        hparams_dict=dataset_hparams_dict,
+        hparams=args.dataset_hparams)
 
     VideoPredictionModel = models.get_model_class(args.model)
-    model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams)
+    hparams_dict = dict(model_hparams_dict)
+    hparams_dict.update({
+        'context_frames': dataset.hparams.context_frames,
+        'sequence_length': dataset.hparams.sequence_length,
+        'repeat': dataset.hparams.time_shift,
+    })
+    model = VideoPredictionModel(
+        hparams_dict=hparams_dict,
+        hparams=args.model_hparams)
 
     if args.num_samples:
         if args.num_samples > dataset.num_examples_per_epoch():
@@ -117,17 +122,10 @@ def main():
     else:
         num_examples_per_epoch = dataset.num_examples_per_epoch()
     if num_examples_per_epoch % args.batch_size != 0:
-        raise ValueError('batch_size should evenly divide the dataset')
-
-    inputs, _ = dataset.make_batch(args.batch_size)
-    if not isinstance(model, models.GroundTruthVideoPredictionModel):
-        # remove ground truth data past context_frames to prevent accidentally using it
-        for k, v in inputs.items():
-            if k != 'actions':
-                inputs[k] = v[:, :model.hparams.context_frames]
+        raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
 
+    inputs = dataset.make_batch(args.batch_size)
     input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
-
     with tf.variable_scope(''):
         model.build_graph(input_phs)
 
@@ -144,6 +142,7 @@ def main():
     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
     config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
     sess = tf.Session(config=config)
+    sess.graph.as_default()
 
     model.restore(sess, args.checkpoint)
 
@@ -170,7 +169,10 @@ def main():
                 gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
                 for t, gen_image in enumerate(gen_images_):
                     gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
-                    gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
+                    if gen_image.shape[-1] == 1:
+                      gen_image = np.tile(gen_image, (1, 1, 3))
+                    else:
+                      gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
                     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
 
         sample_ind += args.batch_size