import itertools
import os
from collections import OrderedDict

import numpy as np
import six
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.training import device_setter
from tensorflow.python.util import nest

from video_prediction.utils import ffmpeg_gif
from video_prediction.utils import gif_summary

IMAGE_SUMMARIES = "image_summaries"
EVAL_SUMMARIES = "eval_summaries"


def local_device_setter(num_devices=1,
                        ps_device_type='cpu',
                        worker_device='/cpu:0',
                        ps_ops=None,
                        ps_strategy=None):
    if ps_ops == None:
        ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']

    if ps_strategy is None:
        ps_strategy = device_setter._RoundRobinStrategy(num_devices)
    if not six.callable(ps_strategy):
        raise TypeError("ps_strategy must be callable")

    def _local_device_chooser(op):
        current_device = pydev.DeviceSpec.from_string(op.device or "")

        node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
        if node_def.op in ps_ops:
            ps_device_spec = pydev.DeviceSpec.from_string(
                '/{}:{}'.format(ps_device_type, ps_strategy(op)))

            ps_device_spec.merge_from(current_device)
            return ps_device_spec.to_string()
        else:
            worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
            worker_device_spec.merge_from(current_device)
            return worker_device_spec.to_string()

    return _local_device_chooser


def replace_read_ops(loss_or_losses, var_list):
    """
    Replaces read ops of each variable in `vars` with new read ops obtained
    from `read_value()`, thus forcing to read the most up-to-date values of
    the variables (which might incur copies across devices).
    The graph is seeded from the tensor(s) `loss_or_losses`.
    """
    # ops between var ops and the loss
    ops = set(ge.get_walks_intersection_ops([var.op for var in var_list], loss_or_losses))
    if not ops:  # loss_or_losses doesn't depend on any var in var_list, so there is nothiing to replace
        return

    # filter out variables that are not involved in computing the loss
    var_list = [var for var in var_list if var.op in ops]

    for var in var_list:
        output, = var.op.outputs
        read_ops = set(output.consumers()) & ops
        for read_op in read_ops:
            with tf.name_scope('/'.join(read_op.name.split('/')[:-1])):
                with tf.device(read_op.device):
                    read_t, = read_op.outputs
                    consumer_ops = set(read_t.consumers()) & ops
                    # consumer_sgv might have multiple inputs, but we only care
                    # about replacing the input that is read_t
                    consumer_sgv = ge.sgv(consumer_ops)
                    consumer_sgv = consumer_sgv.remap_inputs([list(consumer_sgv.inputs).index(read_t)])
                    ge.connect(ge.sgv(var.read_value().op), consumer_sgv)


def print_loss_info(losses, *tensors):
    def get_descendants(tensor, tensors):
        descendants = []
        for child in tensor.op.inputs:
            if child in tensors:
                descendants.append(child)
            else:
                descendants.extend(get_descendants(child, tensors))
        return descendants

    name_to_tensors = itertools.chain(*[tensor.items() for tensor in tensors])
    tensor_to_names = OrderedDict([(v, k) for k, v in name_to_tensors])

    print(tf.get_default_graph().get_name_scope())
    for name, (loss, weight) in losses.items():
        print('  %s (%r)' % (name, weight))
        descendant_names = []
        for descendant in set(get_descendants(loss, tensor_to_names.keys())):
            descendant_names.append(tensor_to_names[descendant])
        for descendant_name in sorted(descendant_names):
            print('    %s' % descendant_name)


def with_flat_batch(flat_batch_fn, ndims=4):
    def fn(x, *args, **kwargs):
        shape = tf.shape(x)
        flat_batch_shape = tf.concat([[-1], shape[-(ndims-1):]], axis=0)
        flat_batch_shape.set_shape([ndims])
        flat_batch_x = tf.reshape(x, flat_batch_shape)
        flat_batch_r = flat_batch_fn(flat_batch_x, *args, **kwargs)
        r = nest.map_structure(lambda x: tf.reshape(x, tf.concat([shape[:-(ndims-1)], tf.shape(x)[1:]], axis=0)),
                               flat_batch_r)
        return r
    return fn


def transpose_batch_time(x):
    if isinstance(x, tf.Tensor) and x.shape.ndims >= 2:
        return tf.transpose(x, [1, 0] + list(range(2, x.shape.ndims)))
    else:
        return x


def dimension(inputs, axis=0):
    shapes = [input_.shape for input_ in nest.flatten(inputs)]
    s = tf.TensorShape([None])
    for shape in shapes:
        s = s.merge_with(shape[axis:axis + 1])
    dim = s[0].value
    return dim


def unroll_rnn(cell, inputs, scope=None, use_dynamic_rnn=True):
    """Chooses between dynamic_rnn and static_rnn if the leading time dimension is dynamic or not."""
    dim = dimension(inputs, axis=0)
    if use_dynamic_rnn or dim is None:
        return tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32,
                                 swap_memory=False, time_major=True, scope=scope)
    else:
        return static_rnn(cell, inputs, scope=scope)


def static_rnn(cell, inputs, scope=None):
    """Simple version of static_rnn."""
    with tf.variable_scope(scope or "rnn") as varscope:
        batch_size = dimension(inputs, axis=1)
        state = cell.zero_state(batch_size, tf.float32)
        flat_inputs = nest.flatten(inputs)
        flat_inputs = list(zip(*[tf.unstack(flat_input, axis=0) for flat_input in flat_inputs]))
        flat_outputs = []
        for time, flat_input in enumerate(flat_inputs):
            if time > 0:
                varscope.reuse_variables()
            input_ = nest.pack_sequence_as(inputs, flat_input)
            output, state = cell(input_, state)
            flat_output = nest.flatten(output)
            flat_outputs.append(flat_output)
        flat_outputs = [tf.stack(flat_output, axis=0) for flat_output in zip(*flat_outputs)]
        outputs = nest.pack_sequence_as(output, flat_outputs)
        return outputs, state


def maybe_pad_or_slice(tensor, desired_length):
    length = tensor.shape.as_list()[0]
    if length < desired_length:
        paddings = [[0, desired_length - length]] + [[0, 0]] * (tensor.shape.ndims - 1)
        tensor = tf.pad(tensor, paddings)
    elif length > desired_length:
        tensor = tensor[:desired_length]
    assert tensor.shape.as_list()[0] == desired_length
    return tensor


def tensor_to_clip(tensor):
    if tensor.shape.ndims == 6:
        # concatenate last dimension vertically
        tensor = tf.concat(tf.unstack(tensor, axis=-1), axis=-3)
    if tensor.shape.ndims == 5:
        # concatenate batch dimension horizontally
        tensor = tf.concat(tf.unstack(tensor, axis=0), axis=2)
    if tensor.shape.ndims == 4:
        # keep up to the first 3 channels
        tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True)
    else:
        raise NotImplementedError
    return tensor


def tensor_to_image_batch(tensor):
    if tensor.shape.ndims == 6:
        # concatenate last dimension vertically
        tensor= tf.concat(tf.unstack(tensor, axis=-1), axis=-3)
    if tensor.shape.ndims == 5:
        # concatenate time dimension horizontally
        tensor = tf.concat(tf.unstack(tensor, axis=1), axis=2)
    if tensor.shape.ndims == 4:
        # keep up to the first 3 channels
        tensor = tf.image.convert_image_dtype(tensor, dtype=tf.uint8, saturate=True)
    else:
        raise NotImplementedError
    return tensor


def _as_name_scope_map(values):
    name_scope_to_values = {}
    for name, value in values.items():
        name_scope = name.split('/')[0]
        name_scope_to_values.setdefault(name_scope, {})
        name_scope_to_values[name_scope][name] = value
    return name_scope_to_values


def add_image_summaries(outputs, max_outputs=8, collections=None):
    if collections is None:
        collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES]
    for name_scope, outputs in _as_name_scope_map(outputs).items():
        with tf.name_scope(name_scope):
            for name, output in outputs.items():
                if max_outputs:
                    output = output[:max_outputs]
                output = tensor_to_image_batch(output)
                if output.shape[-1] not in (1, 3):
                    # these are feature maps, so just skip them
                    continue
                tf.summary.image(name, output, collections=collections)


def add_gif_summaries(outputs, max_outputs=8, collections=None):
    if collections is None:
        collections = [tf.GraphKeys.SUMMARIES, IMAGE_SUMMARIES]
    for name_scope, outputs in _as_name_scope_map(outputs).items():
        with tf.name_scope(name_scope):
            for name, output in outputs.items():
                if max_outputs:
                    output = output[:max_outputs]
                output = tensor_to_clip(output)
                if output.shape[-1] not in (1, 3):
                    # these are feature maps, so just skip them
                    continue
                gif_summary.gif_summary(name, output[None], fps=4, collections=collections)


def add_scalar_summaries(losses_or_metrics, collections=None):
    for name_scope, losses_or_metrics in _as_name_scope_map(losses_or_metrics).items():
        with tf.name_scope(name_scope):
            for name, loss_or_metric in losses_or_metrics.items():
                if isinstance(loss_or_metric, tuple):
                    loss_or_metric, _ = loss_or_metric
                tf.summary.scalar(name, loss_or_metric, collections=collections)


def add_summaries(outputs, collections=None):
    scalar_outputs = OrderedDict()
    image_outputs = OrderedDict()
    gif_outputs = OrderedDict()
    for name, output in outputs.items():
        if not isinstance(output, tf.Tensor):
            continue
        if output.shape.ndims == 0:
            scalar_outputs[name] = output
        elif output.shape.ndims == 4:
            image_outputs[name] = output
        elif output.shape.ndims > 4 and output.shape[4].value in (1, 3):
            gif_outputs[name] = output
    add_scalar_summaries(scalar_outputs, collections=collections)
    add_image_summaries(image_outputs, collections=collections)
    add_gif_summaries(gif_outputs, collections=collections)


def plot_buf(y):
    def _plot_buf(y):
        from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
        from matplotlib.figure import Figure
        import io
        fig = Figure(figsize=(3, 3))
        canvas = FigureCanvas(fig)
        ax = fig.add_subplot(111)
        ax.plot(y)
        ax.grid(axis='y')
        fig.tight_layout(pad=0)

        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        return buf.getvalue()

    s = tf.py_func(_plot_buf, [y], tf.string)
    return s


def add_plot_image_summaries(metrics, collections=None):
    if collections is None:
        collections = [IMAGE_SUMMARIES]
    for name_scope, metrics in _as_name_scope_map(metrics).items():
        with tf.name_scope(name_scope):
            for name, metric in metrics.items():
                try:
                    buf = plot_buf(metric)
                except:
                    continue
                image = tf.image.decode_png(buf, channels=4)
                image = tf.expand_dims(image, axis=0)
                tf.summary.image(name, image, max_outputs=1, collections=collections)


def plot_summary(name, x, y, display_name=None, description=None, collections=None):
    """
    Hack that uses pr_curve summaries for 2D plots.

    Args:
        x: 1-D tensor with values in increasing order.
        y: 1-D tensor with static shape.

    Note: tensorboard needs to be modified and compiled from source to disable
    default axis range [-0.05, 1.05].
    """
    from tensorboard import summary as summary_lib
    x = tf.convert_to_tensor(x)
    y = tf.convert_to_tensor(y)
    with tf.control_dependencies([
        tf.assert_equal(tf.shape(x), tf.shape(y)),
        tf.assert_equal(y.shape.ndims, 1),
    ]):
        y = tf.identity(y)
    num_thresholds = y.shape[0].value
    if num_thresholds is None:
        raise ValueError('Size of y needs to be statically defined for num_thresholds argument')
    summary = summary_lib.pr_curve_raw_data_op(
        name,
        true_positive_counts=tf.ones(num_thresholds),
        false_positive_counts=tf.ones(num_thresholds),
        true_negative_counts=tf.ones(num_thresholds),
        false_negative_counts=tf.ones(num_thresholds),
        precision=y[::-1],
        recall=x[::-1],
        num_thresholds=num_thresholds,
        display_name=display_name,
        description=description,
        collections=collections)
    return summary


def add_plot_summaries(metrics, x_offset=0, collections=None):
    for name_scope, metrics in _as_name_scope_map(metrics).items():
        with tf.name_scope(name_scope):
            for name, metric in metrics.items():
                plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections)


def add_plot_and_scalar_summaries(metrics, x_offset=0, collections=None):
    for name_scope, metrics in _as_name_scope_map(metrics).items():
        with tf.name_scope(name_scope):
            for name, metric in metrics.items():
                tf.summary.scalar(name, tf.reduce_mean(metric), collections=collections)
                plot_summary(name, x_offset + tf.range(tf.shape(metric)[0]), metric, collections=collections)


def convert_tensor_to_gif_summary(summ):
    if isinstance(summ, bytes):
        summary_proto = tf.Summary()
        summary_proto.ParseFromString(summ)
        summ = summary_proto

    summary = tf.Summary()
    for value in summ.value:
        tag = value.tag
        try:
            images_arr = tf.make_ndarray(value.tensor)
        except TypeError:
            summary.value.add(tag=tag, image=value.image)
            continue

        if len(images_arr.shape) == 5:
            images_arr = np.concatenate(list(images_arr), axis=-2)
        if len(images_arr.shape) != 4:
            raise ValueError('Tensors must be 4-D or 5-D for gif summary.')
        channels = images_arr.shape[-1]
        if channels < 1 or channels > 4:
            raise ValueError('Tensors must have 1, 2, 3, or 4 color channels for gif summary.')

        encoded_image_string = ffmpeg_gif.encode_gif(images_arr, fps=4)

        image = tf.Summary.Image()
        image.height = images_arr.shape[-3]
        image.width = images_arr.shape[-2]
        image.colorspace = channels  # 1: grayscale, 2: grayscale + alpha, 3: RGB, 4: RGBA
        image.encoded_image_string = encoded_image_string
        summary.value.add(tag=tag, image=image)
    return summary


def compute_averaged_gradients(opt, tower_loss, **kwargs):
    tower_gradvars = []
    for loss in tower_loss:
        with tf.device(loss.device):
            gradvars = opt.compute_gradients(loss, **kwargs)
            tower_gradvars.append(gradvars)

    # Now compute global loss and gradients.
    gradvars = []
    with tf.name_scope('gradient_averaging'):
        all_grads = {}
        for grad, var in itertools.chain(*tower_gradvars):
            if grad is not None:
                all_grads.setdefault(var, []).append(grad)
        for var, grads in all_grads.items():
            # Average gradients on the same device as the variables
            # to which they apply.
            with tf.device(var.device):
                if len(grads) == 1:
                    avg_grad = grads[0]
                else:
                    avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
            gradvars.append((avg_grad, var))
    return gradvars


# the next 3 function are from tensorpack:
# https://github.com/tensorpack/tensorpack/blob/master/tensorpack/graph_builder/utils.py
def split_grad_list(grad_list):
    """
    Args:
        grad_list: K x N x 2

    Returns:
        K x N: gradients
        K x N: variables
    """
    g = []
    v = []
    for tower in grad_list:
        g.append([x[0] for x in tower])
        v.append([x[1] for x in tower])
    return g, v


def merge_grad_list(all_grads, all_vars):
    """
    Args:
        all_grads (K x N): gradients
        all_vars(K x N): variables

    Return:
        K x N x 2: list of list of (grad, var) pairs
    """
    return [list(zip(gs, vs)) for gs, vs in zip(all_grads, all_vars)]


def allreduce_grads(all_grads, average):
    """
    All-reduce average the gradients among K devices. Results are broadcasted to all devices.

    Args:
        all_grads (K x N): List of list of gradients. N is the number of variables.
        average (bool): average gradients or not.

    Returns:
        K x N: same as input, but each grad is replaced by the average over K devices.
    """
    from tensorflow.contrib import nccl
    nr_tower = len(all_grads)
    if nr_tower == 1:
        return all_grads
    new_all_grads = []  # N x K
    for grads in zip(*all_grads):
        summed = nccl.all_sum(grads)

        grads_for_devices = []  # K
        for g in summed:
            with tf.device(g.device):
                # tensorflow/benchmarks didn't average gradients
                if average:
                    g = tf.multiply(g, 1.0 / nr_tower)
            grads_for_devices.append(g)
        new_all_grads.append(grads_for_devices)

    # transpose to K x N
    ret = list(zip(*new_all_grads))
    return ret


def _reduce_entries(*entries):
    num_gpus = len(entries)
    if entries[0] is None:
        assert all(entry is None for entry in entries[1:])
        reduced_entry = None
    elif isinstance(entries[0], tf.Tensor):
        if entries[0].shape.ndims == 0:
            reduced_entry = tf.add_n(entries) / tf.to_float(num_gpus)
        else:
            reduced_entry = tf.concat(entries, axis=0)
    elif np.isscalar(entries[0]) or isinstance(entries[0], np.ndarray):
        if np.isscalar(entries[0]) or entries[0].ndim == 0:
            reduced_entry = sum(entries) / float(num_gpus)
        else:
            reduced_entry = np.concatenate(entries, axis=0)
    elif isinstance(entries[0], tuple) and len(entries[0]) == 2:
        losses, weights = zip(*entries)
        loss = tf.add_n(losses) / tf.to_float(num_gpus)
        if isinstance(weights[0], tf.Tensor):
            with tf.control_dependencies([tf.assert_equal(weight, weights[0]) for weight in weights[1:]]):
                weight = tf.identity(weights[0])
        else:
            assert all(weight == weights[0] for weight in weights[1:])
            weight = weights[0]
        reduced_entry = (loss, weight)
    else:
        raise NotImplementedError
    return reduced_entry


def reduce_tensors(structures, shallow=False):
    if len(structures) == 1:
        reduced_structure = structures[0]
    else:
        if shallow:
            if isinstance(structures[0], dict):
                shallow_tree = type(structures[0])([(k, None) for k in structures[0]])
            else:
                shallow_tree = type(structures[0])([None for _ in structures[0]])
            reduced_structure = nest.map_structure_up_to(shallow_tree, _reduce_entries, *structures)
        else:
            reduced_structure = nest.map_structure(_reduce_entries, *structures)
    return reduced_structure


def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None):


    if os.path.isdir(checkpoint):
        # latest_checkpoint doesn't work when the path has special characters
        checkpoint = tf.train.latest_checkpoint(checkpoint)
    checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint)
    checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys()
    restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0])
    if not var_list:
        var_list = tf.global_variables()
    restore_vars = {restore_to_checkpoint_mapping(var.name, checkpoint_var_names): var for var in var_list}
    if skip_global_step and 'global_step' in restore_vars:
        del restore_vars['global_step']
    # restore variables that are both in the global graph and in the checkpoint
    restore_and_checkpoint_vars = {name: var for name, var in restore_vars.items() if name in checkpoint_var_names}
    #restore_saver = tf.train.Saver(max_to_keep=1, var_list=restore_and_checkpoint_vars, filename=checkpoint)
    # print out information regarding variables that were not restored or used for restoring
    restore_not_in_checkpoint_vars = {name: var for name, var in restore_vars.items() if
                                      name not in checkpoint_var_names}
    checkpoint_not_in_restore_var_names = [name for name in checkpoint_var_names if name not in restore_vars]
    if skip_global_step and 'global_step' in checkpoint_not_in_restore_var_names:
        checkpoint_not_in_restore_var_names.remove('global_step')
    if restore_not_in_checkpoint_vars:
        print("global variables that were not restored because they are "
              "not in the checkpoint:")
        for name, _ in sorted(restore_not_in_checkpoint_vars.items()):
            print("    ", name)
    if checkpoint_not_in_restore_var_names:
        print("checkpoint variables that were not used for restoring "
              "because they are not in the graph:")
        for name in sorted(checkpoint_not_in_restore_var_names):
            print("    ", name)


    restore_saver = tf.train.Saver(max_to_keep=1, var_list=restore_and_checkpoint_vars, filename=checkpoint)

    return restore_saver, checkpoint


def pixel_distribution(pos, height, width):
    batch_size = pos.get_shape().as_list()[0]
    y, x = tf.unstack(pos, 2, axis=1)

    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1

    Ia = tf.reshape(tf.one_hot(y0 * width + x0, height * width), [batch_size, height, width])
    Ib = tf.reshape(tf.one_hot(y1 * width + x0, height * width), [batch_size, height, width])
    Ic = tf.reshape(tf.one_hot(y0 * width + x1, height * width), [batch_size, height, width])
    Id = tf.reshape(tf.one_hot(y1 * width + x1, height * width), [batch_size, height, width])

    x0_f = tf.cast(x0, 'float32')
    x1_f = tf.cast(x1, 'float32')
    y0_f = tf.cast(y0, 'float32')
    y1_f = tf.cast(y1, 'float32')
    wa = ((x1_f - x) * (y1_f - y))[:, None, None]
    wb = ((x1_f - x) * (y - y0_f))[:, None, None]
    wc = ((x - x0_f) * (y1_f - y))[:, None, None]
    wd = ((x - x0_f) * (y - y0_f))[:, None, None]

    return tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])


def flow_to_rgb(flows):
    """The last axis should have dimension 2, for x and y values."""

    def cartesian_to_polar(x, y):
        magnitude = tf.sqrt(tf.square(x) + tf.square(y))
        angle = tf.atan2(y, x)
        return magnitude, angle

    mag, ang = cartesian_to_polar(*tf.unstack(flows, axis=-1))
    ang_normalized = (ang + np.pi) / (2 * np.pi)
    mag_min = tf.reduce_min(mag)
    mag_max = tf.reduce_max(mag)
    mag_normalized = (mag - mag_min) / (mag_max - mag_min)
    hsv = tf.stack([ang_normalized, tf.ones_like(ang), mag_normalized], axis=-1)
    rgb = tf.image.hsv_to_rgb(hsv)
    return rgb