Skip to content
Snippets Groups Projects
Commit 5cb07ed5 authored by leufen1's avatar leufen1
Browse files

first fix to make training resumption possible, but this will slow down the...

first fix to make training resumption possible, but this will slow down the training by 5-10s per epoch (if model improves)
parent b9c5cf74
No related branches found
No related tags found
4 merge requests!253include current develop,!252Resolve "release v1.3.0",!245update #275 branch,!242Resolve "BUG: loading of custom objects not working"
Pipeline #60054 failed
...@@ -9,9 +9,12 @@ import pickle ...@@ -9,9 +9,12 @@ import pickle
from typing import Union, List from typing import Union, List
from typing_extensions import TypedDict from typing_extensions import TypedDict
from mlair.helpers import TimeTracking
import numpy as np import numpy as np
from keras import backend as K from keras import backend as K
from keras.callbacks import History, ModelCheckpoint, Callback from keras.callbacks import History, ModelCheckpoint, Callback
import keras
from mlair import helpers from mlair import helpers
...@@ -147,6 +150,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint): ...@@ -147,6 +150,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialise ModelCheckpointAdvanced and set callbacks attribute.""" """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
self.callbacks = kwargs.pop("callbacks") self.callbacks = kwargs.pop("callbacks")
self.custom_objects = kwargs.pop("custom_objects")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def update_best(self, hist): def update_best(self, hist):
...@@ -174,17 +178,29 @@ class ModelCheckpointAdvanced(ModelCheckpoint): ...@@ -174,17 +178,29 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
"""Save model as usual (see ModelCheckpoint class), but also save additional callbacks.""" """Save model as usual (see ModelCheckpoint class), but also save additional callbacks."""
super().on_epoch_end(epoch, logs) super().on_epoch_end(epoch, logs)
for callback in self.callbacks: for callback in self.callbacks:
print(callback.keys())
file_path = callback["path"] file_path = callback["path"]
if self.epochs_since_last_save == 0 and epoch != 0: if self.epochs_since_last_save == 0 and epoch != 0:
if self.save_best_only: if self.save_best_only:
current = logs.get(self.monitor) current = logs.get(self.monitor)
model_save = None
if hasattr(callback["callback"], "model"):
# ToDo: store model in cache?
callback["callback"].model.save(
callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5")
callback["callback"].model = None
if current == self.best: if current == self.best:
if self.verbose > 0: # pragma: no branch if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
# ToDo: create "save" method
pickle.dump(callback["callback"], f) pickle.dump(callback["callback"], f)
if callback["callback"].model is None:
with TimeTracking("load_model"):
callback["callback"].model = keras.models.load_model(
callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5",
custom_objects=callback["custom_objects"])
else: else:
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
if self.verbose > 0: # pragma: no branch if self.verbose > 0: # pragma: no branch
...@@ -253,12 +269,17 @@ class CallbackHandler: ...@@ -253,12 +269,17 @@ class CallbackHandler:
def __init__(self): def __init__(self):
"""Initialise CallbackHandler.""" """Initialise CallbackHandler."""
self.__callbacks: List[clbk_type] = [] self.__callbacks: List[clbk_type] = []
self.custom_objects = {}
self._checkpoint = None self._checkpoint = None
self.editable = True self.editable = True
def add_custom_objects(self, custom_objects):
self.custom_objects = custom_objects
@property @property
def _callbacks(self): def _callbacks(self):
return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks] return [{"callback": clbk[clbk["name"]], "path": clbk["path"], "custom_objects": self.custom_objects} for clbk
in self.__callbacks]
@_callbacks.setter @_callbacks.setter
def _callbacks(self, value): def _callbacks(self, value):
...@@ -313,7 +334,8 @@ class CallbackHandler: ...@@ -313,7 +334,8 @@ class CallbackHandler:
"""Return all callbacks and append checkpoint if available on last position.""" """Return all callbacks and append checkpoint if available on last position."""
clbks = self._callbacks clbks = self._callbacks
if self._checkpoint is not None: if self._checkpoint is not None:
clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}] clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath,
"custom_objects": self.custom_objects}]
return clbks return clbks
def get_checkpoint(self) -> ModelCheckpointAdvanced: def get_checkpoint(self) -> ModelCheckpointAdvanced:
...@@ -323,7 +345,8 @@ class CallbackHandler: ...@@ -323,7 +345,8 @@ class CallbackHandler:
def create_model_checkpoint(self, **kwargs): def create_model_checkpoint(self, **kwargs):
"""Create a model checkpoint and enable edit.""" """Create a model checkpoint and enable edit."""
self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs) self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, custom_objects=self.custom_objects,
**kwargs)
self.editable = False self.editable = False
def load_callbacks(self) -> None: def load_callbacks(self) -> None:
......
...@@ -669,10 +669,10 @@ class MyPaperModel(AbstractModelClass): ...@@ -669,10 +669,10 @@ class MyPaperModel(AbstractModelClass):
conv_settings_dict1 = { conv_settings_dict1 = {
'tower_1': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (3, 1), 'tower_1': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (3, 1),
'activation': activation}, 'activation': activation},
'tower_2': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (5, 1), # 'tower_2': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (5, 1),
'activation': activation}, # 'activation': activation},
'tower_3': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (1, 1), # 'tower_3': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (1, 1),
'activation': activation}, # 'activation': activation},
# 'tower_4':{'reduction_filter':8, 'tower_filter':8*2, 'tower_kernel':(7,1), 'activation':activation}, # 'tower_4':{'reduction_filter':8, 'tower_filter':8*2, 'tower_kernel':(7,1), 'activation':activation},
} }
pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation} pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation}
...@@ -680,10 +680,10 @@ class MyPaperModel(AbstractModelClass): ...@@ -680,10 +680,10 @@ class MyPaperModel(AbstractModelClass):
conv_settings_dict2 = { conv_settings_dict2 = {
'tower_1': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1), 'tower_1': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1),
'activation': activation}, 'activation': activation},
'tower_2': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1), # 'tower_2': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
'activation': activation}, # 'activation': activation},
'tower_3': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1), # 'tower_3': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
'activation': activation}, # 'activation': activation},
# 'tower_4':{'reduction_filter':8*2, 'tower_filter':16*2, 'tower_kernel':(7,1), 'activation':activation}, # 'tower_4':{'reduction_filter':8*2, 'tower_filter':16*2, 'tower_kernel':(7,1), 'activation':activation},
} }
pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation} pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation}
......
...@@ -118,6 +118,7 @@ class ModelSetup(RunEnvironment): ...@@ -118,6 +118,7 @@ class ModelSetup(RunEnvironment):
hist = HistoryAdvanced() hist = HistoryAdvanced()
self.data_store.set("hist", hist, scope="model") self.data_store.set("hist", hist, scope="model")
callbacks = CallbackHandler() callbacks = CallbackHandler()
callbacks.add_custom_objects(self.model.custom_objects)
if lr is not None: if lr is not None:
callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment