from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re
import argparse
import csv
import errno
import json
import os
import random

import numpy as np
import tensorflow as tf

from video_prediction import datasets, models


def save_image_sequence(prefix_fname, images, time_start_ind=0):
    import cv2
    head, tail = os.path.split(prefix_fname)
    if head and not os.path.exists(head):
        os.makedirs(head)
    for t, image in enumerate(images):
        image_fname = '%s_%02d.png' % (prefix_fname, time_start_ind + t)
        image = (image * 255.0).astype(np.uint8)
        if image.shape[-1] == 1:
            image = np.tile(image, (1, 1, 3))
        else:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(image_fname, image)


def save_image_sequences(prefix_fname, images, sample_start_ind=0, time_start_ind=0):
    head, tail = os.path.split(prefix_fname)
    if head and not os.path.exists(head):
        os.makedirs(head)
    for i, images_ in enumerate(images):
        images_fname = '%s_%05d' % (prefix_fname, sample_start_ind + i)
        save_image_sequence(images_fname, images_, time_start_ind=time_start_ind)


def save_metrics(prefix_fname, metrics, sample_start_ind=0):
    head, tail = os.path.split(prefix_fname)
    if head and not os.path.exists(head):
        os.makedirs(head)
    assert metrics.ndim == 2
    file_mode = 'w' if sample_start_ind == 0 else 'a'
    with open('%s.csv' % prefix_fname, file_mode, newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        if sample_start_ind == 0:
            writer.writerow(map(str, ['sample_ind'] + list(range(metrics.shape[1])) + ['mean']))
        for i, metrics_row in enumerate(metrics):
            writer.writerow(map(str, [sample_start_ind + i] + list(metrics_row) + [np.mean(metrics_row)]))


def load_metrics(prefix_fname):
    with open('%s.csv' % prefix_fname, newline='') as csvfile:
        reader = csv.reader(csvfile, delimiter='\t', quotechar='|')
        rows = list(reader)
        # skip header (first row), indices (first column), and means (last column)
        metrics = np.array(rows)[1:, 1:-1].astype(np.float32)
    return metrics


def merge_hparams(hparams0, hparams1):
    hparams0 = hparams0 or []
    hparams1 = hparams1 or []
    if not isinstance(hparams0, (list, tuple)):
        hparams0 = [hparams0]
    if not isinstance(hparams1, (list, tuple)):
        hparams1 = [hparams1]
    hparams = list(hparams0) + list(hparams1)
    # simplify into the content if possible
    if len(hparams) == 1:
        hparams, = hparams
    return hparams


def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None):
    sequence_length = model_hparams.sequence_length
    context_frames = model_hparams.context_frames
    future_length = sequence_length - context_frames

    context_images = results['images'][:, :context_frames]

    if 'eval_diversity' in results:
        metric = results['eval_diversity']
        metric_name = 'diversity'
        subtask_dir = task_dir + '_%s' % metric_name
        save_metrics(os.path.join(subtask_dir, 'metrics', metric_name),
                     metric, sample_start_ind=sample_start_ind)

    subtasks = subtasks or ['max']
    for subtask in subtasks:
        metric_names = []
        for k in results.keys():
            if re.match('eval_(\w+)/%s' % subtask, k) and not re.match('eval_gen_images_(\w+)/%s' % subtask, k):
                m = re.match('eval_(\w+)/%s' % subtask, k)
                metric_names.append(m.group(1))
        for metric_name in metric_names:
            subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask)
            gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images'))
            # only keep the future frames
            gen_images = gen_images[:, -future_length:]
            metric = results['eval_%s/%s' % (metric_name, subtask)]
            save_metrics(os.path.join(subtask_dir, 'metrics', metric_name),
                         metric, sample_start_ind=sample_start_ind)
            if only_metrics:
                continue

            save_image_sequences(os.path.join(subtask_dir, 'inputs', 'context_image'),
                                 context_images, sample_start_ind=sample_start_ind)
            save_image_sequences(os.path.join(subtask_dir, 'outputs', 'gen_image'),
                                 gen_images, sample_start_ind=sample_start_ind)


def main():
    """
    results_dir
    ├── output_dir                              # condition / method
    │   ├── prediction_eval_lpips_max           # task: best sample in terms of LPIPS similarity
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the future ones)
    │   │   │   └── ...
    │   │   └── metrics
    │   │       └── lpips.csv
    │   ├── prediction_eval_ssim_max            # task: best sample in terms of SSIM
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the future ones)
    │   │   │   └── ...
    │   │   └── metrics
    │   │       └── ssim.csv
    │   └── ...
    └── ...
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
                                                                     "train, val, test, etc, or a directory containing "
                                                                     "the tfrecords")
    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified")
    parser.add_argument("--output_dir", help="output directory where results are saved. default is results_dir/model_fname, "
                                             "where model_fname is the directory name of checkpoint")
    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")

    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")

    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min)')
    parser.add_argument("--only_metrics", action='store_true')
    parser.add_argument("--num_stochastic_samples", type=int, default=100)

    parser.add_argument("--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset")
    parser.add_argument("--gt_outputs_dir", type=str, help="directory containing output ground truth images for ismple dataset")

    parser.add_argument("--eval_parallel_iterations", type=int, default=10)
    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=7)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print("dataset_hparams.json was not loaded because it does not exist")
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print("model_hparams.json was not loaded because it does not exist")
        args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1])
    else:
        if not args.dataset:
            raise ValueError('dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError('model is required when checkpoint is not specified')
        args.output_dir = args.output_dir or os.path.join(args.results_dir, 'model.%s' % args.model)

    print('----------------------------------- Options ------------------------------------')
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print('------------------------------------- End --------------------------------------')

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(
        args.input_dir,
        mode=args.mode,
        num_epochs=args.num_epochs,
        seed=args.seed,
        hparams_dict=dataset_hparams_dict,
        hparams=args.dataset_hparams)

    VideoPredictionModel = models.get_model_class(args.model)
    hparams_dict = dict(model_hparams_dict)
    hparams_dict.update({
        'context_frames': dataset.hparams.context_frames,
        'sequence_length': dataset.hparams.sequence_length,
        'repeat': dataset.hparams.time_shift,
    })
    model = VideoPredictionModel(
        mode=args.mode,
        hparams_dict=hparams_dict,
        hparams=args.model_hparams,
        eval_num_samples=args.num_stochastic_samples,
        eval_parallel_iterations=args.eval_parallel_iterations)

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    if num_examples_per_epoch % args.batch_size != 0:
        #bing0
        #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
        pass
    #Bing if it is era 5 data we used dataset.make_batch_v2
    #inputs = dataset.make_batch(args.batch_size)
    inputs = dataset.make_batch_v2(args.batch_size)
    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
    with tf.variable_scope(''):
        model.build_graph(input_phs)

    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
        f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.graph.as_default()

    model.restore(sess, args.checkpoint)

    sample_ind = 0
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results = sess.run(inputs)
        except tf.errors.OutOfRangeError:
            break
        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))

        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
        # compute "best" metrics using the computation graph
        fetches = {'images': model.inputs['images']}
        fetches.update(model.eval_outputs.items())
        fetches.update(model.eval_metrics.items())
        results = sess.run(fetches, feed_dict=feed_dict)
        save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'),
                                     results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks)
        sample_ind += args.batch_size

    metric_fnames = []
    metric_names = ['psnr', 'ssim', 'lpips']
    subtasks = ['max']
    for metric_name in metric_names:
        for subtask in subtasks:
            metric_fnames.append(
                os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name))

    for metric_fname in metric_fnames:
        task_name, _, metric_name = metric_fname.split('/')[-3:]
        metric = load_metrics(metric_fname)
        print('=' * 31)
        print(task_name, metric_name)
        print('-' * 31)
        metric_header_format = '{:>10} {:>20}'
        metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})'
        print(metric_header_format.format('time step', os.path.split(metric_fname)[1]))
        for t, (metric_mean, metric_std) in enumerate(zip(metric.mean(axis=0), metric.std(axis=0))):
            print(metric_row_format.format(t, metric_mean, metric_std))
        print(metric_row_format.format('mean (std)', metric.mean(), metric.std()))
        print('=' * 31)


if __name__ == '__main__':
    main()