Skip to content
Snippets Groups Projects
Commit 5cb07ed5 authored by leufen1's avatar leufen1
Browse files

first fix to make training resumption possible, but this will slow down the...

first fix to make training resumption possible, but this will slow down the training by 5-10s per epoch (if model improves)
parent b9c5cf74
No related branches found
No related tags found
4 merge requests!253include current develop,!252Resolve "release v1.3.0",!245update #275 branch,!242Resolve "BUG: loading of custom objects not working"
Pipeline #60054 failed
......@@ -9,9 +9,12 @@ import pickle
from typing import Union, List
from typing_extensions import TypedDict
from mlair.helpers import TimeTracking
import numpy as np
from keras import backend as K
from keras.callbacks import History, ModelCheckpoint, Callback
import keras
from mlair import helpers
......@@ -147,6 +150,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def __init__(self, *args, **kwargs):
"""Initialise ModelCheckpointAdvanced and set callbacks attribute."""
self.callbacks = kwargs.pop("callbacks")
self.custom_objects = kwargs.pop("custom_objects")
super().__init__(*args, **kwargs)
def update_best(self, hist):
......@@ -174,17 +178,29 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def on_epoch_end(self, epoch, logs=None):
"""Save model as usual (see ModelCheckpoint class), but also save additional callbacks."""
super().on_epoch_end(epoch, logs)
for callback in self.callbacks:
print(callback.keys())
file_path = callback["path"]
if self.epochs_since_last_save == 0 and epoch != 0:
if self.save_best_only:
current = logs.get(self.monitor)
model_save = None
if hasattr(callback["callback"], "model"):
# ToDo: store model in cache?
callback["callback"].model.save(
callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5")
callback["callback"].model = None
if current == self.best:
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f:
# ToDo: create "save" method
pickle.dump(callback["callback"], f)
if callback["callback"].model is None:
with TimeTracking("load_model"):
callback["callback"].model = keras.models.load_model(
callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5",
custom_objects=callback["custom_objects"])
else:
with open(file_path, "wb") as f:
if self.verbose > 0: # pragma: no branch
......@@ -253,12 +269,17 @@ class CallbackHandler:
def __init__(self):
"""Initialise CallbackHandler."""
self.__callbacks: List[clbk_type] = []
self.custom_objects = {}
self._checkpoint = None
self.editable = True
def add_custom_objects(self, custom_objects):
self.custom_objects = custom_objects
@property
def _callbacks(self):
return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks]
return [{"callback": clbk[clbk["name"]], "path": clbk["path"], "custom_objects": self.custom_objects} for clbk
in self.__callbacks]
@_callbacks.setter
def _callbacks(self, value):
......@@ -313,7 +334,8 @@ class CallbackHandler:
"""Return all callbacks and append checkpoint if available on last position."""
clbks = self._callbacks
if self._checkpoint is not None:
clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath,
"custom_objects": self.custom_objects}]
return clbks
def get_checkpoint(self) -> ModelCheckpointAdvanced:
......@@ -323,7 +345,8 @@ class CallbackHandler:
def create_model_checkpoint(self, **kwargs):
"""Create a model checkpoint and enable edit."""
self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, custom_objects=self.custom_objects,
**kwargs)
self.editable = False
def load_callbacks(self) -> None:
......
......@@ -669,10 +669,10 @@ class MyPaperModel(AbstractModelClass):
conv_settings_dict1 = {
'tower_1': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (3, 1),
'activation': activation},
'tower_2': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (5, 1),
'activation': activation},
'tower_3': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (1, 1),
'activation': activation},
# 'tower_2': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (5, 1),
# 'activation': activation},
# 'tower_3': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (1, 1),
# 'activation': activation},
# 'tower_4':{'reduction_filter':8, 'tower_filter':8*2, 'tower_kernel':(7,1), 'activation':activation},
}
pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation}
......@@ -680,10 +680,10 @@ class MyPaperModel(AbstractModelClass):
conv_settings_dict2 = {
'tower_1': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1),
'activation': activation},
'tower_2': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
'activation': activation},
'tower_3': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
'activation': activation},
# 'tower_2': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
# 'activation': activation},
# 'tower_3': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
# 'activation': activation},
# 'tower_4':{'reduction_filter':8*2, 'tower_filter':16*2, 'tower_kernel':(7,1), 'activation':activation},
}
pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation}
......
......@@ -118,6 +118,7 @@ class ModelSetup(RunEnvironment):
hist = HistoryAdvanced()
self.data_store.set("hist", hist, scope="model")
callbacks = CallbackHandler()
callbacks.add_custom_objects(self.model.custom_objects)
if lr is not None:
callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment