diff --git a/README.md b/README.md index 752714f779c5dd01baa7ccdca679d4b6e3a93b4f..95b70a56e1dc23594e9102187a2da02e39d6102d 100644 --- a/README.md +++ b/README.md @@ -36,10 +36,30 @@ pip install -r requirements.txt ```bash bash data/download_and_preprocess_dataset.sh bair ``` -- Download a pre-trained model (e.g. `ours_savp`) for that dataset: +- Download a pre-trained model (e.g. `ours_savp`) for the action-free version of that dataset (i.e. `bair_action_free`): ```bash -bash models/download_model.sh bair ours_savp +bash pretrained_models/download_model.sh bair_action_free ours_savp ``` +- Sample predictions from the model: +```bash +CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \ + --dataset_hparams sequence_length=30 \ + --checkpoint pretrained_models/bair_action_free/ours_savp \ + --mode test \ + --results_dir results_test_samples/bair_action_free +``` +- The predictions are saved as images and GIFs in `results_test_samples/bair_action_free/ours_savp`. +- Evaluate predictions from the model using full-reference metrics: +```bash +CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/bair \ + --dataset_hparams sequence_length=30 \ + --checkpoint pretrained_models/bair_action_free/ours_savp \ + --mode test \ + --results_dir results_test/bair_action_free \ + --batch_size 8 +``` +- The results are saved in `results_test/bair_action_free/ours_savp`. +- See evaluation details of our experiments in [`scripts/generate_all.sh`](scripts/generate_all.sh) and [`scripts/evaluate_all.sh`](scripts/evaluate_all.sh). ### Model Training - To train a model, download and preprocess a dataset (e.g. `bair`): @@ -54,7 +74,7 @@ CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset b ``` - To view training and validation information (e.g. loss plots, GIFs of predictions), run `tensorboard --logdir logs/bair_action_free --port 6006` and open http://localhost:6006. - For multi-GPU training, set `CUDA_VISIBLE_DEVICES` to a comma-separated list of devices, e.g. `CUDA_VISIBLE_DEVICES=0,1,2,3`. To use the CPU, set `CUDA_VISIBLE_DEVICES=""`. -- See more training details for other datasets and models in `scripts/train_all.sh`. +- See more training details for other datasets and models in [`scripts/train_all.sh`](scripts/train_all.sh). ### Datasets Download the datasets using the following script. These datasets are collected by other researchers. Please cite their papers if you use the data. @@ -65,6 +85,8 @@ bash data/download_and_preprocess_dataset.sh dataset_name - `bair`: [BAIR robot pushing dataset](https://sites.google.com/view/sna-visual-mpc/). [[Citation](data/bibtex/sna.txt)] - `kth`: [KTH human actions dataset](http://www.nada.kth.se/cvap/actions/). [[Citation](data/bibtex/kth.txt)] +To use a different dataset, preprocess it into TFRecords files and define a class for it. See [`kth_dataset.py`](video_prediction/datasets/kth_dataset.py) for an example where the original dataset is given as videos. + ## Models diff --git a/data/download_and_preprocess_dataset.sh b/data/download_and_preprocess_dataset.sh index 93652923b68bee6867fd0c9bf9e6b6abc04949b6..a1a6412fc6468b52f8213cae9de075b533c66dfc 100644 --- a/data/download_and_preprocess_dataset.sh +++ b/data/download_and_preprocess_dataset.sh @@ -12,7 +12,7 @@ if [ $1 = "bair" ]; then mkdir -p ${TARGET_DIR} TAR_FNAME=bair_robot_pushing_dataset_v0.tar URL=http://rail.eecs.berkeley.edu/datasets/${TAR_FNAME} - echo "Downloading $1 dataset (this takes a while)" + echo "Downloading '$1' dataset (this takes a while)" wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} tar -xvf ${TARGET_DIR}/${TAR_FNAME} --strip-components=1 -C ${TARGET_DIR} rm ${TARGET_DIR}/${TAR_FNAME} @@ -23,7 +23,7 @@ elif [ $1 = "kth" ]; then TARGET_DIR=./data/kth mkdir -p ${TARGET_DIR} mkdir -p ${TARGET_DIR}/raw - echo "Downloading $1 dataset (this takes a while)" + echo "Downloading '$1' dataset (this takes a while)" for ACTION in walking jogging running boxing handwaving handclapping; do ZIP_FNAME=${ACTION}.zip URL=http://www.nada.kth.se/cvap/actions/${ZIP_FNAME} @@ -36,4 +36,4 @@ else echo "Invalid dataset name: '$1' (choose from 'bair', 'kth')" >&2 exit 1 fi -echo "Succesfully finished downloading and preprocessing dataset $1" +echo "Succesfully finished downloading and preprocessing dataset '$1'" diff --git a/hparams/bair/sv2p_time_variant/model_hparams.json b/hparams/bair/sv2p_time_variant/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..9a10efc29454ab436c40dc703f74823ad238f431 --- /dev/null +++ b/hparams/bair/sv2p_time_variant/model_hparams.json @@ -0,0 +1,12 @@ +{ + "batch_size": 32, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 0.0, + "l2_weight": 1.0, + "kl_weight": 0.008, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0001 +} \ No newline at end of file diff --git a/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json b/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..329acdf7184f1c945acd703c79bd9d57ae475046 --- /dev/null +++ b/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json @@ -0,0 +1,12 @@ +{ + "batch_size": 32, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 0.0, + "l2_weight": 1.0, + "kl_weight": 0.008, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0 +} \ No newline at end of file diff --git a/hparams/kth/sv2p_time_invariant/model_hparams.json b/hparams/kth/sv2p_time_invariant/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..f1273998fd03bd3404eb73c93ba7108ce98f89d3 --- /dev/null +++ b/hparams/kth/sv2p_time_invariant/model_hparams.json @@ -0,0 +1,12 @@ +{ + "batch_size": 16, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 0.0, + "l2_weight": 1.0, + "kl_weight": 0.008, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0 +} \ No newline at end of file diff --git a/hparams/kth/sv2p_time_variant/model_hparams.json b/hparams/kth/sv2p_time_variant/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..f1273998fd03bd3404eb73c93ba7108ce98f89d3 --- /dev/null +++ b/hparams/kth/sv2p_time_variant/model_hparams.json @@ -0,0 +1,12 @@ +{ + "batch_size": 16, + "lr": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "l1_weight": 0.0, + "l2_weight": 1.0, + "kl_weight": 0.008, + "video_sn_vae_gan_weight": 0.0, + "video_sn_gan_weight": 0.0, + "state_weight": 0.0 +} \ No newline at end of file diff --git a/pretrained_models/download_model.sh b/pretrained_models/download_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..6382e4d98aea65987cba7974687b1a8d785a9ca3 --- /dev/null +++ b/pretrained_models/download_model.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash + +# exit if any command fails +set -e + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 DATASET_NAME MODEL_NAME" >&2 + exit 1 +fi +DATASET_NAME=$1 +MODEL_NAME=$2 + +declare -A model_name_to_fname +if [ ${DATASET_NAME} = "bair_action_free" ]; then + model_name_to_fname=( + [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1 + [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1 + [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2 + [ours_gan]=${DATASET_NAME}_ours_gan + [ours_savp]=${DATASET_NAME}_ours_savp + [ours_vae]=${DATASET_NAME}_ours_vae_l1 + [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1 + [ours_vae_l2]=${DATASET_NAME}_ours_vae_l2 + [sv2p_time_invariant]=${DATASET_NAME}_sv2p_time_invariant + ) +elif [ ${DATASET_NAME} = "kth" ]; then + model_name_to_fname=( + [kth_ours_deterministic]=${DATASET_NAME}_kth_ours_deterministic_l1 + [kth_ours_deterministic_l1]=${DATASET_NAME}_kth_ours_deterministic_l1 + [kth_ours_deterministic_l2]=${DATASET_NAME}_kth_ours_deterministic_l2 + [kth_ours_gan]=${DATASET_NAME}_kth_ours_gan + [kth_ours_savp]=${DATASET_NAME}_kth_ours_savp + [kth_ours_vae]=${DATASET_NAME}_kth_ours_vae_l1 + [kth_ours_vae_l1]=${DATASET_NAME}_kth_ours_vae_l1 + [kth_sv2p_time_invariant]=${DATASET_NAME}_kth_sv2p_time_invariant + [kth_sv2p_time_variant]=${DATASET_NAME}_kth_sv2p_time_variant + ) +elif [ ${DATASET_NAME} = "bair" ]; then + model_name_to_fname=( + [bair_ours_deterministic]=${DATASET_NAME}_bair_ours_deterministic_l1 + [bair_ours_deterministic_l1]=${DATASET_NAME}_bair_ours_deterministic_l1 + [bair_ours_deterministic_l2]=${DATASET_NAME}_bair_ours_deterministic_l2 + [bair_ours_gan]=${DATASET_NAME}_bair_ours_gan + [bair_ours_savp]=${DATASET_NAME}_bair_ours_savp + [bair_ours_vae]=${DATASET_NAME}_bair_ours_vae_l1 + [bair_ours_vae_l1]=${DATASET_NAME}_bair_ours_vae_l1 + [bair_ours_vae_l2]=${DATASET_NAME}_bair_ours_vae_l2 + [bair_sna_l1]=${DATASET_NAME}_bair_sna_l1 + [bair_sna_l2]=${DATASET_NAME}_bair_sna_l2 + [bair_sv2p_time_variant]=${DATASET_NAME}_bair_sv2p_time_variant + ) +else + echo "Invalid dataset name: '${DATASET_NAME}' (choose from 'bair_action_free', 'kth', 'bair)" >&2 + exit 1 +fi + +if ! [[ ${model_name_to_fname[${MODEL_NAME}]} ]]; then + echo "Invalid model name '${MODEL_NAME}' when dataset name is '${DATASET_NAME}'. Valid mode names are:" >&2 + for model_name in "${!model_name_to_fname[@]}"; do + echo "'${model_name}'" >&2 + done + exit 1 +fi +TARGET_DIR=./pretrained_models/${DATASET_NAME}/${MODEL_NAME} +mkdir -p ${TARGET_DIR} +TAR_FNAME=${model_name_to_fname[${MODEL_NAME}]}.tar.gz +URL=https://people.eecs.berkeley.edu/~alexlee_gk/projects/savp/pretrained_models/${TAR_FNAME} +echo "Downloading '${TAR_FNAME}'" +wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} +tar -xvf ${TARGET_DIR}/${TAR_FNAME} -C ${TARGET_DIR} +rm ${TARGET_DIR}/${TAR_FNAME} + +echo "Succesfully finished downloading pretrained model '${MODEL_NAME}' on dataset '${DATASET_NAME}'" diff --git a/scripts/evaluate_all.sh b/scripts/evaluate_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..c57f5c895da22f167a0a1bb4204a225965c8a48c --- /dev/null +++ b/scripts/evaluate_all.sh @@ -0,0 +1,44 @@ +# BAIR action-free robot pushing dataset +dataset=bair_action_free +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sv2p_time_invariant \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8 +done + +# KTH human actions dataset +# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence +dataset=kth +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sv2p_time_variant \ + sv2p_time_invariant \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/kth --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 1 +done + +# BAIR action-conditioned robot pushing dataset +dataset=bair +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sna_l1 \ + sna_l2 \ + sv2p_time_variant \ +; do + CUDA_VISIBLE_DEVICES=1 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8 +done diff --git a/scripts/generate.py b/scripts/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..9cf8a25bd22d038a9b80e2b90e7f7585805ab59c --- /dev/null +++ b/scripts/generate.py @@ -0,0 +1,167 @@ +import argparse +import errno +import json +import os +import random + +import cv2 +import numpy as np +import tensorflow as tf + +from video_prediction import datasets, models +from video_prediction.utils.ffmpeg_gif import save_gif + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") + parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") + parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") + parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is " + "results_gif_dir/model_fname") + parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is " + "results_png_dir/model_fname") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + + parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + + parser.add_argument("--batch_size", type=int, default=16, help="number of samples in batch") + parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") + parser.add_argument("--num_epochs", type=int, default=1) + + parser.add_argument("--num_stochastic_samples", type=int, default=5) + parser.add_argument("--gif_length", type=int, help="default is sequence_length") + parser.add_argument("--fps", type=int, default=4) + + parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int, default=7) + + args = parser.parse_args() + + if args.seed is not None: + tf.set_random_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + args.results_gif_dir = args.results_gif_dir or args.results_dir + args.results_png_dir = args.results_png_dir or args.results_dir + dataset_hparams_dict = {} + model_hparams_dict = {} + if args.checkpoint: + checkpoint_dir = os.path.normpath(args.checkpoint) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + if not os.path.isdir(args.checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + args.dataset = args.dataset or options['dataset'] + args.model = args.model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict = json.loads(f.read()) + model_hparams_dict.pop('num_gpus', None) # backwards-compatibility + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1]) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1]) + else: + if not args.dataset: + raise ValueError('dataset is required when checkpoint is not specified') + if not args.model: + raise ValueError('model is required when checkpoint is not specified') + args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) + args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + VideoDataset = datasets.get_dataset_class(args.dataset) + dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, + hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) + + def override_hparams_dict(dataset): + hparams_dict = dict(model_hparams_dict) + hparams_dict['context_frames'] = dataset.hparams.context_frames + hparams_dict['sequence_length'] = dataset.hparams.sequence_length + hparams_dict['repeat'] = dataset.hparams.time_shift + return hparams_dict + + VideoPredictionModel = models.get_model_class(args.model) + model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams) + + inputs, target = dataset.make_batch(args.batch_size) + if not isinstance(model, models.GroundTruthVideoPredictionModel): + # remove ground truth data past context_frames to prevent accidentally using it + for k, v in inputs.items(): + if k != 'actions': + inputs[k] = v[:, :model.hparams.context_frames] + + input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph') + + with tf.variable_scope(''): + model.build_graph(input_phs, target_ph) + + for output_dir in (args.output_gif_dir, args.output_png_dir): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + sess = tf.Session(config=config) + + model.restore(sess, args.checkpoint) + + sample_ind = 0 + while True: + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results, target_result = sess.run([inputs, target]) + except tf.errors.OutOfRangeError: + break + print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) + + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + for stochastic_sample_ind in range(args.num_stochastic_samples): + gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) + for i, gen_images_ in enumerate(gen_images): + gen_images_ = (gen_images_ * 255.0).astype(np.uint8) + + gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + save_gif(os.path.join(args.output_gif_dir, gen_images_fname), + gen_images_[:args.gif_length] if args.gif_length else gen_images_, fps=args.fps) + + for t, gen_image in enumerate(gen_images_): + gen_image_fname = 'gen_image_%05d_%02d_%02d.png' % (sample_ind + i, stochastic_sample_ind, t) + gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) + cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) + + sample_ind += args.batch_size + + +if __name__ == '__main__': + main() diff --git a/scripts/generate_all.sh b/scripts/generate_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..c3736b36df1840cbc46b89526cef0c908b500760 --- /dev/null +++ b/scripts/generate_all.sh @@ -0,0 +1,55 @@ +# BAIR action-free robot pushing dataset +dataset=bair_action_free +CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \ + --dataset_hparams sequence_length=30 --model ground_truth --mode test \ + --output_gif_dir results_test_2afc/${dataset}/ground_truth \ + --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + sv2p_time_invariant \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \ + --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \ + --results_gif_dir results_test_2afc/${dataset} \ + --results_png_dir results_test_samples/${dataset} --gif_length 10 +done + +# KTH human actions dataset +# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence +dataset=kth +CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/kth --dataset kth \ + --dataset_hparams sequence_length=40 --model ground_truth --mode test \ + --output_gif_dir results_test_2afc/${dataset}/ground_truth \ + --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 --batch_size 1 +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + sv2p_time_invariant \ + sv2p_time_variant \ +; do + CUDA_VISIBLE_DEVICES=1 python scripts/generate.py --input_dir data/kth \ + --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test \ + --results_gif_dir results_test_2afc/${dataset} \ + --results_png_dir results_test_samples/${dataset} --gif_length 10 --batch_size 1 +done + +# BAIR action-conditioned robot pushing dataset +dataset=bair +CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \ + --dataset_hparams sequence_length=30 --model ground_truth --mode test \ + --output_gif_dir results_test_2afc/${dataset}/ground_truth \ + --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 +for method_dir in \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + sv2p_time_variant \ +; do + CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \ + --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \ + --results_gif_dir results_test_2afc/${dataset} \ + --results_png_dir results_test_samples/${dataset} --gif_length 10 +done diff --git a/scripts/plot_results.py b/scripts/plot_results.py new file mode 100644 index 0000000000000000000000000000000000000000..db85f0f707c675c89840285f79f3999155f3337d --- /dev/null +++ b/scripts/plot_results.py @@ -0,0 +1,216 @@ +import argparse +import glob +import os + +import numpy as np + + +def load_metrics(prefix_fname): + import csv + with open('%s.csv' % prefix_fname, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter='\t', quotechar='|') + rows = list(reader) + # skip header (first row), indices (first column), and means (last column) + metrics = np.array(rows)[1:, 1:-1].astype(np.float32) + return metrics + + +def plot_metric(metric, start_x=0, color=None, label=None, zorder=None): + import matplotlib.pyplot as plt + metric_mean = np.mean(metric, axis=0) + metric_se = np.std(metric, axis=0) / np.sqrt(len(metric)) + kwargs = {} + if color: + kwargs['color'] = color + if zorder: + kwargs['zorder'] = zorder + plt.errorbar(np.arange(len(metric_mean)) + start_x, + metric_mean, yerr=metric_se, linewidth=2, + label=label, **kwargs) + # metric_std = np.std(metric, axis=0) + # plt.plot(np.arange(len(metric_mean)) + start_x, metric_mean, + # linewidth=2, color=color, label=label) + # plt.fill_between(np.arange(len(metric_mean)) + start_x, + # metric_mean - metric_std, metric_mean + metric_std, + # color=color, alpha=0.5) + + +def get_color(method_name): + import matplotlib.pyplot as plt + color_mapping = { + 'ours_vae_gan': plt.cm.Vega20(0), + 'ours_gan': plt.cm.Vega20(2), + 'ours_vae': plt.cm.Vega20(4), + 'ours_vae_l1': plt.cm.Vega20(4), + 'ours_vae_l2': plt.cm.Vega20(14), + 'ours_deterministic': plt.cm.Vega20(6), + 'ours_deterministic_l1': plt.cm.Vega20(6), + 'ours_deterministic_l2': plt.cm.Vega20(10), + 'sna_l1': plt.cm.Vega20(8), + 'sna_l2': plt.cm.Vega20(9), + 'sv2p_time_variant': plt.cm.Vega20(16), + 'sv2p_time_invariant': plt.cm.Vega20(16), + 'svg_lp': plt.cm.Vega20(18), + 'svg_fp': plt.cm.Vega20(18), + 'svg_fp_resized_data_loader': plt.cm.Vega20(18), + } + if method_name in color_mapping: + color = color_mapping[method_name] + else: + color = None + for k, v in color_mapping.items(): + if method_name.startswith(k): + color = v + break + return color + + +def get_method_name(method_name): + method_name_mapping = { + 'ours_vae_gan': 'Ours, SAVP', + 'ours_gan': 'Ours, GAN-only', + 'ours_vae': 'Ours, VAE-only', + 'ours_vae_l1': 'Ours, VAE-only, $\mathcal{L}_1$', + 'ours_vae_l2': 'Ours, VAE-only, $\mathcal{L}_2$', + 'ours_deterministic': 'Ours, deterministic', + 'ours_deterministic_l1': 'Ours, deterministic, $\mathcal{L}_1$', + 'ours_deterministic_l2': 'Ours, deterministic, $\mathcal{L}_2$', + 'sna_l1': 'SNA, $\mathcal{L}_1$ (Ebert et al.)', + 'sna_l2': 'SNA, $\mathcal{L}_2$ (Ebert et al.)', + 'sv2p_time_variant': 'SV2P time-variant (Babaeizadeh et al.)', + 'sv2p_time_invariant': 'SV2P time-invariant (Babaeizadeh et al.)', + } + return method_name_mapping.get(method_name, method_name) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("results_dir", type=str) + parser.add_argument("--method_dirs", type=str, nargs='+', help='directories in results_dir (all of them by default)') + parser.add_argument("--method_names", type=str, nargs='+', help='method names for the header') + parser.add_argument("--web_dir", type=str, help='default is results_dir/web') + parser.add_argument("--plot_fname", type=str, default='metrics.pdf') + parser.add_argument('--usetex', '--use_tex', action='store_true') + parser.add_argument('--save', action='store_true') + args = parser.parse_args() + + if args.save: + import matplotlib + matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab! + import matplotlib.pyplot as plt + + if args.usetex: + plt.rc('text', usetex=True) + plt.rc('text.latex', preview=True) + plt.rc('font', family='serif') + + if args.web_dir is None: + args.web_dir = os.path.join(args.results_dir, 'web') + + if args.method_dirs is None: + unsorted_method_dirs = os.listdir(args.results_dir) + # exclude web_dir and all directories that starts with web + if args.web_dir in unsorted_method_dirs: + unsorted_method_dirs.remove(args.web_dir) + unsorted_method_dirs = [method_dir for method_dir in unsorted_method_dirs if not os.path.basename(method_dir).startswith('web')] + # put ground_truth and repeat in the front (if any) + method_dirs = [] + for first_method_dir in ['ground_truth', 'repeat']: + if first_method_dir in unsorted_method_dirs: + unsorted_method_dirs.remove(first_method_dir) + method_dirs.append(first_method_dir) + method_dirs.extend(sorted(unsorted_method_dirs)) + else: + method_dirs = list(args.method_dirs) + if args.method_names is None: + method_names = [get_method_name(method_dir) for method_dir in method_dirs] + else: + method_names = list(args.method_names) + if args.usetex: + method_names = [method_name.replace('kl_weight', r'$\lambda_{\textsc{kl}}$') for method_name in method_names] + method_dirs = [os.path.join(args.results_dir, method_dir) for method_dir in method_dirs] + + # infer task and metric names from first method + metric_fnames = sorted(glob.glob('%s/*_max/metrics/*.csv' % glob.escape(method_dirs[0]))) + task_names = [] + metric_names = [] + for metric_fname in metric_fnames: + head, tail = os.path.split(metric_fname) + task_name = head.split('/')[-2] + metric_name, _ = os.path.splitext(tail) + task_names.append(task_name) + metric_names.append(metric_name) + + # save plots + dataset_name = os.path.split(os.path.normpath(args.results_dir))[1] + plots_dir = os.path.join(args.web_dir, 'plots') + if not os.path.exists(plots_dir): + os.makedirs(plots_dir) + + if dataset_name in ('bair', 'bair_action_free'): + context_frames = 2 + training_sequence_length = 12 + elif dataset_name == 'kth': + context_frames = 10 + training_sequence_length = 20 + else: + raise NotImplementedError + + fig = plt.figure(figsize=(12, 5)) + i_task = 0 + for task_name, metric_name in zip(task_names, metric_names): + if not task_name.endswith('max'): + continue + # use ssim for all datasets + # if metric_name not in ('psnr', 'ssim', 'vgg_csim'): + # continue + if dataset_name in ('bair', 'bair_action_free'): + if metric_name not in ('psnr', 'ssim_finn', 'vgg_csim'): + continue + elif dataset_name == 'kth': + if metric_name not in ('psnr', 'ssim_scikit', 'vgg_csim'): + continue + + plt.subplot(1, 3, i_task + 1) # hard-coded 3 columns + + for method_name, method_dir in zip(method_names, method_dirs): + metric = load_metrics(os.path.join(method_dir, task_name, 'metrics', metric_name)) + plot_metric(metric, context_frames + 1, color=get_color(os.path.basename(method_dir)), label=method_name) + + plt.grid(axis='y') + plt.axvline(x=training_sequence_length, linewidth=1, color='k') + plt.xlabel('Time Step', fontsize=15) + plt.ylabel({ + 'psnr': 'Average PSNR', + 'ssim': 'Average SSIM', + 'ssim_scikit': 'Average SSIM', + 'ssim_finn': 'Average SSIM', + 'vgg_csim': 'Average VGG cosine similarity', + }[metric_name], fontsize=15) + plt.xlim((context_frames + 1, metric.shape[1] + context_frames)) + plt.tick_params(labelsize=10) + + if i_task == 1: + # plt.title({ + # 'bair': 'Action-conditioned BAIR Dataset', + # 'bair_action_free': 'Action-free BAIR Dataset', + # 'kth': 'KTH Dataset', + # }[dataset_name], fontsize=16) + if len(method_names) <= 4 and sum([len(method_name) for method_name in method_names]) < 90: + ncol = len(method_names) + else: + ncol = (len(method_names) + 1) // 2 + plt.legend(bbox_to_anchor=(0.5, -0.12), loc='upper center', ncol=ncol, fontsize=15) + i_task += 1 + fig.tight_layout(rect=(0, 0.1, 1, 1)) + + if args.save: + plt.show(block=False) + print("Saving to", os.path.join(plots_dir, args.plot_fname)) + plt.savefig(os.path.join(plots_dir, args.plot_fname), bbox_inches='tight') + else: + plt.show() + + +if __name__ == '__main__': + main() diff --git a/scripts/plot_results_all.sh b/scripts/plot_results_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..6fd01c4ce0d66b6863f16679ea54f1a35fe50d26 --- /dev/null +++ b/scripts/plot_results_all.sh @@ -0,0 +1,78 @@ +python scripts/plot_results.py results_test/bair_action_free --method_dirs \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sv2p_time_invariant \ + svg_lp \ + --save --use_tex --plot_fname metrics_all.pdf + +python scripts/plot_results.py results_test/bair --method_dirs \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_vae_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sna_l1 \ + sna_l2 \ + sv2p_time_variant \ + --save --use_tex --plot_fname metrics_all.pdf + +python scripts/plot_results.py results_test/kth --method_dirs \ + ours_vae_gan \ + ours_gan \ + ours_vae_l1 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + sv2p_time_variant \ + sv2p_time_invariant \ + svg_fp_resized_data_loader \ + --save --use_tex --plot_fname metrics_all.pdf + + +python scripts/plot_results.py results_test/bair_action_free --method_dirs \ + sv2p_time_invariant \ + svg_lp \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics.pdf; \ +python scripts/plot_results.py results_test/bair_action_free --method_dirs \ + ours_deterministic \ + ours_vae \ + ours_gan \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics_ablation.pdf; \ +python scripts/plot_results.py results_test/bair_action_free --method_dirs \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + ours_vae_l1 \ + ours_vae_l2 \ + --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf; \ +python scripts/plot_results.py results_test/kth --method_dirs \ + sv2p_time_variant \ + svg_fp_resized_data_loader \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics.pdf; \ +python scripts/plot_results.py results_test/kth --method_dirs \ + ours_deterministic \ + ours_vae \ + ours_gan \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics_ablation.pdf; \ +python scripts/plot_results.py results_test/bair --method_dirs \ + sv2p_time_variant \ + ours_deterministic \ + ours_vae \ + ours_gan \ + ours_vae_gan \ + --save --use_tex --plot_fname metrics.pdf; \ +python scripts/plot_results.py results_test/bair --method_dirs \ + sna_l1 \ + sna_l2 \ + ours_deterministic_l1 \ + ours_deterministic_l2 \ + ours_vae_l1 \ + ours_vae_l2 \ + --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf