__author__ = "Lukas Leufen"
__date__ = '2019-12-12'


from abc import ABC
from typing import Any, Callable

import keras

from src import helpers


class AbstractModelClass(ABC):

    """
    The AbstractModelClass provides a unified skeleton for any model provided to the machine learning workflow. The
    model can always be accessed by calling ModelClass.model or directly by an model method without parsing the model
    attribute name (e.g. ModelClass.model.compile -> ModelClass.compile). Beside the model, this class provides the
    corresponding loss function.
    """

    def __init__(self) -> None:

        """
        Predefine internal attributes for model and loss.
        """

        self.__model = None
        self.__loss = None

    def __getattr__(self, name: str) -> Any:

        """
        Is called if __getattribute__ is not able to find requested attribute. Normally, the model class is saved into
        a variable like `model = ModelClass()`. To bypass a call like `model.model` to access the _model attribute,
        this method tries to search for the named attribute in the self.model namespace and returns this attribute if
        available. Therefore, following expression is true: `ModelClass().compile == ModelClass().model.compile` as long
        the called attribute/method is not part if the ModelClass itself.
        :param name: name of the attribute or method to call
        :return: attribute or method from self.model namespace
        """

        return self.model.__getattribute__(name)

    @property
    def model(self) -> keras.Model:

        """
        The model property containing a keras.Model instance.
        :return: the keras model
        """

        return self.__model

    @model.setter
    def model(self, value):
        self.__model = value

    @property
    def loss(self) -> Callable:

        """
        The loss property containing a callable loss function. The loss function can be any keras loss or a customised
        function. If the loss is a customised function, it must contain the internal loss(y_true, y_pred) function:
            def customised_loss(args):
                def loss(y_true, y_pred):
                    return actual_function(y_true, y_pred, args)
            return loss
        :return: the loss function
        """

        return self.__loss

    @loss.setter
    def loss(self, value) -> None:
        self.__loss = value

    def get_settings(self):
        return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__"))


class MyLittleModel(AbstractModelClass):

    """
    A customised model with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the
    output layer depending on the window_lead_time parameter. Dropout is used between the Convolution and the first
    Dense layer.
    """

    def __init__(self, window_history_size, window_lead_time, channels):

        """
        Sets model and loss depending on the given arguments.
        :param activation: activation function
        :param window_history_size: number of historical time steps included in the input data
        :param channels: number of variables used in input data
        :param regularizer: <not used here>
        :param dropout_rate: dropout rate used in the model [0, 1)
        :param window_lead_time: number of time steps to forecast in the output layer
        """

        super().__init__()

        # settings
        self.window_history_size = window_history_size
        self.window_lead_time = window_lead_time
        self.channels = channels
        self.dropout_rate = 0.1
        self.regularizer = keras.regularizers.l2(0.1)
        self.initial_lr = 1e-2
        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
        self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
        self.epochs = 2
        self.batch_size = int(256)
        self.activation = keras.layers.PReLU

        # apply to model
        self.set_model()
        self.set_loss()

    def set_model(self):

        """
        Build the model.
        :param activation: activation function
        :param window_history_size: number of historical time steps included in the input data
        :param channels: number of variables used in input data
        :param dropout_rate: dropout rate used in the model [0, 1)
        :param window_lead_time: number of time steps to forecast in the output layer
        :return: built keras model
        """

        # add 1 to window_size to include current time step t0
        x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels))
        x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input)
        x_in = self.activation(name='{}_conv_act'.format("major"))(x_in)
        x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
        x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in)
        x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in)
        x_in = self.activation()(x_in)
        x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in)
        x_in = self.activation()(x_in)
        x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
        x_in = self.activation()(x_in)
        x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
        out_main = self.activation()(x_in)
        self.model = keras.Model(inputs=x_input, outputs=[out_main])

    def set_loss(self):

        """
        Set the loss
        :return: loss function
        """

        self.loss = keras.losses.mean_squared_error


class MyBranchedModel(AbstractModelClass):

    """
    A customised model


    with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the
    output layer depending on the window_lead_time parameter. Dropout is used between the Convolution and the first
    Dense layer.
    """

    def __init__(self, window_history_size, window_lead_time, channels):

        """
        Sets model and loss depending on the given arguments.
        :param activation: activation function
        :param window_history_size: number of historical time steps included in the input data
        :param channels: number of variables used in input data
        :param regularizer: <not used here>
        :param dropout_rate: dropout rate used in the model [0, 1)
        :param window_lead_time: number of time steps to forecast in the output layer
        """

        super().__init__()

        # settings
        self.window_history_size = window_history_size
        self.window_lead_time = window_lead_time
        self.channels = channels
        self.dropout_rate = 0.1
        self.regularizer = keras.regularizers.l2(0.1)
        self.initial_lr = 1e-2
        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
        self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
        self.epochs = 2
        self.batch_size = int(256)
        self.activation = keras.layers.PReLU

        # apply to model
        self.set_model()
        self.set_loss()

    def set_model(self):

        """
        Build the model.
        :param activation: activation function
        :param window_history_size: number of historical time steps included in the input data
        :param channels: number of variables used in input data
        :param dropout_rate: dropout rate used in the model [0, 1)
        :param window_lead_time: number of time steps to forecast in the output layer
        :return: built keras model
        """

        # add 1 to window_size to include current time step t0
        x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels))
        x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input)
        x_in = self.activation(name='{}_conv_act'.format("major"))(x_in)
        x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
        x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in)
        x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in)
        x_in = self.activation()(x_in)
        out_minor_1 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_1"))(x_in)
        out_minor_1 = self.activation()(out_minor_1)
        x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in)
        x_in = self.activation()(x_in)
        out_minor_2 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_2"))(x_in)
        out_minor_2 = self.activation()(out_minor_2)
        x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
        x_in = self.activation()(x_in)
        x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
        out_main = self.activation()(x_in)
        self.model = keras.Model(inputs=x_input, outputs=[out_minor_1, out_minor_2, out_main])

    def set_loss(self):

        """
        Set the loss
        :return: loss function
        """

        self.loss = [keras.losses.mean_absolute_error] + [keras.losses.mean_squared_error] + \
                    [keras.losses.mean_squared_error]