from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import csv
import errno
import json
import os
import random

import numpy as np
import tensorflow as tf
from tensorflow.python.util import nest

from video_prediction import datasets, models, metrics
from video_prediction.policies.servo_policy import ServoPolicy


def compute_expectation_np(pix_distrib):
    assert pix_distrib.shape[-1] == 1
    pix_distrib = pix_distrib / np.sum(pix_distrib, axis=(-3, -2), keepdims=True)
    height, width = pix_distrib.shape[-3:-1]
    xv, yv = np.meshgrid(np.arange(width), np.arange(height))
    return np.stack([np.sum(yv[:, :, None] * pix_distrib, axis=(-3, -2, -1)),
                     np.sum(xv[:, :, None] * pix_distrib, axis=(-3, -2, -1))], axis=-1)


def as_heatmap(image, normalize=True):
    import matplotlib.pyplot as plt
    image = np.squeeze(image, axis=-1)
    if normalize:
        image = image / np.max(image, axis=(-2, -1), keepdims=True)
    cmap = plt.get_cmap('viridis')
    heatmap = cmap(image)[..., :3]
    return heatmap


def rgb2gray(rgb):
    return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])


def resize_and_draw_circle(image, size, center, radius, dpi=128.0, **kwargs):
    import matplotlib.pyplot as plt
    from matplotlib.patches import Circle
    import io
    height, width = size
    fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.imshow(image, interpolation='none')
    circle = Circle(center[::-1], radius=radius, **kwargs)
    ax.add_patch(circle)
    ax.axis("off")
    fig.canvas.draw()
    trans = ax.figure.dpi_scale_trans.inverted()
    bbox = ax.bbox.transformed(trans)
    buff = io.BytesIO()
    plt.savefig(buff, format="png", dpi=ax.figure.dpi, bbox_inches=bbox)
    buff.seek(0)
    image = plt.imread(buff)[..., :3]
    plt.close(fig)
    return image


def save_image_sequence(prefix_fname, images, overlaid_images=None, centers=None,
                        radius=5, alpha=0.8, time_start_ind=0):
    import cv2
    head, tail = os.path.split(prefix_fname)
    if head and not os.path.exists(head):
        os.makedirs(head)
    if images.shape[-1] == 1:
        images = as_heatmap(images)
    if overlaid_images is not None:
        assert images.shape[-1] == 3
        assert overlaid_images.shape[-1] == 1
        gray_images = rgb2gray(images)
        overlaid_images = as_heatmap(overlaid_images)
        images = (1 - alpha) * gray_images[..., None] + alpha * overlaid_images
    for t, image in enumerate(images):
        image_fname = '%s_%02d.png' % (prefix_fname, time_start_ind + t)
        if centers is not None:
            scale = np.max(np.array([256, 256]) / np.array(image.shape[:2]))
            image = resize_and_draw_circle(image, np.array(image.shape[:2]) * scale, centers[t], radius,
                                           edgecolor='r', fill=False, linestyle='--', linewidth=2)
        image = (image * 255.0).astype(np.uint8)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(image_fname, image)


def save_image_sequences(prefix_fname, images, overlaid_images=None, centers=None,
                         radius=5, alpha=0.8, sample_start_ind=0, time_start_ind=0):
    head, tail = os.path.split(prefix_fname)
    if head and not os.path.exists(head):
        os.makedirs(head)
    if overlaid_images is None:
        overlaid_images = [None] * len(images)
    if centers is None:
        centers = [None] * len(images)
    for i, (images_, overlaid_images_, centers_) in enumerate(zip(images, overlaid_images, centers)):
        images_fname = '%s_%05d' % (prefix_fname, sample_start_ind + i)
        save_image_sequence(images_fname, images_, overlaid_images_, centers_,
                            radius=radius, alpha=alpha, time_start_ind=time_start_ind)


