"""Model setup module.""" __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-12-02' import logging import os import keras import tensorflow as tf from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler # from src.model_modules.model_class import MyBranchedModel as MyModel from src.model_modules.model_class import MyLittleModel as MyModel # from src.model_modules.model_class import MyTowerModel as MyModel # from src.model_modules.model_class import MyPaperModel as MyModel from src.run_modules.run_environment import RunEnvironment class ModelSetup(RunEnvironment): """ Set up the model. Schedule of model setup: #. set channels (from variables dimension) #. build imported model #. plot model architecture #. load weights if enabled (e.g. to resume a training) #. set callbacks and checkpoint #. compile model Required objects [scope] from data store: * `experiment_path` [.] * `experiment_name` [.] * `trainable` [.] * `create_new_model` [.] * `generator` [train] * `window_lead_time` [.] * `window_history_size` [.] Optional objects * `lr_decay` [model] Sets * `channels` [model] * `model` [model] * `hist` [model] * `callbacks` [model] * `model_name` [model] * all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model] Creates * plot of model architecture in `<model_name>.pdf` """ def __init__(self): """Initialise and run model setup.""" super().__init__() self.model = None path = self.data_store.get("experiment_path") exp_name = self.data_store.get("experiment_name") self.scope = "model" self.path = os.path.join(path, f"{exp_name}_%s") self.model_name = self.path % "%s.h5" self.checkpoint_name = self.path % "model-best.h5" self.callbacks_name = self.path % "model-best-callbacks-%s.pickle" self._trainable = self.data_store.get("trainable") self._create_new_model = self.data_store.get("create_new_model") self._run() def _run(self): # set channels depending on inputs self._set_channels() # build model graph using settings from my_model_settings() self.build_model() # plot model structure self.plot_model() # load weights if no training shall be performed if not self._trainable and not self._create_new_model: self.load_weights() # create checkpoint self._set_callbacks() # compile model self.compile_model() def _set_channels(self): """Set channels as number of variables of train generator.""" channels = self.data_store.get("generator", "train")[0][0].shape[-1] self.data_store.set("channels", channels, self.scope) def compile_model(self): """Compile model with optimizer and loss.""" optimizer = self.data_store.get("optimizer", self.scope) loss = self.model.loss self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) self.data_store.set("model", self.model, self.scope) def _set_callbacks(self): """ Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added. """ lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) hist = HistoryAdvanced() self.data_store.set("hist", hist, scope="model") callbacks = CallbackHandler() if lr is not None: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') self.data_store.set("callbacks", callbacks, self.scope) def load_weights(self): """Try to load weights from existing model or skip if not possible.""" try: self.model.load_weights(self.model_name) logging.info(f"reload weights from model {self.model_name} ...") except OSError: logging.info('no weights to reload...') def build_model(self): """Build model using window_history_size, window_lead_time and channels from data store.""" args_list = ["window_history_size", "window_lead_time", "channels"] args = self.data_store.create_args_dict(args_list, self.scope) self.model = MyModel(**args) self.get_model_settings() def get_model_settings(self): """Load all model settings and store in data store.""" model_settings = self.model.get_settings() self.data_store.set_from_dict(model_settings, self.scope, log=True) self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") self.data_store.set("model_name", self.model_name, self.scope) def plot_model(self): # pragma: no cover """Plot model architecture as `<model_name>.pdf`.""" with tf.device("/cpu:0"): file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)