"""Model setup module."""

__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-12-02'

import logging
import os

import keras
import tensorflow as tf

from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
# from src.model_modules.model_class import MyBranchedModel as MyModel
from src.model_modules.model_class import MyLittleModel as MyModel
# from src.model_modules.model_class import MyTowerModel as MyModel
# from src.model_modules.model_class import MyPaperModel as MyModel
from src.run_modules.run_environment import RunEnvironment


class ModelSetup(RunEnvironment):
    """
    Set up the model.

    Schedule of model setup:
        #. set channels (from variables dimension)
        #. build imported model
        #. plot model architecture
        #. load weights if enabled (e.g. to resume a training)
        #. set callbacks and checkpoint
        #. compile model

    Required objects [scope] from data store:
        * `experiment_path` [.]
        * `experiment_name` [.]
        * `trainable` [.]
        * `create_new_model` [.]
        * `generator` [train]
        * `window_lead_time` [.]
        * `window_history_size` [.]

    Optional objects
        * `lr_decay` [model]

    Sets
        * `channels` [model]
        * `model` [model]
        * `hist` [model]
        * `callbacks` [model]
        * `model_name` [model]
        * all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model]

    Creates
        * plot of model architecture in `<model_name>.pdf`

    """

    def __init__(self):
        """Initialise and run model setup."""
        super().__init__()
        self.model = None
        path = self.data_store.get("experiment_path")
        exp_name = self.data_store.get("experiment_name")
        self.scope = "model"
        self.path = os.path.join(path, f"{exp_name}_%s")
        self.model_name = self.path % "%s.h5"
        self.checkpoint_name = self.path % "model-best.h5"
        self.callbacks_name = self.path % "model-best-callbacks-%s.pickle"
        self._trainable = self.data_store.get("trainable")
        self._create_new_model = self.data_store.get("create_new_model")
        self._run()

    def _run(self):

        # set channels depending on inputs
        self._set_channels()

        # build model graph using settings from my_model_settings()
        self.build_model()

        # plot model structure
        self.plot_model()

        # load weights if no training shall be performed
        if not self._trainable and not self._create_new_model:
            self.load_weights()

        # create checkpoint
        self._set_callbacks()

        # compile model
        self.compile_model()

    def _set_channels(self):
        """Set channels as number of variables of train generator."""
        channels = self.data_store.get("generator", "train")[0][0].shape[-1]
        self.data_store.set("channels", channels, self.scope)

    def compile_model(self):
        """Compile model with optimizer and loss."""
        optimizer = self.data_store.get("optimizer", self.scope)
        loss = self.model.loss
        self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"])
        self.data_store.set("model", self.model, self.scope)

    def _set_callbacks(self):
        """
        Set all callbacks for the training phase.

        Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added.
        """
        lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
        hist = HistoryAdvanced()
        self.data_store.set("hist", hist, scope="model")
        callbacks = CallbackHandler()
        if lr is not None:
            callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
        callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
        callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
                                          save_best_only=True, mode='auto')
        self.data_store.set("callbacks", callbacks, self.scope)

    def load_weights(self):
        """Try to load weights from existing model or skip if not possible."""
        try:
            self.model.load_weights(self.model_name)
            logging.info(f"reload weights from model {self.model_name} ...")
        except OSError:
            logging.info('no weights to reload...')

    def build_model(self):
        """Build model using window_history_size, window_lead_time and channels from data store."""
        args_list = ["window_history_size", "window_lead_time", "channels"]
        args = self.data_store.create_args_dict(args_list, self.scope)
        self.model = MyModel(**args)
        self.get_model_settings()

    def get_model_settings(self):
        """Load all model settings and store in data store."""
        model_settings = self.model.get_settings()
        self.data_store.set_from_dict(model_settings, self.scope, log=True)
        self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model")
        self.data_store.set("model_name", self.model_name, self.scope)

    def plot_model(self):  # pragma: no cover
        """Plot model architecture as `<model_name>.pdf`."""
        with tf.device("/cpu:0"):
            file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
            keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)