def save_metrics(prefix_fname, metrics, sample_start_ind=0):
    head, tail = os.path.split(prefix_fname)
    if head and not os.path.exists(head):
        os.makedirs(head)
    assert metrics.ndim == 2
    file_mode = 'w' if sample_start_ind == 0 else 'a'
    with open('%s.csv' % prefix_fname, file_mode, newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        if sample_start_ind == 0:
            writer.writerow(map(str, ['sample_ind'] + list(range(metrics.shape[1])) + ['mean']))
        for i, metrics_row in enumerate(metrics):
            writer.writerow(map(str, [sample_start_ind + i] + list(metrics_row) + [np.mean(metrics_row)]))


def load_metrics(prefix_fname):
    with open('%s.csv' % prefix_fname, newline='') as csvfile:
        reader = csv.reader(csvfile, delimiter='\t', quotechar='|')
        rows = list(reader)
        # skip header (first row), indices (first column), and means (last column)
        metrics = np.array(rows)[1:, 1:-1].astype(np.float32)
    return metrics


def merge_hparams(hparams0, hparams1):
    hparams0 = hparams0 or []
    hparams1 = hparams1 or []
    if not isinstance(hparams0, (list, tuple)):
        hparams0 = [hparams0]
    if not isinstance(hparams1, (list, tuple)):
        hparams1 = [hparams1]
    hparams = list(hparams0) + list(hparams1)
    # simplify into the content if possible
    if len(hparams) == 1:
        hparams, = hparams
    return hparams


def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None):
    context_frames = model_hparams.context_frames
    context_images = results['images'][:, :context_frames]
    images = results['eval_images']
    metric_names = ['psnr', 'ssim', 'ssim_scikit', 'ssim_finn', 'vgg_csim']
    metric_fns = [metrics.peak_signal_to_noise_ratio_np,
                  metrics.structural_similarity_np,
                  metrics.structural_similarity_scikit_np,
                  metrics.structural_similarity_finn_np,
                  None]
    subtasks = subtasks or ['max']
    for metric_name, metric_fn in zip(metric_names, metric_fns):
        for subtask in subtasks:
            subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask)
            gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images'))
            if metric_fn is not None:  # recompute using numpy implementation
                metric = metric_fn(images, gen_images, keep_axis=(0, 1))
            else:
                metric = results['eval_%s/%s' % (metric_name, subtask)]
            save_metrics(os.path.join(subtask_dir, 'metrics', metric_name),
                         metric, sample_start_ind=sample_start_ind)
            if only_metrics:
                continue

            save_image_sequences(os.path.join(subtask_dir, 'inputs', 'context_image'),
                                 context_images, sample_start_ind=sample_start_ind)
            save_image_sequences(os.path.join(subtask_dir, 'outputs', 'gen_image'),
                                 gen_images, sample_start_ind=sample_start_ind)


def save_prediction_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False):
    context_frames = model_hparams.context_frames
    sequence_length = model_hparams.sequence_length
    context_images, images = np.split(results['images'], [context_frames], axis=1)
    gen_images = results['gen_images'][:, context_frames - sequence_length:]
    psnr = metrics.peak_signal_to_noise_ratio_np(images, gen_images, keep_axis=(0, 1))
    mse = metrics.mean_squared_error_np(images, gen_images, keep_axis=(0, 1))
    ssim = metrics.structural_similarity_np(images, gen_images, keep_axis=(0, 1))
    save_metrics(os.path.join(task_dir, 'metrics', 'psnr'),
                 psnr, sample_start_ind=sample_start_ind)
    save_metrics(os.path.join(task_dir, 'metrics', 'mse'),
                 mse, sample_start_ind=sample_start_ind)
    save_metrics(os.path.join(task_dir, 'metrics', 'ssim'),
                 ssim, sample_start_ind=sample_start_ind)
    if only_metrics:
        return

    save_image_sequences(os.path.join(task_dir, 'inputs', 'context_image'),
                         context_images, sample_start_ind=sample_start_ind)
    save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_image'),
                         gen_images, sample_start_ind=sample_start_ind)


def save_motion_results(task_dir, results, model_hparams, draw_center=False,
                        sample_start_ind=0, only_metrics=False):
    context_frames = model_hparams.context_frames
    sequence_length = model_hparams.sequence_length
    pix_distribs = results['pix_distribs'][:, context_frames:]
    gen_pix_distribs = results['gen_pix_distribs'][:, context_frames - sequence_length:]
    pix_dist = metrics.expected_pixel_distance_np(pix_distribs, gen_pix_distribs, keep_axis=(0, 1))
    save_metrics(os.path.join(task_dir, 'metrics', 'pix_dist'),
                 pix_dist, sample_start_ind=sample_start_ind)
    if only_metrics:
        return

    context_images, images = np.split(results['images'], [context_frames], axis=1)
    gen_images = results['gen_images'][:, context_frames - sequence_length:]
    initial_pix_distrib = results['pix_distribs'][:, 0:1]
    num_motions = pix_distribs.shape[-1]
    for i in range(num_motions):
        output_name_posfix = '%d' % i if num_motions > 1 else ''
        centers = compute_expectation_np(initial_pix_distrib[..., i:i + 1]) if draw_center else None
        save_image_sequences(os.path.join(task_dir, 'inputs', 'pix_distrib%s' % output_name_posfix),
                             context_images[:, 0:1], initial_pix_distrib[..., i:i + 1], centers, sample_start_ind=sample_start_ind)
        centers = compute_expectation_np(gen_pix_distribs[..., i:i + 1]) if draw_center else None
        save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_pix_distrib%s' % output_name_posfix),
                             gen_images, gen_pix_distribs[..., i:i + 1], centers, sample_start_ind=sample_start_ind)


def save_servo_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False):
    context_frames = model_hparams.context_frames
    sequence_length = model_hparams.sequence_length
    context_images, images = np.split(results['images'], [context_frames], axis=1)
    gen_images = results['gen_images'][:, context_frames - sequence_length:]
    goal_image = results['goal_image']
    # TODO: should exclude "context" actions assuming that they are passed in to the network
    actions = results['actions']
    gen_actions = results['gen_actions']
    goal_image_mse = metrics.mean_squared_error_np(goal_image, gen_images[:, -1], keep_axis=0)
    action_mse = metrics.mean_squared_error_np(actions, gen_actions, keep_axis=(0, 1))
    save_metrics(os.path.join(task_dir, 'metrics', 'goal_image_mse'),
                 goal_image_mse[:, None], sample_start_ind=sample_start_ind)
    save_metrics(os.path.join(task_dir, 'metrics', 'action_mse'),
                 action_mse, sample_start_ind=sample_start_ind)
    if only_metrics:
        return

    save_image_sequences(os.path.join(task_dir, 'inputs', 'context_image'),
                         context_images, sample_start_ind=sample_start_ind)
    save_image_sequences(os.path.join(task_dir, 'inputs', 'goal_image'),
                         goal_image[:, None], sample_start_ind=sample_start_ind)
    save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_image'),
                         gen_images, sample_start_ind=sample_start_ind)
    gen_image_goal_diffs = np.abs(gen_images - goal_image[:, None])
    save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_image_goal_diff'),
                         gen_image_goal_diffs, sample_start_ind=sample_start_ind)


def main():
    """
    results_dir
    ├── output_dir                              # condition / method
    │   ├── prediction                          # task
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the ones in the loss)
    │   │   │   └── ...
    │   │   └── metrics
    │   │       ├── psnr.csv
    │   │       ├── mse.csv
    │   │       └── ssim.csv
    │   ├── prediction_eval_vgg_csim_max        # task: best sample in terms of VGG cosine similarity
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the ones in the loss)
    │   │   │   └── ...
    │   │   └── metrics
    │   │       └── vgg_csim.csv
    │   ├── servo
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png
    │   │   │   ├── ...
    │   │   │   ├── goal_image_00000_00.png     # only one goal image per sample
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png
    │   │   │   ├── ...
    │   │   │   ├── gen_image_goal_diff_00000_00.png
    │   │   │   └── ...
    │   │   └── metrics
    │   │       ├── action_mse.csv
    │   │       └── goal_image_mse.csv
    │   ├── motion
    │   │   ├── inputs
    │   │   │   ├── pix_distrib_00000_00.png
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_pix_distrib_00000_00.png
    │   │   │   ├── ...
    │   │   │   ├── gen_pix_distrib_overlaid_00000_00.png
    │   │   │   └── ...
    │   │   └── metrics
    │   │       └── pix_dist.csv
    │   └── ...
    └── ...
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
                                                                     "train, val, test, etc, or a directory containing "
                                                                     "the tfrecords")
    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified")
    parser.add_argument("--output_dir", help="output directory where results are saved. default is results_dir/model_fname, "
                                             "where model_fname is the directory name of checkpoint")
    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")

    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")

    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument("--tasks", type=str, nargs='+', help='tasks to evaluate (e.g. prediction, prediction_eval, servo, motion)')
    parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'min'], help='subtasks to evaluate (e.g. max, avg, min). only applicable to prediction_eval')
    parser.add_argument("--only_metrics", action='store_true')
    parser.add_argument("--num_stochastic_samples", type=int, default=100)

    parser.add_argument("--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset")
    parser.add_argument("--gt_outputs_dir", type=str, help="directory containing output ground truth images for ismple dataset")

    parser.add_argument("--eval_parallel_iterations", type=int, default=10)
    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=7)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print("dataset_hparams.json was not loaded because it does not exist")
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
                model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
        except FileNotFoundError:
            print("model_hparams.json was not loaded because it does not exist")
        args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1])
    else:
        if not args.dataset:
            raise ValueError('dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError('model is required when checkpoint is not specified')
        args.output_dir = args.output_dir or os.path.join(args.results_dir, 'model.%s' % args.model)

    print('----------------------------------- Options ------------------------------------')
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print('------------------------------------- End --------------------------------------')

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed,
                           hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)

    def override_hparams_dict(dataset):
        hparams_dict = dict(model_hparams_dict)
        hparams_dict['context_frames'] = dataset.hparams.context_frames
        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
        hparams_dict['repeat'] = dataset.hparams.time_shift
        return hparams_dict

    VideoPredictionModel = models.get_model_class(args.model)
    model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams,
                                 eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations)
    context_frames = model.hparams.context_frames
    sequence_length = model.hparams.sequence_length

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    if num_examples_per_epoch % args.batch_size != 0:
        raise ValueError('batch_size should evenly divide the dataset')

    inputs, target = dataset.make_batch(args.batch_size)
    if not isinstance(model, models.GroundTruthVideoPredictionModel):
        # remove ground truth data past context_frames to prevent accidentally using it
        for k, v in inputs.items():
            if k != 'actions':
                inputs[k] = v[:, :context_frames]

    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
    target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph')

    with tf.variable_scope(''):
        model.build_graph(input_phs, target_ph)

    tasks = args.tasks
    if tasks is None:
        tasks = ['prediction_eval']
        if 'pix_distribs' in inputs:
            tasks.append('motion')

    if 'servo' in tasks:
        servo_model = VideoPredictionModel(mode='test', hparams_dict=model_hparams_dict, hparams=args.model_hparams)
        cem_batch_size = 200
        plan_horizon = sequence_length - 1
        image_shape = inputs['images'].shape.as_list()[2:]
        state_shape = inputs['states'].shape.as_list()[2:]
        action_shape = inputs['actions'].shape.as_list()[2:]
        servo_input_phs = {
            'images': tf.placeholder(tf.float32, shape=[cem_batch_size, context_frames] + image_shape),
            'states': tf.placeholder(tf.float32, shape=[cem_batch_size, 1] + state_shape),
            'actions': tf.placeholder(tf.float32, shape=[cem_batch_size, plan_horizon] + action_shape),
        }
        if isinstance(servo_model, models.GroundTruthVideoPredictionModel):
            images_shape = inputs['images'].shape.as_list()[1:]
            servo_input_phs['images'] = tf.placeholder(tf.float32, shape=[cem_batch_size] + images_shape)
        with tf.variable_scope('', reuse=True):
            servo_model.build_graph(servo_input_phs)

    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
        f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)

    model.restore(sess, args.checkpoint)

    if 'servo' in tasks:
        servo_policy = ServoPolicy(servo_model, sess)

    sample_ind = 0
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results, target_result = sess.run([inputs, target])
        except tf.errors.OutOfRangeError:
            break
        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))

        if 'prediction_eval' in tasks:
            feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
            feed_dict.update({target_ph: target_result})
            # compute "best" metrics using the computation graph (if available) or explicitly with python logic
            if model.eval_outputs and model.eval_metrics:
                fetches = {'images': model.inputs['images']}
                fetches.update(model.eval_outputs.items())
                fetches.update(model.eval_metrics.items())
                results = sess.run(fetches, feed_dict=feed_dict)
            else:
                metric_names = ['psnr', 'ssim', 'ssim_scikit', 'ssim_finn', 'vgg_csim']
                metric_fns = [metrics.peak_signal_to_noise_ratio_np,
                              metrics.structural_similarity_np,
                              metrics.structural_similarity_scikit_np,
                              metrics.structural_similarity_finn_np,
                              metrics.vgg_cosine_similarity_np]

                all_gen_images = []
                all_metrics = [np.empty((args.num_stochastic_samples, args.batch_size, sequence_length - context_frames)) for _ in metric_names]
                for s in range(args.num_stochastic_samples):
                    gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
                    all_gen_images.append(gen_images)
                    for metric_name, metric_fn, all_metric in zip(metric_names, metric_fns, all_metrics):
                        metric = metric_fn(gen_images, target_result, keep_axis=(0, 1))
                        all_metric[s] = metric

                results = {}
                for metric_name, all_metric in zip(metric_names, all_metrics):
                    for subtask in args.eval_substasks:
                        results['eval_gen_images_%s/%s' % (metric_name, subtask)] = np.empty_like(all_gen_images[0])
                        results['eval_%s/%s' % (metric_name, subtask)] = np.empty_like(all_metric[0])

                for i in range(args.batch_size):
                    for metric_name, all_metric in zip(metric_names, all_metrics):
                        ordered = np.argsort(np.mean(all_metric, axis=-1)[:, i])  # mean over time and sort over samples
                        for subtask in args.eval_substasks:
                            if subtask == 'max':
                                sidx = ordered[-1]
                            elif subtask == 'min':
                                sidx = ordered[0]
                            else:
                                raise NotImplementedError
                            results['eval_gen_images_%s/%s' % (metric_name, subtask)][i] = all_gen_images[sidx][i]
                            results['eval_%s/%s' % (metric_name, subtask)][i] = all_metric[sidx][i]
            save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'),
                                         results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks)

        if 'prediction' in tasks or 'motion' in tasks:  # do these together
            feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
            fetches = {'images': model.inputs['images'],
                       'gen_images': model.outputs['gen_images']}
            if 'motion' in tasks:
                fetches.update({'pix_distribs': model.inputs['pix_distribs'],
                                'gen_pix_distribs': model.outputs['gen_pix_distribs']})

            if args.num_stochastic_samples:
                all_results = [sess.run(fetches, feed_dict=feed_dict) for _ in range(args.num_stochastic_samples)]
                all_results = nest.map_structure(lambda *x: np.stack(x), *all_results)
                all_context_images, all_images = np.split(all_results['images'], [context_frames], axis=2)
                all_gen_images = all_results['gen_images'][:, :, context_frames - sequence_length:]
                all_mse = metrics.mean_squared_error_np(all_images, all_gen_images, keep_axis=(0, 1))
                all_mse_argsort = np.argsort(all_mse, axis=0)

                for subtask, argsort_ind in zip(['_best', '_median', '_worst'],
                                                [0, args.num_stochastic_samples // 2, -1]):
                    all_mse_inds = all_mse_argsort[argsort_ind]
                    gather = lambda x: np.array([x[ind, sample_ind] for sample_ind, ind in enumerate(all_mse_inds)])
                    results = nest.map_structure(gather, all_results)
                    if 'prediction' in tasks:
                        save_prediction_results(os.path.join(output_dir, 'prediction' + subtask),
                                                results, model.hparams, sample_ind, args.only_metrics)
                    if 'motion' in tasks:
                        draw_center = isinstance(model, models.NonTrainableVideoPredictionModel)
                        save_motion_results(os.path.join(output_dir, 'motion' + subtask),
                                            results, model.hparams, draw_center, sample_ind, args.only_metrics)
            else:
                results = sess.run(fetches, feed_dict=feed_dict)
                if 'prediction' in tasks:
                    save_prediction_results(os.path.join(output_dir, 'prediction'),
                                            results, model.hparams, sample_ind, args.only_metrics)
                if 'motion' in tasks:
                    draw_center = isinstance(model, models.NonTrainableVideoPredictionModel)
                    save_motion_results(os.path.join(output_dir, 'motion'),
                                        results, model.hparams, draw_center, sample_ind, args.only_metrics)

        if 'servo' in tasks:
            images = input_results['images']
            states = input_results['states']
            gen_actions = []
            gen_images = []
            for images_, states_ in zip(images, states):
                obs = {'context_images': images_[:context_frames],
                       'context_state': states_[0],
                       'goal_image': images_[-1]}
                if isinstance(servo_model, models.GroundTruthVideoPredictionModel):
                    obs['context_images'] = images_
                gen_actions_, gen_images_ = servo_policy.act(obs, servo_model.outputs['gen_images'])
                gen_actions.append(gen_actions_)
                gen_images.append(gen_images_)
            gen_actions = np.stack(gen_actions)
            gen_images = np.stack(gen_images)
            results = {'images': input_results['images'],
                       'actions': input_results['actions'],
                       'goal_image': input_results['images'][:, -1],
                       'gen_actions': gen_actions,
                       'gen_images': gen_images}
            save_servo_results(os.path.join(output_dir, 'servo'),
                               results, servo_model.hparams, sample_ind, args.only_metrics)

        sample_ind += args.batch_size

    metric_fnames = []
    if 'prediction_eval' in tasks:
        metric_names = ['psnr', 'ssim', 'ssim_finn', 'vgg_csim']
        subtasks = ['max']
        for metric_name in metric_names:
            for subtask in subtasks:
                metric_fnames.append(
                    os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name))
    if 'prediction' in tasks:
        subtask = '_best' if args.num_stochastic_samples else ''
        metric_fnames.extend([
            os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'psnr'),
            os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'mse'),
            os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'ssim'),
        ])
    if 'motion' in tasks:
        subtask = '_best' if args.num_stochastic_samples else ''
        metric_fnames.append(os.path.join(output_dir, 'motion' + subtask, 'metrics', 'pix_dist'))
    if 'servo' in tasks:
        metric_fnames.append(os.path.join(output_dir, 'servo', 'metrics', 'goal_image_mse'))
        metric_fnames.append(os.path.join(output_dir, 'servo', 'metrics', 'action_mse'))

    for metric_fname in metric_fnames:
        task_name, _, metric_name = metric_fname.split('/')[-3:]
        metric = load_metrics(metric_fname)
        print('=' * 31)
        print(task_name, metric_name)
        print('-' * 31)
        metric_header_format = '{:>10} {:>20}'
        metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})'
        print(metric_header_format.format('time step', os.path.split(metric_fname)[1]))
        for t, (metric_mean, metric_std) in enumerate(zip(metric.mean(axis=0), metric.std(axis=0))):
            print(metric_row_format.format(t, metric_mean, metric_std))
        print(metric_row_format.format('mean (std)', metric.mean(), metric.std()))
        print('=' * 31)


if __name__ == '__main__':
    main()