diff --git a/scripts/evaluate.py b/scripts/evaluate.py
index 730aba76bb42eec3fb25260562cac54003a8bfbc..c3de4e189257e9b3e0206db165183fe36376f816 100644
--- a/scripts/evaluate.py
+++ b/scripts/evaluate.py
@@ -1,5 +1,6 @@
 import argparse
 import csv
+import errno
 import json
 import os
 import random
@@ -248,37 +249,46 @@ def main():
     ├── output_dir                              # condition / method
     │   ├── prediction                          # task
     │   │   ├── inputs
-    │   │   │   ├── context_image_00000_00.jpg  # indexed by sample index and time step
+    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
     │   │   │   └── ...
     │   │   ├── outputs
-    │   │   │   ├── gen_image_00000_00.jpg      # predicted images (only the ones in the loss)
+    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the ones in the loss)
     │   │   │   └── ...
     │   │   └── metrics
     │   │       ├── psnr.csv
     │   │       ├── mse.csv
     │   │       └── ssim.csv
+    │   ├── prediction_eval_vgg_csim_max        # task: best sample in terms of VGG cosine similarity
+    │   │   ├── inputs
+    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
+    │   │   │   └── ...
+    │   │   ├── outputs
+    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the ones in the loss)
+    │   │   │   └── ...
+    │   │   └── metrics
+    │   │       └── vgg_csim.csv
     │   ├── servo
     │   │   ├── inputs
-    │   │   │   ├── context_image_00000_00.jpg
+    │   │   │   ├── context_image_00000_00.png
     │   │   │   ├── ...
-    │   │   │   ├── goal_image_00000_00.jpg     # only one goal image per sample
+    │   │   │   ├── goal_image_00000_00.png     # only one goal image per sample
     │   │   │   └── ...
     │   │   ├── outputs
-    │   │   │   ├── gen_image_00000_00.jpg
+    │   │   │   ├── gen_image_00000_00.png
     │   │   │   ├── ...
-    │   │   │   ├── gen_image_goal_diff_00000_00.jpg
+    │   │   │   ├── gen_image_goal_diff_00000_00.png
     │   │   │   └── ...
     │   │   └── metrics
     │   │       ├── action_mse.csv
     │   │       └── goal_image_mse.csv
     │   ├── motion
     │   │   ├── inputs
-    │   │   │   ├── pix_distrib_00000_00.jpg
+    │   │   │   ├── pix_distrib_00000_00.png
     │   │   │   └── ...
     │   │   ├── outputs
-    │   │   │   ├── gen_pix_distrib_00000_00.jpg
+    │   │   │   ├── gen_pix_distrib_00000_00.png
     │   │   │   ├── ...
-    │   │   │   ├── gen_pix_distrib_overlaid_00000_00.jpg
+    │   │   │   ├── gen_pix_distrib_overlaid_00000_00.png
     │   │   │   └── ...
     │   │   └── metrics
     │   │       └── pix_dist.csv
@@ -286,41 +296,38 @@ def main():
     └── ...
     """
     parser = argparse.ArgumentParser()
-    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories train,"
-                                                                     "val, test, etc, or a directory containing the tfrecords")
-    parser.add_argument("--output_dir", default=None, help="where to put output files")
-    parser.add_argument("--results_dir", type=str, default='results')
-    parser.add_argument("--seed", type=int, default=7)
-    parser.add_argument("--checkpoint", type=str,
-                        help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000) "
-                             "to resume training from or use for testing. Can specify multiple checkpoints. "
-                             "If more than one checkpoint is provided, the global step from the checkpoints "
-                             "are not restored.")
-    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val')
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified")
+    parser.add_argument("--output_dir", help="output directory where results are saved. default is results_dir/model_fname, "
+                                             "where model_fname is the directory name of checkpoint")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
 
-    parser.add_argument("--batch_size", type=int, default=16, help="number of samples in batch")
-    parser.add_argument("--num_samples", type=int, help="number of samples for the table of sequence (all of them by default)")
+    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')
 
-    parser.add_argument("--num_gpus", type=int, default=1)
-    parser.add_argument("--eval_parallel_iterations", type=int, default=10)
-    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
     parser.add_argument("--dataset", type=str, help="dataset class name")
     parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
     parser.add_argument("--model", type=str, help="model class name")
     parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
 
-    parser.add_argument("--tasks", type=str, nargs='+', help='tasks to evaluation (e.g. prediction, servo, motion)')
-    parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'min'], help='subtasks to evaluation (e.g. max, avg, min)')
+    parser.add_argument("--batch_size", type=int, default=16, help="number of samples in batch")
+    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
+    parser.add_argument("--num_epochs", type=int, default=1)
+
+    parser.add_argument("--tasks", type=str, nargs='+', help='tasks to evaluate (e.g. prediction, prediction_eval, servo, motion)')
+    parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'min'], help='subtasks to evaluate (e.g. max, avg, min). only applicable to prediction_eval')
     parser.add_argument("--only_metrics", action='store_true')
-    parser.add_argument("--num_stochastic_samples", type=int, default=0)
+    parser.add_argument("--num_stochastic_samples", type=int, default=100)
 
-    args = parser.parse_args()
+    parser.add_argument("--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset")
+    parser.add_argument("--gt_outputs_dir", type=str, help="directory containing output ground truth images for ismple dataset")
 
-    cuda_visible_devices = os.environ['CUDA_VISIBLE_DEVICES']
-    if cuda_visible_devices == '':
-        assert args.num_gpus == 0
-    else:
-        assert len(cuda_visible_devices.split(',')) == args.num_gpus
+    parser.add_argument("--eval_parallel_iterations", type=int, default=10)
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int, default=7)
+
+    args = parser.parse_args()
 
     if args.seed is not None:
         tf.set_random_seed(args.seed)
@@ -331,6 +338,8 @@ def main():
     model_hparams_dict = {}
     if args.checkpoint:
         checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
         if not os.path.isdir(args.checkpoint):
             checkpoint_dir, _ = os.path.split(checkpoint_dir)
         with open(os.path.join(checkpoint_dir, "options.json")) as f:
@@ -342,11 +351,11 @@ def main():
             with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                 dataset_hparams_dict = json.loads(f.read())
         except FileNotFoundError:
-            print("model_hparams.json was not loaded because it does not exist")
+            print("dataset_hparams.json was not loaded because it does not exist")
         try:
             with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                 model_hparams_dict = json.loads(f.read())
-                model_hparams_dict.pop('num_gpus', None)
+                model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
         except FileNotFoundError:
             print("model_hparams.json was not loaded because it does not exist")
         args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1])
@@ -363,7 +372,8 @@ def main():
     print('------------------------------------- End --------------------------------------')
 
     VideoDataset = datasets.get_dataset_class(args.dataset)
-    dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=1, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)
+    dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed,
+                           hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)
 
     def override_hparams_dict(dataset):
         hparams_dict = dict(model_hparams_dict)
@@ -373,13 +383,14 @@ def main():
         return hparams_dict
 
     VideoPredictionModel = models.get_model_class(args.model)
-    model = VideoPredictionModel(mode='test', num_gpus=args.num_gpus, eval_parallel_iterations=args.eval_parallel_iterations,
-                                 hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams)
+    model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams,
+                                 eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations)
     context_frames = model.hparams.context_frames
     sequence_length = model.hparams.sequence_length
 
     inputs, target = dataset.make_batch(args.batch_size)
     if not isinstance(model, models.GroundTruthVideoPredictionModel):
+        # remove ground truth data past context_frames to prevent accidentally using it
         for k, v in inputs.items():
             if k != 'actions':
                 inputs[k] = v[:, :context_frames]
@@ -446,10 +457,47 @@ def main():
         if 'prediction_eval' in tasks:
             feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
             feed_dict.update({target_ph: target_result})
-            fetches = {'images': model.inputs['images']}
-            fetches.update(model.eval_outputs.items())
-            fetches.update(model.eval_metrics.items())
-            results = sess.run(fetches, feed_dict=feed_dict)
+            # compute "best" metrics using the computation graph (if available) or explicitly with python logic
+            if model.eval_outputs and model.eval_metrics:
+                fetches = {'images': model.inputs['images']}
+                fetches.update(model.eval_outputs.items())
+                fetches.update(model.eval_metrics.items())
+                results = sess.run(fetches, feed_dict=feed_dict)
+            else:
+                metric_names = ['psnr', 'ssim', 'ssim_scikit', 'ssim_finn', 'vgg_csim']
+                metric_fns = [metrics.peak_signal_to_noise_ratio_np,
+                              metrics.structural_similarity_np,
+                              metrics.structural_similarity_scikit_np,
+                              metrics.structural_similarity_finn_np,
+                              metrics.vgg_cosine_similarity_np]
+
+                all_gen_images = []
+                all_metrics = [np.empty((args.num_stochastic_samples, args.batch_size, sequence_length - context_frames)) for _ in metric_names]
+                for s in range(args.num_stochastic_samples):
+                    gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
+                    all_gen_images.append(gen_images)
+                    for metric_name, metric_fn, all_metric in zip(metric_names, metric_fns, all_metrics):
+                        metric = metric_fn(gen_images, target_result, keep_axis=(0, 1))
+                        all_metric[s] = metric
+
+                results = {}
+                for metric_name, all_metric in zip(metric_names, all_metrics):
+                    for subtask in args.eval_substasks:
+                        results['eval_gen_images_%s/%s' % (metric_name, subtask)] = np.empty_like(all_gen_images[0])
+                        results['eval_%s/%s' % (metric_name, subtask)] = np.empty_like(all_metric[0])
+
+                for i in range(args.batch_size):
+                    for metric_name, all_metric in zip(metric_names, all_metrics):
+                        ordered = np.argsort(np.mean(all_metric, axis=-1)[:, i])  # mean over time and sort over samples
+                        for subtask in args.eval_substasks:
+                            if subtask == 'max':
+                                sidx = ordered[-1]
+                            elif subtask == 'min':
+                                sidx = ordered[0]
+                            else:
+                                raise NotImplementedError
+                            results['eval_gen_images_%s/%s' % (metric_name, subtask)][i] = all_gen_images[sidx][i]
+                            results['eval_%s/%s' % (metric_name, subtask)][i] = all_metric[sidx][i]
             save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'),
                                          results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks)
 
diff --git a/scripts/train.py b/scripts/train.py
index 7723c774f823b61dbf7e8ef72273d88ede2ebe9a..c657c73a87c3ab8c5b4c5caf0f7deb22885bf337 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -35,24 +35,17 @@ def main():
 
     parser.add_argument("--summary_freq", type=int, default=1000, help="save summaries (except for image and eval summaries) every summary_freq steps")
     parser.add_argument("--image_summary_freq", type=int, default=5000, help="save image summaries every image_summary_freq steps")
-    parser.add_argument("--eval_summary_freq", type=int, default=50000, help="save eval summaries every eval_summary_freq steps")
+    parser.add_argument("--eval_summary_freq", type=int, default=0, help="save eval summaries every eval_summary_freq steps")
     parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps")
     parser.add_argument("--metrics_freq", type=int, default=0, help="run and display metrics every metrics_freq step")
     parser.add_argument("--gif_freq", type=int, default=0, help="save gifs of predicted frames every gif_freq steps")
     parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable")
 
-    parser.add_argument("--num_gpus", type=int, default=1)
     parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
     parser.add_argument("--seed", type=int)
 
     args = parser.parse_args()
 
-    cuda_visible_devices = os.environ['CUDA_VISIBLE_DEVICES']
-    if cuda_visible_devices == '':
-        assert args.num_gpus == 0
-    else:
-        assert len(cuda_visible_devices.split(',')) == args.num_gpus
-
     if args.seed is not None:
         tf.set_random_seed(args.seed)
         np.random.seed(args.seed)
@@ -129,8 +122,8 @@ def main():
         return hparams_dict
 
     VideoPredictionModel = models.get_model_class(args.model)
-    train_model = VideoPredictionModel(mode='train', num_gpus=args.num_gpus, hparams_dict=override_hparams_dict(train_dataset), hparams=args.model_hparams)
-    val_models = [VideoPredictionModel(mode='val', num_gpus=args.num_gpus, hparams_dict=override_hparams_dict(val_dataset), hparams=args.model_hparams)
+    train_model = VideoPredictionModel(mode='train', hparams_dict=override_hparams_dict(train_dataset), hparams=args.model_hparams)
+    val_models = [VideoPredictionModel(mode='val', hparams_dict=override_hparams_dict(val_dataset), hparams=args.model_hparams)
                   for val_dataset in val_datasets]
 
     batch_size = train_model.hparams.batch_size
diff --git a/video_prediction/models/base_model.py b/video_prediction/models/base_model.py
index 10317fa553a5e93ffad5f5308d2ec7ded672d5d4..e74d8fc48a867ada490208c6f68de663cfe33efb 100644
--- a/video_prediction/models/base_model.py
+++ b/video_prediction/models/base_model.py
@@ -1,5 +1,6 @@
 import functools
 import itertools
+import os
 from collections import OrderedDict
 
 import tensorflow as tf
@@ -18,7 +19,7 @@ from . import vgg_network
 
 class BaseVideoPredictionModel:
     def __init__(self, mode='train', hparams_dict=None, hparams=None,
-                 num_gpus=1, eval_parallel_iterations=1):
+                 num_gpus=None, eval_num_samples=100, eval_parallel_iterations=1):
         """
         Base video prediction model.
 
@@ -33,7 +34,17 @@ class BaseVideoPredictionModel:
                 These values overrides any values in hparams_dict (if any).
         """
         self.mode = mode
+        cuda_visible_devices = os.environ['CUDA_VISIBLE_DEVICES']
+        if cuda_visible_devices == '':
+            max_num_gpus = 0
+        else:
+            max_num_gpus = len(cuda_visible_devices.split(','))
+        if num_gpus is None:
+            num_gpus = max_num_gpus
+        elif num_gpus > max_num_gpus:
+            raise ValueError('num_gpus=%d is greater than the number of visible devices %d' % (num_gpus, max_num_gpus))
         self.num_gpus = num_gpus
+        self.eval_num_samples = eval_num_samples
         self.eval_parallel_iterations = eval_parallel_iterations
         self.hparams = self.parse_hparams(hparams_dict, hparams)
         if not self.hparams.context_frames:
@@ -116,7 +127,9 @@ class BaseVideoPredictionModel:
             metrics[metric_name] = metric_fn(target_images, gen_images)
         return metrics
 
-    def eval_outputs_and_metrics_fn(self, inputs, outputs, targets, num_samples=100, parallel_iterations=1):
+    def eval_outputs_and_metrics_fn(self, inputs, outputs, targets, num_samples=None, parallel_iterations=None):
+        num_samples = num_samples or self.eval_num_samples
+        parallel_iterations = parallel_iterations or self.eval_parallel_iterations
         eval_outputs = OrderedDict()
         eval_metrics = OrderedDict()
         metric_fns = [
@@ -224,7 +237,6 @@ class VideoPredictionModel(BaseVideoPredictionModel):
                  mode='train',
                  hparams_dict=None,
                  hparams=None,
-                 num_gpus=1,
                  **kwargs):
         """
         Trainable video prediction model with CPU and multi-GPU support.
