diff --git a/README.md b/README.md
index 752714f779c5dd01baa7ccdca679d4b6e3a93b4f..95b70a56e1dc23594e9102187a2da02e39d6102d 100644
--- a/README.md
+++ b/README.md
@@ -36,10 +36,30 @@ pip install -r requirements.txt
 ```bash
 bash data/download_and_preprocess_dataset.sh bair
 ```
-- Download a pre-trained model (e.g. `ours_savp`) for that dataset:
+- Download a pre-trained model (e.g. `ours_savp`) for the action-free version of that dataset (i.e. `bair_action_free`):
 ```bash
-bash models/download_model.sh bair ours_savp
+bash pretrained_models/download_model.sh bair_action_free ours_savp
 ```
+- Sample predictions from the model:
+```bash
+CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \
+  --dataset_hparams sequence_length=30 \
+  --checkpoint pretrained_models/bair_action_free/ours_savp \
+  --mode test \
+  --results_dir results_test_samples/bair_action_free
+```
+- The predictions are saved as images and GIFs in `results_test_samples/bair_action_free/ours_savp`.
+- Evaluate predictions from the model using full-reference metrics:
+```bash
+CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/bair \
+  --dataset_hparams sequence_length=30 \
+  --checkpoint pretrained_models/bair_action_free/ours_savp \
+  --mode test \
+  --results_dir results_test/bair_action_free \
+  --batch_size 8
+```
+- The results are saved in `results_test/bair_action_free/ours_savp`.
+- See evaluation details of our experiments in [`scripts/generate_all.sh`](scripts/generate_all.sh) and [`scripts/evaluate_all.sh`](scripts/evaluate_all.sh).
 
 ### Model Training
 - To train a model, download and preprocess a dataset (e.g. `bair`):
@@ -54,7 +74,7 @@ CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset b
 ```
 - To view training and validation information (e.g. loss plots, GIFs of predictions), run `tensorboard --logdir logs/bair_action_free --port 6006` and open http://localhost:6006.
 - For multi-GPU training, set `CUDA_VISIBLE_DEVICES` to a comma-separated list of devices, e.g. `CUDA_VISIBLE_DEVICES=0,1,2,3`. To use the CPU, set `CUDA_VISIBLE_DEVICES=""`.
-- See more training details for other datasets and models in `scripts/train_all.sh`.
+- See more training details for other datasets and models in [`scripts/train_all.sh`](scripts/train_all.sh).
 
 ### Datasets
 Download the datasets using the following script. These datasets are collected by other researchers. Please cite their papers if you use the data.
@@ -65,6 +85,8 @@ bash data/download_and_preprocess_dataset.sh dataset_name
 - `bair`: [BAIR robot pushing dataset](https://sites.google.com/view/sna-visual-mpc/). [[Citation](data/bibtex/sna.txt)]
 - `kth`: [KTH human actions dataset](http://www.nada.kth.se/cvap/actions/). [[Citation](data/bibtex/kth.txt)]
 
+To use a different dataset, preprocess it into TFRecords files and define a class for it. See [`kth_dataset.py`](video_prediction/datasets/kth_dataset.py) for an example where the original dataset is given as videos.
+
 ## Models
 
 
diff --git a/data/download_and_preprocess_dataset.sh b/data/download_and_preprocess_dataset.sh
index 93652923b68bee6867fd0c9bf9e6b6abc04949b6..a1a6412fc6468b52f8213cae9de075b533c66dfc 100644
--- a/data/download_and_preprocess_dataset.sh
+++ b/data/download_and_preprocess_dataset.sh
@@ -12,7 +12,7 @@ if [ $1 = "bair" ]; then
   mkdir -p ${TARGET_DIR}
   TAR_FNAME=bair_robot_pushing_dataset_v0.tar
   URL=http://rail.eecs.berkeley.edu/datasets/${TAR_FNAME}
-  echo "Downloading $1 dataset (this takes a while)"
+  echo "Downloading '$1' dataset (this takes a while)"
   wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME}
   tar -xvf ${TARGET_DIR}/${TAR_FNAME} --strip-components=1 -C ${TARGET_DIR}
   rm ${TARGET_DIR}/${TAR_FNAME}
@@ -23,7 +23,7 @@ elif [ $1 = "kth" ]; then
   TARGET_DIR=./data/kth
   mkdir -p ${TARGET_DIR}
   mkdir -p ${TARGET_DIR}/raw
-  echo "Downloading $1 dataset (this takes a while)"
+  echo "Downloading '$1' dataset (this takes a while)"
   for ACTION in walking jogging running boxing handwaving handclapping; do
     ZIP_FNAME=${ACTION}.zip
     URL=http://www.nada.kth.se/cvap/actions/${ZIP_FNAME}
@@ -36,4 +36,4 @@ else
   echo "Invalid dataset name: '$1' (choose from 'bair', 'kth')" >&2
   exit 1
 fi
-echo "Succesfully finished downloading and preprocessing dataset $1"
+echo "Succesfully finished downloading and preprocessing dataset '$1'"
diff --git a/hparams/bair/sv2p_time_variant/model_hparams.json b/hparams/bair/sv2p_time_variant/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..9a10efc29454ab436c40dc703f74823ad238f431
--- /dev/null
+++ b/hparams/bair/sv2p_time_variant/model_hparams.json
@@ -0,0 +1,12 @@
+{
+    "batch_size": 32,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 0.0,
+    "l2_weight": 1.0,
+    "kl_weight": 0.008,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0001
+}
\ No newline at end of file
diff --git a/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json b/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..329acdf7184f1c945acd703c79bd9d57ae475046
--- /dev/null
+++ b/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json
@@ -0,0 +1,12 @@
+{
+    "batch_size": 32,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 0.0,
+    "l2_weight": 1.0,
+    "kl_weight": 0.008,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0
+}
\ No newline at end of file
diff --git a/hparams/kth/sv2p_time_invariant/model_hparams.json b/hparams/kth/sv2p_time_invariant/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..f1273998fd03bd3404eb73c93ba7108ce98f89d3
--- /dev/null
+++ b/hparams/kth/sv2p_time_invariant/model_hparams.json
@@ -0,0 +1,12 @@
+{
+    "batch_size": 16,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 0.0,
+    "l2_weight": 1.0,
+    "kl_weight": 0.008,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0
+}
\ No newline at end of file
diff --git a/hparams/kth/sv2p_time_variant/model_hparams.json b/hparams/kth/sv2p_time_variant/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..f1273998fd03bd3404eb73c93ba7108ce98f89d3
--- /dev/null
+++ b/hparams/kth/sv2p_time_variant/model_hparams.json
@@ -0,0 +1,12 @@
+{
+    "batch_size": 16,
+    "lr": 0.001,
+    "beta1": 0.9,
+    "beta2": 0.999,
+    "l1_weight": 0.0,
+    "l2_weight": 1.0,
+    "kl_weight": 0.008,
+    "video_sn_vae_gan_weight": 0.0,
+    "video_sn_gan_weight": 0.0,
+    "state_weight": 0.0
+}
\ No newline at end of file
diff --git a/pretrained_models/download_model.sh b/pretrained_models/download_model.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6382e4d98aea65987cba7974687b1a8d785a9ca3
--- /dev/null
+++ b/pretrained_models/download_model.sh
@@ -0,0 +1,73 @@
+#!/usr/bin/env bash
+
+# exit if any command fails
+set -e
+
+if [ "$#" -ne 2 ]; then
+  echo "Usage: $0 DATASET_NAME MODEL_NAME" >&2
+  exit 1
+fi
+DATASET_NAME=$1
+MODEL_NAME=$2
+
+declare -A model_name_to_fname
+if [ ${DATASET_NAME} = "bair_action_free" ]; then
+  model_name_to_fname=(
+    [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1
+    [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1
+    [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2
+    [ours_gan]=${DATASET_NAME}_ours_gan
+    [ours_savp]=${DATASET_NAME}_ours_savp
+    [ours_vae]=${DATASET_NAME}_ours_vae_l1
+    [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1
+    [ours_vae_l2]=${DATASET_NAME}_ours_vae_l2
+    [sv2p_time_invariant]=${DATASET_NAME}_sv2p_time_invariant
+  )
+elif [ ${DATASET_NAME} = "kth" ]; then
+  model_name_to_fname=(
+    [kth_ours_deterministic]=${DATASET_NAME}_kth_ours_deterministic_l1
+    [kth_ours_deterministic_l1]=${DATASET_NAME}_kth_ours_deterministic_l1
+    [kth_ours_deterministic_l2]=${DATASET_NAME}_kth_ours_deterministic_l2
+    [kth_ours_gan]=${DATASET_NAME}_kth_ours_gan
+    [kth_ours_savp]=${DATASET_NAME}_kth_ours_savp
+    [kth_ours_vae]=${DATASET_NAME}_kth_ours_vae_l1
+    [kth_ours_vae_l1]=${DATASET_NAME}_kth_ours_vae_l1
+    [kth_sv2p_time_invariant]=${DATASET_NAME}_kth_sv2p_time_invariant
+    [kth_sv2p_time_variant]=${DATASET_NAME}_kth_sv2p_time_variant
+  )
+elif [ ${DATASET_NAME} = "bair" ]; then
+  model_name_to_fname=(
+    [bair_ours_deterministic]=${DATASET_NAME}_bair_ours_deterministic_l1
+    [bair_ours_deterministic_l1]=${DATASET_NAME}_bair_ours_deterministic_l1
+    [bair_ours_deterministic_l2]=${DATASET_NAME}_bair_ours_deterministic_l2
+    [bair_ours_gan]=${DATASET_NAME}_bair_ours_gan
+    [bair_ours_savp]=${DATASET_NAME}_bair_ours_savp
+    [bair_ours_vae]=${DATASET_NAME}_bair_ours_vae_l1
+    [bair_ours_vae_l1]=${DATASET_NAME}_bair_ours_vae_l1
+    [bair_ours_vae_l2]=${DATASET_NAME}_bair_ours_vae_l2
+    [bair_sna_l1]=${DATASET_NAME}_bair_sna_l1
+    [bair_sna_l2]=${DATASET_NAME}_bair_sna_l2
+    [bair_sv2p_time_variant]=${DATASET_NAME}_bair_sv2p_time_variant
+  )
+else
+  echo "Invalid dataset name: '${DATASET_NAME}' (choose from 'bair_action_free', 'kth', 'bair)" >&2
+  exit 1
+fi
+
+if ! [[ ${model_name_to_fname[${MODEL_NAME}]} ]]; then
+  echo "Invalid model name '${MODEL_NAME}' when dataset name is '${DATASET_NAME}'. Valid mode names are:" >&2
+  for model_name in "${!model_name_to_fname[@]}"; do
+    echo "'${model_name}'" >&2
+  done
+  exit 1
+fi
+TARGET_DIR=./pretrained_models/${DATASET_NAME}/${MODEL_NAME}
+mkdir -p ${TARGET_DIR}
+TAR_FNAME=${model_name_to_fname[${MODEL_NAME}]}.tar.gz
+URL=https://people.eecs.berkeley.edu/~alexlee_gk/projects/savp/pretrained_models/${TAR_FNAME}
+echo "Downloading '${TAR_FNAME}'"
+wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME}
+tar -xvf ${TARGET_DIR}/${TAR_FNAME} -C ${TARGET_DIR}
+rm ${TARGET_DIR}/${TAR_FNAME}
+
+echo "Succesfully finished downloading pretrained model '${MODEL_NAME}' on dataset '${DATASET_NAME}'"
diff --git a/scripts/evaluate_all.sh b/scripts/evaluate_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c57f5c895da22f167a0a1bb4204a225965c8a48c
--- /dev/null
+++ b/scripts/evaluate_all.sh
@@ -0,0 +1,44 @@
+# BAIR action-free robot pushing dataset
+dataset=bair_action_free
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sv2p_time_invariant \
+; do
+   CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8
+done
+
+# KTH human actions dataset
+# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence
+dataset=kth
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sv2p_time_variant \
+    sv2p_time_invariant \
+; do
+    CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/kth --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 1
+done
+
+# BAIR action-conditioned robot pushing dataset
+dataset=bair
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sna_l1 \
+    sna_l2 \
+    sv2p_time_variant \
+; do
+    CUDA_VISIBLE_DEVICES=1 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8
+done
diff --git a/scripts/generate.py b/scripts/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cf8a25bd22d038a9b80e2b90e7f7585805ab59c
--- /dev/null
+++ b/scripts/generate.py
@@ -0,0 +1,167 @@
+import argparse
+import errno
+import json
+import os
+import random
+
+import cv2
+import numpy as np
+import tensorflow as tf
+
+from video_prediction import datasets, models
+from video_prediction.utils.ffmpeg_gif import save_gif
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
+                                                                     "train, val, test, etc, or a directory containing "
+                                                                     "the tfrecords")
+    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified")
+    parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified")
+    parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified")
+    parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is "
+                                                 "results_gif_dir/model_fname")
+    parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is "
+                                                 "results_png_dir/model_fname")
+    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
+
+    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')
+
+    parser.add_argument("--dataset", type=str, help="dataset class name")
+    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
+    parser.add_argument("--model", type=str, help="model class name")
+    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
+
+    parser.add_argument("--batch_size", type=int, default=16, help="number of samples in batch")
+    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
+    parser.add_argument("--num_epochs", type=int, default=1)
+
+    parser.add_argument("--num_stochastic_samples", type=int, default=5)
+    parser.add_argument("--gif_length", type=int, help="default is sequence_length")
+    parser.add_argument("--fps", type=int, default=4)
+
+    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--seed", type=int, default=7)
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        tf.set_random_seed(args.seed)
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+    args.results_gif_dir = args.results_gif_dir or args.results_dir
+    args.results_png_dir = args.results_png_dir or args.results_dir
+    dataset_hparams_dict = {}
+    model_hparams_dict = {}
+    if args.checkpoint:
+        checkpoint_dir = os.path.normpath(args.checkpoint)
+        if not os.path.exists(checkpoint_dir):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
+        if not os.path.isdir(args.checkpoint):
+            checkpoint_dir, _ = os.path.split(checkpoint_dir)
+        with open(os.path.join(checkpoint_dir, "options.json")) as f:
+            print("loading options from checkpoint %s" % args.checkpoint)
+            options = json.loads(f.read())
+            args.dataset = args.dataset or options['dataset']
+            args.model = args.model or options['model']
+        try:
+            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
+                dataset_hparams_dict = json.loads(f.read())
+        except FileNotFoundError:
+            print("dataset_hparams.json was not loaded because it does not exist")
+        try:
+            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
+                model_hparams_dict = json.loads(f.read())
+                model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
+        except FileNotFoundError:
+            print("model_hparams.json was not loaded because it does not exist")
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1])
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1])
+    else:
+        if not args.dataset:
+            raise ValueError('dataset is required when checkpoint is not specified')
+        if not args.model:
+            raise ValueError('model is required when checkpoint is not specified')
+        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
+        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)
+
+    print('----------------------------------- Options ------------------------------------')
+    for k, v in args._get_kwargs():
+        print(k, "=", v)
+    print('------------------------------------- End --------------------------------------')
+
+    VideoDataset = datasets.get_dataset_class(args.dataset)
+    dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed,
+                           hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)
+
+    def override_hparams_dict(dataset):
+        hparams_dict = dict(model_hparams_dict)
+        hparams_dict['context_frames'] = dataset.hparams.context_frames
+        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
+        hparams_dict['repeat'] = dataset.hparams.time_shift
+        return hparams_dict
+
+    VideoPredictionModel = models.get_model_class(args.model)
+    model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams)
+
+    inputs, target = dataset.make_batch(args.batch_size)
+    if not isinstance(model, models.GroundTruthVideoPredictionModel):
+        # remove ground truth data past context_frames to prevent accidentally using it
+        for k, v in inputs.items():
+            if k != 'actions':
+                inputs[k] = v[:, :model.hparams.context_frames]
+
+    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph')
+
+    with tf.variable_scope(''):
+        model.build_graph(input_phs, target_ph)
+
+    for output_dir in (args.output_gif_dir, args.output_png_dir):
+        if not os.path.exists(output_dir):
+            os.makedirs(output_dir)
+        with open(os.path.join(output_dir, "options.json"), "w") as f:
+            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
+        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
+            f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
+        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
+            f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
+
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
+    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+    sess = tf.Session(config=config)
+
+    model.restore(sess, args.checkpoint)
+
+    sample_ind = 0
+    while True:
+        if args.num_samples and sample_ind >= args.num_samples:
+            break
+        try:
+            input_results, target_result = sess.run([inputs, target])
+        except tf.errors.OutOfRangeError:
+            break
+        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))
+
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        for stochastic_sample_ind in range(args.num_stochastic_samples):
+            gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
+            for i, gen_images_ in enumerate(gen_images):
+                gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
+
+                gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
+                save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
+                         gen_images_[:args.gif_length] if args.gif_length else gen_images_, fps=args.fps)
+
+                for t, gen_image in enumerate(gen_images_):
+                    gen_image_fname = 'gen_image_%05d_%02d_%02d.png' % (sample_ind + i, stochastic_sample_ind, t)
+                    gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
+                    cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
+
+        sample_ind += args.batch_size
+
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/generate_all.sh b/scripts/generate_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c3736b36df1840cbc46b89526cef0c908b500760
--- /dev/null
+++ b/scripts/generate_all.sh
@@ -0,0 +1,55 @@
+# BAIR action-free robot pushing dataset
+dataset=bair_action_free
+CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \
+    --dataset_hparams sequence_length=30 --model ground_truth --mode test \
+    --output_gif_dir results_test_2afc/${dataset}/ground_truth \
+    --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    sv2p_time_invariant \
+; do
+    CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \
+        --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \
+        --results_gif_dir results_test_2afc/${dataset} \
+        --results_png_dir results_test_samples/${dataset} --gif_length 10
+done
+
+# KTH human actions dataset
+# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence
+dataset=kth
+CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/kth --dataset kth \
+    --dataset_hparams sequence_length=40 --model ground_truth --mode test \
+    --output_gif_dir results_test_2afc/${dataset}/ground_truth \
+    --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 --batch_size 1
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    sv2p_time_invariant \
+    sv2p_time_variant \
+; do
+    CUDA_VISIBLE_DEVICES=1 python scripts/generate.py --input_dir data/kth \
+        --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test \
+        --results_gif_dir results_test_2afc/${dataset} \
+        --results_png_dir results_test_samples/${dataset} --gif_length 10 --batch_size 1
+done
+
+# BAIR action-conditioned robot pushing dataset
+dataset=bair
+CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \
+    --dataset_hparams sequence_length=30 --model ground_truth --mode test \
+    --output_gif_dir results_test_2afc/${dataset}/ground_truth \
+    --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10
+for method_dir in \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    sv2p_time_variant \
+; do
+    CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \
+        --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \
+        --results_gif_dir results_test_2afc/${dataset} \
+        --results_png_dir results_test_samples/${dataset} --gif_length 10
+done
diff --git a/scripts/plot_results.py b/scripts/plot_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..db85f0f707c675c89840285f79f3999155f3337d
--- /dev/null
+++ b/scripts/plot_results.py
@@ -0,0 +1,216 @@
+import argparse
+import glob
+import os
+
+import numpy as np
+
+
+def load_metrics(prefix_fname):
+    import csv
+    with open('%s.csv' % prefix_fname, newline='') as csvfile:
+        reader = csv.reader(csvfile, delimiter='\t', quotechar='|')
+        rows = list(reader)
+        # skip header (first row), indices (first column), and means (last column)
+        metrics = np.array(rows)[1:, 1:-1].astype(np.float32)
+    return metrics
+
+
+def plot_metric(metric, start_x=0, color=None, label=None, zorder=None):
+    import matplotlib.pyplot as plt
+    metric_mean = np.mean(metric, axis=0)
+    metric_se = np.std(metric, axis=0) / np.sqrt(len(metric))
+    kwargs = {}
+    if color:
+        kwargs['color'] = color
+    if zorder:
+        kwargs['zorder'] = zorder
+    plt.errorbar(np.arange(len(metric_mean)) + start_x,
+                 metric_mean, yerr=metric_se, linewidth=2,
+                 label=label, **kwargs)
+    # metric_std = np.std(metric, axis=0)
+    # plt.plot(np.arange(len(metric_mean)) + start_x, metric_mean,
+    #          linewidth=2, color=color, label=label)
+    # plt.fill_between(np.arange(len(metric_mean)) + start_x,
+    #                  metric_mean - metric_std, metric_mean + metric_std,
+    #                  color=color, alpha=0.5)
+
+
+def get_color(method_name):
+    import matplotlib.pyplot as plt
+    color_mapping = {
+        'ours_vae_gan': plt.cm.Vega20(0),
+        'ours_gan': plt.cm.Vega20(2),
+        'ours_vae': plt.cm.Vega20(4),
+        'ours_vae_l1': plt.cm.Vega20(4),
+        'ours_vae_l2': plt.cm.Vega20(14),
+        'ours_deterministic': plt.cm.Vega20(6),
+        'ours_deterministic_l1': plt.cm.Vega20(6),
+        'ours_deterministic_l2': plt.cm.Vega20(10),
+        'sna_l1': plt.cm.Vega20(8),
+        'sna_l2': plt.cm.Vega20(9),
+        'sv2p_time_variant': plt.cm.Vega20(16),
+        'sv2p_time_invariant': plt.cm.Vega20(16),
+        'svg_lp': plt.cm.Vega20(18),
+        'svg_fp': plt.cm.Vega20(18),
+        'svg_fp_resized_data_loader': plt.cm.Vega20(18),
+    }
+    if method_name in color_mapping:
+        color = color_mapping[method_name]
+    else:
+        color = None
+        for k, v in color_mapping.items():
+            if method_name.startswith(k):
+                color = v
+                break
+    return color
+
+
+def get_method_name(method_name):
+    method_name_mapping = {
+        'ours_vae_gan': 'Ours, SAVP',
+        'ours_gan': 'Ours, GAN-only',
+        'ours_vae': 'Ours, VAE-only',
+        'ours_vae_l1': 'Ours, VAE-only, $\mathcal{L}_1$',
+        'ours_vae_l2': 'Ours, VAE-only, $\mathcal{L}_2$',
+        'ours_deterministic': 'Ours, deterministic',
+        'ours_deterministic_l1': 'Ours, deterministic, $\mathcal{L}_1$',
+        'ours_deterministic_l2': 'Ours, deterministic, $\mathcal{L}_2$',
+        'sna_l1': 'SNA, $\mathcal{L}_1$ (Ebert et al.)',
+        'sna_l2': 'SNA, $\mathcal{L}_2$ (Ebert et al.)',
+        'sv2p_time_variant': 'SV2P time-variant (Babaeizadeh et al.)',
+        'sv2p_time_invariant': 'SV2P time-invariant (Babaeizadeh et al.)',
+    }
+    return method_name_mapping.get(method_name, method_name)
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("results_dir", type=str)
+    parser.add_argument("--method_dirs", type=str, nargs='+', help='directories in results_dir (all of them by default)')
+    parser.add_argument("--method_names", type=str, nargs='+', help='method names for the header')
+    parser.add_argument("--web_dir", type=str, help='default is results_dir/web')
+    parser.add_argument("--plot_fname", type=str, default='metrics.pdf')
+    parser.add_argument('--usetex', '--use_tex', action='store_true')
+    parser.add_argument('--save', action='store_true')
+    args = parser.parse_args()
+
+    if args.save:
+        import matplotlib
+        matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
+    import matplotlib.pyplot as plt
+
+    if args.usetex:
+        plt.rc('text', usetex=True)
+        plt.rc('text.latex', preview=True)
+        plt.rc('font', family='serif')
+
+    if args.web_dir is None:
+        args.web_dir = os.path.join(args.results_dir, 'web')
+
+    if args.method_dirs is None:
+        unsorted_method_dirs = os.listdir(args.results_dir)
+        # exclude web_dir and all directories that starts with web
+        if args.web_dir in unsorted_method_dirs:
+            unsorted_method_dirs.remove(args.web_dir)
+        unsorted_method_dirs = [method_dir for method_dir in unsorted_method_dirs if not os.path.basename(method_dir).startswith('web')]
+        # put ground_truth and repeat in the front (if any)
+        method_dirs = []
+        for first_method_dir in ['ground_truth', 'repeat']:
+            if first_method_dir in unsorted_method_dirs:
+                unsorted_method_dirs.remove(first_method_dir)
+                method_dirs.append(first_method_dir)
+        method_dirs.extend(sorted(unsorted_method_dirs))
+    else:
+        method_dirs = list(args.method_dirs)
+    if args.method_names is None:
+        method_names = [get_method_name(method_dir) for method_dir in method_dirs]
+    else:
+        method_names = list(args.method_names)
+    if args.usetex:
+        method_names = [method_name.replace('kl_weight', r'$\lambda_{\textsc{kl}}$') for method_name in method_names]
+    method_dirs = [os.path.join(args.results_dir, method_dir) for method_dir in method_dirs]
+
+    # infer task and metric names from first method
+    metric_fnames = sorted(glob.glob('%s/*_max/metrics/*.csv' % glob.escape(method_dirs[0])))
+    task_names = []
+    metric_names = []
+    for metric_fname in metric_fnames:
+        head, tail = os.path.split(metric_fname)
+        task_name = head.split('/')[-2]
+        metric_name, _ = os.path.splitext(tail)
+        task_names.append(task_name)
+        metric_names.append(metric_name)
+
+    # save plots
+    dataset_name = os.path.split(os.path.normpath(args.results_dir))[1]
+    plots_dir = os.path.join(args.web_dir, 'plots')
+    if not os.path.exists(plots_dir):
+        os.makedirs(plots_dir)
+
+    if dataset_name in ('bair', 'bair_action_free'):
+        context_frames = 2
+        training_sequence_length = 12
+    elif dataset_name == 'kth':
+        context_frames = 10
+        training_sequence_length = 20
+    else:
+        raise NotImplementedError
+
+    fig = plt.figure(figsize=(12, 5))
+    i_task = 0
+    for task_name, metric_name in zip(task_names, metric_names):
+        if not task_name.endswith('max'):
+            continue
+        # use ssim for all datasets
+        # if metric_name not in ('psnr', 'ssim', 'vgg_csim'):
+        #     continue
+        if dataset_name in ('bair', 'bair_action_free'):
+            if metric_name not in ('psnr', 'ssim_finn', 'vgg_csim'):
+                continue
+        elif dataset_name == 'kth':
+            if metric_name not in ('psnr', 'ssim_scikit', 'vgg_csim'):
+                continue
+
+        plt.subplot(1, 3, i_task + 1)  # hard-coded 3 columns
+
+        for method_name, method_dir in zip(method_names, method_dirs):
+            metric = load_metrics(os.path.join(method_dir, task_name, 'metrics', metric_name))
+            plot_metric(metric, context_frames + 1, color=get_color(os.path.basename(method_dir)), label=method_name)
+
+        plt.grid(axis='y')
+        plt.axvline(x=training_sequence_length, linewidth=1, color='k')
+        plt.xlabel('Time Step', fontsize=15)
+        plt.ylabel({
+            'psnr': 'Average PSNR',
+            'ssim': 'Average SSIM',
+            'ssim_scikit': 'Average SSIM',
+            'ssim_finn': 'Average SSIM',
+            'vgg_csim': 'Average VGG cosine similarity',
+        }[metric_name], fontsize=15)
+        plt.xlim((context_frames + 1, metric.shape[1] + context_frames))
+        plt.tick_params(labelsize=10)
+
+        if i_task == 1:
+            # plt.title({
+            #     'bair': 'Action-conditioned BAIR Dataset',
+            #     'bair_action_free': 'Action-free BAIR Dataset',
+            #     'kth': 'KTH Dataset',
+            # }[dataset_name], fontsize=16)
+            if len(method_names) <= 4 and sum([len(method_name) for method_name in method_names]) < 90:
+                ncol = len(method_names)
+            else:
+                ncol = (len(method_names) + 1) // 2
+            plt.legend(bbox_to_anchor=(0.5, -0.12), loc='upper center', ncol=ncol, fontsize=15)
+        i_task += 1
+    fig.tight_layout(rect=(0, 0.1, 1, 1))
+
+    if args.save:
+        plt.show(block=False)
+        print("Saving to", os.path.join(plots_dir, args.plot_fname))
+        plt.savefig(os.path.join(plots_dir, args.plot_fname), bbox_inches='tight')
+    else:
+        plt.show()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/scripts/plot_results_all.sh b/scripts/plot_results_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6fd01c4ce0d66b6863f16679ea54f1a35fe50d26
--- /dev/null
+++ b/scripts/plot_results_all.sh
@@ -0,0 +1,78 @@
+python scripts/plot_results.py results_test/bair_action_free --method_dirs \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sv2p_time_invariant \
+    svg_lp \
+    --save --use_tex --plot_fname metrics_all.pdf
+
+python scripts/plot_results.py results_test/bair --method_dirs \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sna_l1 \
+    sna_l2 \
+    sv2p_time_variant \
+    --save --use_tex --plot_fname metrics_all.pdf
+
+python scripts/plot_results.py results_test/kth --method_dirs \
+    ours_vae_gan \
+    ours_gan \
+    ours_vae_l1 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    sv2p_time_variant \
+    sv2p_time_invariant \
+    svg_fp_resized_data_loader \
+    --save --use_tex --plot_fname metrics_all.pdf
+
+
+python scripts/plot_results.py results_test/bair_action_free --method_dirs \
+    sv2p_time_invariant \
+    svg_lp \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics.pdf; \
+python scripts/plot_results.py results_test/bair_action_free --method_dirs \
+    ours_deterministic \
+    ours_vae \
+    ours_gan \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics_ablation.pdf; \
+python scripts/plot_results.py results_test/bair_action_free --method_dirs \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf; \
+python scripts/plot_results.py results_test/kth --method_dirs \
+    sv2p_time_variant \
+    svg_fp_resized_data_loader \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics.pdf; \
+python scripts/plot_results.py results_test/kth --method_dirs \
+    ours_deterministic \
+    ours_vae \
+    ours_gan \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics_ablation.pdf; \
+python scripts/plot_results.py results_test/bair --method_dirs \
+    sv2p_time_variant \
+    ours_deterministic \
+    ours_vae \
+    ours_gan \
+    ours_vae_gan \
+    --save --use_tex --plot_fname metrics.pdf; \
+python scripts/plot_results.py results_test/bair --method_dirs \
+    sna_l1 \
+    sna_l2 \
+    ours_deterministic_l1 \
+    ours_deterministic_l2 \
+    ours_vae_l1 \
+    ours_vae_l2 \
+    --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf