__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-12-02'


import logging
import os

import keras
import tensorflow as tf

from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
# from src.model_modules.model_class import MyBranchedModel as MyModel
from src.model_modules.model_class import MyLittleModel as MyModel
# from src.model_modules.model_class import MyTowerModel as MyModel
# from src.model_modules.model_class import MyPaperModel as MyModel
from src.run_modules.run_environment import RunEnvironment


class ModelSetup(RunEnvironment):

    def __init__(self):

        # create run framework
        super().__init__()
        self.model = None
        path = self.data_store.get("experiment_path", "general")
        exp_name = self.data_store.get("experiment_name", "general")
        self.scope = "general.model"
        self.path = os.path.join(path, f"{exp_name}_%s")
        self.model_name = self.path % "%s.h5"
        self.checkpoint_name = self.path % "model-best.h5"
        self.callbacks_name = self.path % "model-best-callbacks-%s.pickle"
        self._trainable = self.data_store.get("trainable", "general")
        self._create_new_model = self.data_store.get("create_new_model", "general")
        self._run()

    def _run(self):

        # set channels depending on inputs
        self._set_channels()

        # build model graph using settings from my_model_settings()
        self.build_model()

        # plot model structure
        self.plot_model()

        # load weights if no training shall be performed
        if not self._trainable and not self._create_new_model:
            self.load_weights()

        # create checkpoint
        self._set_callbacks()

        # compile model
        self.compile_model()

    def _set_channels(self):
        channels = self.data_store.get("generator", "general.train")[0][0].shape[-1]
        self.data_store.set("channels", channels, self.scope)

    def compile_model(self):
        optimizer = self.data_store.get("optimizer", self.scope)
        loss = self.model.loss
        self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"])
        self.data_store.set("model", self.model, self.scope)

    def _set_callbacks(self):
        """
        Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the
        advanced model checkpoint is added.
        """
        lr = self.data_store.get_default("lr_decay", scope="general.model", default=None)
        hist = HistoryAdvanced()
        self.data_store.set("hist", hist, scope="general.model")
        callbacks = CallbackHandler()
        if lr:
            callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
        callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
        callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
                                          save_best_only=True, mode='auto')
        self.data_store.set("callbacks", callbacks, self.scope)

    def load_weights(self):
        try:
            self.model.load_weights(self.model_name)
            logging.info(f"reload weights from model {self.model_name} ...")
        except OSError:
            logging.info('no weights to reload...')

    def build_model(self):
        args_list = ["window_history_size", "window_lead_time", "channels"]
        args = self.data_store.create_args_dict(args_list, self.scope)
        self.model = MyModel(**args)
        self.get_model_settings()

    def get_model_settings(self):
        model_settings = self.model.get_settings()
        self.data_store.set_args_from_dict(model_settings, self.scope)
        self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model")
        self.data_store.set("model_name", self.model_name, self.scope)

    def plot_model(self):  # pragma: no cover
        with tf.device("/cpu:0"):
            file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
            keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)