__author__ = 'Lukas Leufen, Felix Kleinert'
__date__ = '2020-01-31'

import logging
import math
import pickle
from typing import Union

import numpy as np
from keras import backend as K
from keras.callbacks import History, ModelCheckpoint

from src import helpers


class HistoryAdvanced(History):
    """
    This is almost an identical clone of the original History class. The only difference is that attributes epoch and
    history are instantiated during the init phase and not during on_train_begin. This is required to resume an already
    started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as
    additional callback. To get the full history use this object for further steps instead of the default return of
    training methods like fit_generator().

        hist = HistoryAdvanced()
        history = model.fit_generator(generator=.... , callbacks=[hist])
        history = hist

    If training was started from beginning this class is identical to the returned history class object.
    """

    def __init__(self):
        self.epoch = []
        self.history = {}
        super().__init__()

    def on_train_begin(self, logs=None):
        pass


class LearningRateDecay(History):
    """
    Decay learning rate during model training. Start with a base learning rate and lower this rate after every
    n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
    """

    def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
        super().__init__()
        self.lr = {'lr': []}
        self.base_lr = self.check_param(base_lr, 'base_lr')
        self.drop = self.check_param(drop, 'drop')
        self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
        self.epoch = []
        self.history = {}

    @staticmethod
    def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
        """
        Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
        only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
        value without any check.
        :param value: value to check
        :param name: name of the variable to display in error message
        :param lower: left (lower) endpoint of interval, opened
        :param upper: right (upper) endpoint of interval, closed
        :return: unchanged value or raise ValueError
        """
        if lower is None:
            lower = -np.inf
        if upper is None:
            upper = np.inf
        if lower < value <= upper:
            return value
        else:
            raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
                             f"{name}={value}")

    def on_train_begin(self, logs=None):
        pass

    def on_epoch_begin(self, epoch: int, logs=None):
        """
        Lower learning rate every epochs_drop epochs by factor drop.
        :param epoch: current epoch
        :param logs: ?
        :return: update keras learning rate
        """
        current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop))
        K.set_value(self.model.optimizer.lr, current_lr)
        self.lr['lr'].append(current_lr)
        logging.info(f"Set learning rate to {current_lr}")
        return K.get_value(self.model.optimizer.lr)


class ModelCheckpointAdvanced(ModelCheckpoint):
    """
    Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Specify this callbacks as follow:

        lr = CustomLearningRate()
        hist = CustomHistory()
        callbacks_name = "your_custom_path_%s.pickle"
        callbacks = [{"callback": lr, "path": callbacks_name % "lr"},
                 {"callback": hist, "path": callbacks_name % "hist"}]
        ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks)

    Add this ckpt_callbacks as all other additional callbacks to the callback list. IMPORTANT: Always add ckpt_callbacks
    as last callback to properly update all tracked callbacks, e.g.

        fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks])

    """
    def __init__(self, *args, **kwargs):
        self.callbacks = kwargs.pop("callbacks")
        super().__init__(*args, **kwargs)

    def update_best(self, hist):
        """
        Update internal best on resuming a training process. Otherwise best is set to +/- inf depending on the
        performance metric and the first trained model (first of the resuming training process) will always saved as
        best model because its performance will be better than infinity. To prevent this behaviour and compare the
        performance with the best model performance, call this method before resuming the training process.
        :param hist: The History object from the previous (interrupted) training.
        """
        self.best = hist.history.get(self.monitor)[-1]

    def update_callbacks(self, callbacks):
        """
        Update all stored callback objects. The argument callbacks needs to follow the same convention like described
        in the class description (list of dictionaries). Must be run before resuming a training process.
        """
        self.callbacks = helpers.to_list(callbacks)

    def on_epoch_end(self, epoch, logs=None):
        """
        Save model as usual (see ModelCheckpoint class), but also save additional callbacks.
        """
        super().on_epoch_end(epoch, logs)

        for callback in self.callbacks:
            file_path = callback["path"]
            if self.epochs_since_last_save == 0 and epoch != 0:
                if self.save_best_only:
                    current = logs.get(self.monitor)
                    if current == self.best:
                        if self.verbose > 0:  # pragma: no branch
                            print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                        with open(file_path, "wb") as f:
                            pickle.dump(callback["callback"], f)
                else:
                    with open(file_path, "wb") as f:
                        if self.verbose > 0:  # pragma: no branch
                            print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                        pickle.dump(callback["callback"], f)