Skip to content
Snippets Groups Projects
Commit 8fb227b7 authored by lukas leufen's avatar lukas leufen
Browse files

docs for model setup

parent 0c6804cd
No related branches found
No related tags found
3 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!91WIP: Resolve "create sphinx docu"
Pipeline #35632 failed
...@@ -199,7 +199,9 @@ class CallbackHandler: ...@@ -199,7 +199,9 @@ class CallbackHandler:
r"""Use the CallbackHandler for better controlling of custom callbacks. r"""Use the CallbackHandler for better controlling of custom callbacks.
The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position
if required. You can add an arbitrary number of callbacks to the handler. if required. You can add an arbitrary number of callbacks to the handler. First, add all callbacks and finally
create the model checkpoint. Callbacks that have been added after checkpoint create wouldn't be part if it.
Therefore, the handler blocks adding of new callbacks after creation of model checkpoint.
.. code-block:: python .. code-block:: python
......
"""Model setup module."""
__author__ = "Lukas Leufen, Felix Kleinert" __author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-12-02' __date__ = '2019-12-02'
...@@ -16,10 +18,44 @@ from src.run_modules.run_environment import RunEnvironment ...@@ -16,10 +18,44 @@ from src.run_modules.run_environment import RunEnvironment
class ModelSetup(RunEnvironment): class ModelSetup(RunEnvironment):
"""
Set up the model.
Schedule of model setup:
#. set channels (from variables dimension)
#. build imported model
#. plot model architecture
#. load weights if enabled (e.g. to resume a training)
#. set callbacks and checkpoint
#. compile model
Required objects [scope] from data store:
* `experiment_path` [.]
* `experiment_name` [.]
* `trainable` [.]
* `create_new_model` [.]
* `generator` [train]
* `window_lead_time` [.]
* `window_history_size` [.]
Optional objects
* `lr_decay` [model]
Sets
* `channels` [model]
* `model` [model]
* `hist` [model]
* `callbacks` [model]
* `model_name` [model]
* all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model]
Creates
* plot of model architecture in `<model_name>.pdf`
def __init__(self): """
# create run framework def __init__(self):
"""Initialise and run model setup."""
super().__init__() super().__init__()
self.model = None self.model = None
path = self.data_store.get("experiment_path") path = self.data_store.get("experiment_path")
...@@ -55,10 +91,12 @@ class ModelSetup(RunEnvironment): ...@@ -55,10 +91,12 @@ class ModelSetup(RunEnvironment):
self.compile_model() self.compile_model()
def _set_channels(self): def _set_channels(self):
"""Set channels as number of variables of train generator."""
channels = self.data_store.get("generator", "train")[0][0].shape[-1] channels = self.data_store.get("generator", "train")[0][0].shape[-1]
self.data_store.set("channels", channels, self.scope) self.data_store.set("channels", channels, self.scope)
def compile_model(self): def compile_model(self):
"""Compile model with optimizer and loss."""
optimizer = self.data_store.get("optimizer", self.scope) optimizer = self.data_store.get("optimizer", self.scope)
loss = self.model.loss loss = self.model.loss
self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"])
...@@ -66,14 +104,15 @@ class ModelSetup(RunEnvironment): ...@@ -66,14 +104,15 @@ class ModelSetup(RunEnvironment):
def _set_callbacks(self): def _set_callbacks(self):
""" """
Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the Set all callbacks for the training phase.
advanced model checkpoint is added.
Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added.
""" """
lr = self.data_store.get_default("lr_decay", scope="model", default=None) lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
hist = HistoryAdvanced() hist = HistoryAdvanced()
self.data_store.set("hist", hist, scope="model") self.data_store.set("hist", hist, scope="model")
callbacks = CallbackHandler() callbacks = CallbackHandler()
if lr: if lr is not None:
callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
...@@ -81,6 +120,7 @@ class ModelSetup(RunEnvironment): ...@@ -81,6 +120,7 @@ class ModelSetup(RunEnvironment):
self.data_store.set("callbacks", callbacks, self.scope) self.data_store.set("callbacks", callbacks, self.scope)
def load_weights(self): def load_weights(self):
"""Try to load weights from existing model or skip if not possible."""
try: try:
self.model.load_weights(self.model_name) self.model.load_weights(self.model_name)
logging.info(f"reload weights from model {self.model_name} ...") logging.info(f"reload weights from model {self.model_name} ...")
...@@ -88,18 +128,21 @@ class ModelSetup(RunEnvironment): ...@@ -88,18 +128,21 @@ class ModelSetup(RunEnvironment):
logging.info('no weights to reload...') logging.info('no weights to reload...')
def build_model(self): def build_model(self):
"""Build model using window_history_size, window_lead_time and channels from data store."""
args_list = ["window_history_size", "window_lead_time", "channels"] args_list = ["window_history_size", "window_lead_time", "channels"]
args = self.data_store.create_args_dict(args_list, self.scope) args = self.data_store.create_args_dict(args_list, self.scope)
self.model = MyModel(**args) self.model = MyModel(**args)
self.get_model_settings() self.get_model_settings()
def get_model_settings(self): def get_model_settings(self):
"""Load all model settings and store in data store."""
model_settings = self.model.get_settings() model_settings = self.model.get_settings()
self.data_store.set_from_dict(model_settings, self.scope) self.data_store.set_from_dict(model_settings, self.scope)
self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model")
self.data_store.set("model_name", self.model_name, self.scope) self.data_store.set("model_name", self.model_name, self.scope)
def plot_model(self): # pragma: no cover def plot_model(self): # pragma: no cover
"""Plot model architecture as `<model_name>.pdf`."""
with tf.device("/cpu:0"): with tf.device("/cpu:0"):
file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment