__author__ = 'Lukas Leufen'
__date__ = '2019-10-21'


import logging
import keras
import keras.backend as K
import math


def to_list(arg):
    if not isinstance(arg, list):
        arg = [arg]
    return arg


class Loss:

    def l_p_loss(self, power):
        def loss(y_true, y_pred):
            return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1)
        return loss


class lrDecay(keras.callbacks.History):

    def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
        super(lrDecay, self).__init__()

        self.lr = {'lr': []}
        self.base_lr = base_lr
        self.drop = drop
        self.epochs_drop = epochs_drop

    def on_epoch_begin(self, epoch: int, logs=None):
        if epoch > 0:
            current_lr = self.base_lr * math.pow(self.drop, math.floor(1 + epoch) / self.epochs_drop)
        else:
            current_lr = self.base_lr
        K.set_value(self.model.optimizer.lr, current_lr)
        self.lr['lr'].append(current_lr)
        logging.info(f"Set learning rate to {current_lr}")
        return K.get_value(self.model.optimizer.lr)


class lrCallback(keras.callbacks.History):

    def __init__(self):
        super(lrCallback, self).__init__()
        self.lr = None

    def on_train_begin(self, logs=None):
        self.lr = {}

    def on_epoch_end(self, epoch, logs=None):
        self.lr.append(self.model.optimizer.lr)