@@ -248,7 +260,7 @@ class VideoPredictionModel(BaseVideoPredictionModel):
                 where `name` must be defined in `self.get_default_hparams()`.
                 These values overrides any values in hparams_dict (if any).
         """
-        super(VideoPredictionModel, self).__init__(mode, hparams_dict, hparams, num_gpus=num_gpus, **kwargs)
+        super(VideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
 
         self.generator_fn = functools.partial(generator_fn, hparams=self.hparams)
         self.encoder_fn = functools.partial(encoder_fn, hparams=self.hparams) if encoder_fn else None
@@ -344,7 +356,7 @@ class VideoPredictionModel(BaseVideoPredictionModel):
             vgg_cdist_weight=0.0,
             feature_l2_weight=0.0,
             ae_l2_weight=0.0,
-            state_weight=1e-4,
+            state_weight=0.0,
             tv_weight=0.0,
             gan_weight=0.0,
             vae_gan_weight=0.0,
@@ -485,8 +497,7 @@ class VideoPredictionModel(BaseVideoPredictionModel):
             with tf.name_scope("metrics"):
                 metrics = self.metrics_fn(inputs, outputs, targets)
             with tf.name_scope("eval_outputs_and_metrics"):
-                eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs, targets,
-                                                                              parallel_iterations=self.eval_parallel_iterations)
+                eval_outputs, eval_metrics = self.eval_outputs_and_metrics_fn(inputs, outputs, targets)
         else:
             d_losses = {}
             g_losses = {}
@@ -509,7 +520,7 @@ class VideoPredictionModel(BaseVideoPredictionModel):
 
         global_step = tf.train.get_or_create_global_step()
 
-        if self.num_gpus <= 1:
+        if self.num_gpus <= 1:  # cpu or 1 gpu
             outputs_tuple, losses_tuple, metrics_tuple = self.tower_fn(self.inputs, self.targets)
             self.gen_images, self.gen_images_enc, self.outputs, self.eval_outputs = outputs_tuple
             self.d_losses, self.g_losses, g_losses_post = losses_tuple