__author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-12-02' import logging import os import keras import tensorflow as tf from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler # from src.model_modules.model_class import MyBranchedModel as MyModel from src.model_modules.model_class import MyLittleModel as MyModel # from src.model_modules.model_class import MyTowerModel as MyModel # from src.model_modules.model_class import MyPaperModel as MyModel from src.run_modules.run_environment import RunEnvironment class ModelSetup(RunEnvironment): def __init__(self): # create run framework super().__init__() self.model = None path = self.data_store.get("experiment_path", "general") exp_name = self.data_store.get("experiment_name", "general") self.scope = "general.model" self.path = os.path.join(path, f"{exp_name}_%s") self.model_name = self.path % "%s.h5" self.checkpoint_name = self.path % "model-best.h5" self.callbacks_name = self.path % "model-best-callbacks-%s.pickle" self._trainable = self.data_store.get("trainable", "general") self._create_new_model = self.data_store.get("create_new_model", "general") self._run() def _run(self): # set channels depending on inputs self._set_channels() # build model graph using settings from my_model_settings() self.build_model() # plot model structure self.plot_model() # load weights if no training shall be performed if not self._trainable and not self._create_new_model: self.load_weights() # create checkpoint self._set_callbacks() # compile model self.compile_model() def _set_channels(self): channels = self.data_store.get("generator", "general.train")[0][0].shape[-1] self.data_store.set("channels", channels, self.scope) def compile_model(self): optimizer = self.data_store.get("optimizer", self.scope) loss = self.model.loss self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) self.data_store.set("model", self.model, self.scope) def _set_callbacks(self): """ Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added. """ lr = self.data_store.get_default("lr_decay", scope="general.model", default=None) hist = HistoryAdvanced() self.data_store.set("hist", hist, scope="general.model") callbacks = CallbackHandler() if lr: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') self.data_store.set("callbacks", callbacks, self.scope) def load_weights(self): try: self.model.load_weights(self.model_name) logging.info(f"reload weights from model {self.model_name} ...") except OSError: logging.info('no weights to reload...') def build_model(self): args_list = ["window_history_size", "window_lead_time", "channels"] args = self.data_store.create_args_dict(args_list, self.scope) self.model = MyModel(**args) self.get_model_settings() def get_model_settings(self): model_settings = self.model.get_settings() self.data_store.set_args_from_dict(model_settings, self.scope) self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") self.data_store.set("model_name", self.model_name, self.scope) def plot_model(self): # pragma: no cover with tf.device("/cpu:0"): file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)