__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-12-02'


import logging
import os

import keras
import tensorflow as tf
from keras import losses

from src.helpers import l_p_loss
from src.model_modules.flatten import flatten_tail
from src.model_modules.inception_model import InceptionModelBase
from src.model_modules.keras_extensions import HistoryAdvanced, ModelCheckpointAdvanced
# from src.model_modules.model_class import MyBranchedModel as MyModel
from src.model_modules.model_class import MyLittleModel as MyModel
from src.run_modules.run_environment import RunEnvironment


class ModelSetup(RunEnvironment):

    def __init__(self):

        # create run framework
        super().__init__()
        self.model = None
        path = self.data_store.get("experiment_path", "general")
        exp_name = self.data_store.get("experiment_name", "general")
        self.scope = "general.model"
        self.path = os.path.join(path, f"{exp_name}_%s")
        self.model_name = self.path % "%s.h5"
        self.checkpoint_name = self.path % "model-best.h5"
        self.callbacks_name = self.path % "model-best-callbacks-%s.pickle"
        self._trainable = self.data_store.get("trainable", "general")
        self._create_new_model = self.data_store.get("create_new_model", "general")
        self._run()

    def _run(self):

        # set channels depending on inputs
        self._set_channels()

        # build model graph using settings from my_model_settings()
        self.build_model()

        # plot model structure
        self.plot_model()

        # load weights if no training shall be performed
        if not self._trainable and not self._create_new_model:
            self.load_weights()

        # create checkpoint
        self._set_checkpoint()

        # compile model
        self.compile_model()

    def _set_channels(self):
        channels = self.data_store.get("generator", "general.train")[0][0].shape[-1]
        self.data_store.set("channels", channels, self.scope)

    def compile_model(self):
        optimizer = self.data_store.get("optimizer", self.scope)
        loss = self.model.loss
        self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"])
        self.data_store.set("model", self.model, self.scope)

    def _set_checkpoint(self):
        """
        Must be run after all callback functions that shall be tracked during training have been created (currently this
        affects the learning rate decay and the advanced history [actually created in this method]).
        """
        lr = self.data_store.get("lr_decay", scope="general.model")
        hist = HistoryAdvanced()
        self.data_store.set("hist", hist, scope="general.model")
        callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"},
                     {"callback": hist, "path": self.callbacks_name % "hist"}]
        checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
                                             save_best_only=True, mode='auto', callbacks=callbacks)
        self.data_store.set("checkpoint", checkpoint, self.scope)

    def load_weights(self):
        try:
            self.model.load_weights(self.model_name)
            logging.info(f"reload weights from model {self.model_name} ...")
        except OSError:
            logging.info('no weights to reload...')

    def build_model(self):
        args_list = ["window_history_size", "window_lead_time", "channels"]
        args = self.data_store.create_args_dict(args_list, self.scope)
        self.model = MyModel(**args)
        self.get_model_settings()

    def get_model_settings(self):
        model_settings = self.model.get_settings()
        self.data_store.set_args_from_dict(model_settings, self.scope)
        self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model")
        self.data_store.set("model_name", self.model_name, self.scope)

    def plot_model(self):  # pragma: no cover
        with tf.device("/cpu:0"):
            file_name = f"{self.model_name.split(sep='.')[0]}.pdf"
            keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)


def my_loss():
    loss = l_p_loss(4)
    keras_loss = losses.mean_squared_error
    loss_all = [loss] + [keras_loss]
    return loss_all


def my_little_loss():
    return losses.mean_squared_error


def my_little_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time):

    X_input = keras.layers.Input(
        shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
    X_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(X_input)
    X_in = activation(name='{}_conv_act'.format("major"))(X_in)
    X_in = keras.layers.Flatten(name='{}'.format("major"))(X_in)
    X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_1'.format("major"))(X_in)
    X_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(X_in)
    X_in = activation()(X_in)
    X_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(X_in)
    X_in = activation()(X_in)
    X_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(X_in)
    X_in = activation()(X_in)
    X_in = keras.layers.Dense(window_lead_time, name='{}_Dense'.format("major"))(X_in)
    out_main = activation()(X_in)
    return keras.Model(inputs=X_input, outputs=[out_main])


def my_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time):

    conv_settings_dict1 = {
        'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
        'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation},
        'tower_3': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (1, 1), 'activation': activation},
    }

    pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}

    conv_settings_dict2 = {'tower_1': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (3, 1),
                                       'activation': activation},
                           'tower_2': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (5, 1),
                                       'activation': activation},
                           'tower_3': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (1, 1),
                                       'activation': activation},
                           }
    pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation}

    conv_settings_dict3 = {'tower_1': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1),
                                       'activation': activation},
                           'tower_2': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
                                       'activation': activation},
                           'tower_3': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
                                       'activation': activation},
                           }

    pool_settings_dict3 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation}

    ##########################################
    inception_model = InceptionModelBase()

    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0

    X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1, regularizer=regularizer,
                                           batch_normalisation=True)

    out_minor = flatten_tail(X_in, 'Minor_1', bound_weight=True, activation=activation, dropout_rate=dropout_rate,
                             reduction_filter=4, first_dense=32, window_lead_time=window_lead_time)

    X_in = keras.layers.Dropout(dropout_rate)(X_in)

    X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=regularizer,
                                           batch_normalisation=True)

    X_in = keras.layers.Dropout(dropout_rate)(X_in)

    X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, regularizer=regularizer,
                                           batch_normalisation=True)
    #############################################

    out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=dropout_rate,
                            reduction_filter=64, first_dense=64, window_lead_time=window_lead_time)

    return keras.Model(inputs=X_input, outputs=[out_minor, out_main])