diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py index c41d722197c2529f04f6643cc72b51f0d3fe0087..479913811a668d8330a389b2876360f096f57dbf 100644 --- a/src/model_modules/keras_extensions.py +++ b/src/model_modules/keras_extensions.py @@ -199,7 +199,9 @@ class CallbackHandler: r"""Use the CallbackHandler for better controlling of custom callbacks. The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position - if required. You can add an arbitrary number of callbacks to the handler. + if required. You can add an arbitrary number of callbacks to the handler. First, add all callbacks and finally + create the model checkpoint. Callbacks that have been added after checkpoint create wouldn't be part if it. + Therefore, the handler blocks adding of new callbacks after creation of model checkpoint. .. code-block:: python diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 92357ab94d79dfb57bb9ffeabde850d920a266c9..f0c42dedd1cfe38379badc185324b6c042d72cbd 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -1,3 +1,5 @@ +"""Model setup module.""" + __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-12-02' @@ -16,10 +18,44 @@ from src.run_modules.run_environment import RunEnvironment class ModelSetup(RunEnvironment): + """ + Set up the model. + + Schedule of model setup: + #. set channels (from variables dimension) + #. build imported model + #. plot model architecture + #. load weights if enabled (e.g. to resume a training) + #. set callbacks and checkpoint + #. compile model + + Required objects [scope] from data store: + * `experiment_path` [.] + * `experiment_name` [.] + * `trainable` [.] + * `create_new_model` [.] + * `generator` [train] + * `window_lead_time` [.] + * `window_history_size` [.] + + Optional objects + * `lr_decay` [model] + + Sets + * `channels` [model] + * `model` [model] + * `hist` [model] + * `callbacks` [model] + * `model_name` [model] + * all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model] + + Creates + * plot of model architecture in `<model_name>.pdf` + + """ def __init__(self): - - # create run framework + """Initialise and run model setup.""" super().__init__() self.model = None path = self.data_store.get("experiment_path") @@ -55,10 +91,12 @@ class ModelSetup(RunEnvironment): self.compile_model() def _set_channels(self): + """Set channels as number of variables of train generator.""" channels = self.data_store.get("generator", "train")[0][0].shape[-1] self.data_store.set("channels", channels, self.scope) def compile_model(self): + """Compile model with optimizer and loss.""" optimizer = self.data_store.get("optimizer", self.scope) loss = self.model.loss self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) @@ -66,14 +104,15 @@ class ModelSetup(RunEnvironment): def _set_callbacks(self): """ - Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the - advanced model checkpoint is added. + Set all callbacks for the training phase. + + Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added. """ - lr = self.data_store.get_default("lr_decay", scope="model", default=None) + lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) hist = HistoryAdvanced() self.data_store.set("hist", hist, scope="model") callbacks = CallbackHandler() - if lr: + if lr is not None: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', @@ -81,6 +120,7 @@ class ModelSetup(RunEnvironment): self.data_store.set("callbacks", callbacks, self.scope) def load_weights(self): + """Try to load weights from existing model or skip if not possible.""" try: self.model.load_weights(self.model_name) logging.info(f"reload weights from model {self.model_name} ...") @@ -88,18 +128,21 @@ class ModelSetup(RunEnvironment): logging.info('no weights to reload...') def build_model(self): + """Build model using window_history_size, window_lead_time and channels from data store.""" args_list = ["window_history_size", "window_lead_time", "channels"] args = self.data_store.create_args_dict(args_list, self.scope) self.model = MyModel(**args) self.get_model_settings() def get_model_settings(self): + """Load all model settings and store in data store.""" model_settings = self.model.get_settings() self.data_store.set_from_dict(model_settings, self.scope) self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") self.data_store.set("model_name", self.model_name, self.scope) def plot_model(self): # pragma: no cover + """Plot model architecture as `<model_name>.pdf`.""" with tf.device("/cpu:0"): file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)