Skip to content
Snippets Groups Projects
Commit de864eae authored by Felix Kleinert's avatar Felix Kleinert
Browse files

introduce 'compile_options' in model_class which combines all allowed args of...

introduce 'compile_options' in model_class which combines all allowed args of keras' .compile() methods. The method 'set_compile_options' has to be defined in child class if additional compile options should be used (optimizer and loss, have to be set anyway in child.__init__ #110
parent 243ca29e
No related branches found
No related tags found
3 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!92introduce 'compile_options' in model_class which combines all allowed args of...
Pipeline #34409 passed
...@@ -8,11 +8,13 @@ from abc import ABC ...@@ -8,11 +8,13 @@ from abc import ABC
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
import keras import keras
import logging
from src.model_modules.inception_model import InceptionModelBase from src.model_modules.inception_model import InceptionModelBase
from src.model_modules.flatten import flatten_tail from src.model_modules.flatten import flatten_tail
from src.model_modules.advanced_paddings import PadUtils, Padding2D from src.model_modules.advanced_paddings import PadUtils, Padding2D
class AbstractModelClass(ABC): class AbstractModelClass(ABC):
""" """
...@@ -32,6 +34,13 @@ class AbstractModelClass(ABC): ...@@ -32,6 +34,13 @@ class AbstractModelClass(ABC):
self.__loss = None self.__loss = None
self.model_name = self.__class__.__name__ self.model_name = self.__class__.__name__
self.__custom_objects = {} self.__custom_objects = {}
self.__allowed_compile_options = {'metrics': None,
'loss_weights': None,
'sample_weight_mode': None,
'weighted_metrics': None,
'target_tensors': None
}
self.__compile_options = self.__allowed_compile_options
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
...@@ -93,6 +102,26 @@ class AbstractModelClass(ABC): ...@@ -93,6 +102,26 @@ class AbstractModelClass(ABC):
def custom_objects(self, value) -> None: def custom_objects(self, value) -> None:
self.__custom_objects = value self.__custom_objects = value
@property
def compile_options(self) -> Callable:
"""
The compile options property allows the user to use all keras.compile() arguments which are not already covered
by __loss and optimizer
:return:
"""
return self.__compile_options
@compile_options.setter
def compile_options(self, value: Dict) -> None:
if not isinstance(value, dict):
raise TypeError(f"`value' has to be a dictionary. But it is {type(value)}")
for new_k, new_v in value.items():
if new_k in self.__allowed_compile_options.keys():
self.__compile_options[new_k] = new_v
else:
logging.warning(
f"`{new_k}' is not a valid additional compile option. Will be ignored in keras.compile()")
def get_settings(self) -> Dict: def get_settings(self) -> Dict:
""" """
Get all class attributes that are not protected in the AbstractModelClass as dictionary. Get all class attributes that are not protected in the AbstractModelClass as dictionary.
...@@ -106,6 +135,21 @@ class AbstractModelClass(ABC): ...@@ -106,6 +135,21 @@ class AbstractModelClass(ABC):
def set_loss(self): def set_loss(self):
pass pass
def set_compile_options(self):
"""
This method only has to be defined in child class, when additional compile options should be used ()
(other options than optimizer and loss)
Has to be set as dictionary: {'metrics': None,
'loss_weights': None,
'sample_weight_mode': None,
'weighted_metrics': None,
'target_tensors': None
}
:return:
"""
pass
def set_custom_objects(self, **kwargs) -> None: def set_custom_objects(self, **kwargs) -> None:
""" """
Set custom objects that are not part of keras framework. These custom objects are needed if an already compiled Set custom objects that are not part of keras framework. These custom objects are needed if an already compiled
...@@ -157,6 +201,7 @@ class MyLittleModel(AbstractModelClass): ...@@ -157,6 +201,7 @@ class MyLittleModel(AbstractModelClass):
# apply to model # apply to model
self.set_model() self.set_model()
self.set_loss() self.set_loss()
self.set_compile_options()
self.set_custom_objects(loss=self.loss) self.set_custom_objects(loss=self.loss)
def set_model(self): def set_model(self):
...@@ -196,6 +241,9 @@ class MyLittleModel(AbstractModelClass): ...@@ -196,6 +241,9 @@ class MyLittleModel(AbstractModelClass):
self.loss = keras.losses.mean_squared_error self.loss = keras.losses.mean_squared_error
def set_compile_options(self):
self.compile_options = {"metrics": ["mse", "mae"]}
class MyBranchedModel(AbstractModelClass): class MyBranchedModel(AbstractModelClass):
...@@ -315,6 +363,7 @@ class MyTowerModel(AbstractModelClass): ...@@ -315,6 +363,7 @@ class MyTowerModel(AbstractModelClass):
# apply to model # apply to model
self.set_model() self.set_model()
self.set_loss() self.set_loss()
self.set_compile_options()
self.set_custom_objects(loss=self.loss) self.set_custom_objects(loss=self.loss)
def set_model(self): def set_model(self):
...@@ -392,6 +441,9 @@ class MyTowerModel(AbstractModelClass): ...@@ -392,6 +441,9 @@ class MyTowerModel(AbstractModelClass):
self.loss = [keras.losses.mean_squared_error] self.loss = [keras.losses.mean_squared_error]
def set_compile_options(self):
self.compile_options = {"metrics": ["mse"]}
class MyPaperModel(AbstractModelClass): class MyPaperModel(AbstractModelClass):
......
...@@ -10,8 +10,8 @@ import tensorflow as tf ...@@ -10,8 +10,8 @@ import tensorflow as tf
from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
# from src.model_modules.model_class import MyBranchedModel as MyModel # from src.model_modules.model_class import MyBranchedModel as MyModel
# from src.model_modules.model_class import MyLittleModel as MyModel from src.model_modules.model_class import MyLittleModel as MyModel
from src.model_modules.model_class import MyTowerModel as MyModel # from src.model_modules.model_class import MyTowerModel as MyModel
# from src.model_modules.model_class import MyPaperModel as MyModel # from src.model_modules.model_class import MyPaperModel as MyModel
from src.run_modules.run_environment import RunEnvironment from src.run_modules.run_environment import RunEnvironment
...@@ -62,7 +62,8 @@ class ModelSetup(RunEnvironment): ...@@ -62,7 +62,8 @@ class ModelSetup(RunEnvironment):
def compile_model(self): def compile_model(self):
optimizer = self.data_store.get("optimizer", self.scope) optimizer = self.data_store.get("optimizer", self.scope)
loss = self.model.loss loss = self.model.loss
self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) compile_options = self.model.compile_options
self.model.compile(optimizer=optimizer, loss=loss, **compile_options)
self.data_store.set("model", self.model, self.scope) self.data_store.set("model", self.model, self.scope)
def _set_callbacks(self): def _set_callbacks(self):
......
...@@ -40,6 +40,16 @@ class TestAbstractModelClass: ...@@ -40,6 +40,16 @@ class TestAbstractModelClass:
amc.loss = keras.losses.mean_absolute_error amc.loss = keras.losses.mean_absolute_error
assert amc.loss == keras.losses.mean_absolute_error assert amc.loss == keras.losses.mean_absolute_error
def test_compile_options_property(self, amc):
amc.compile_options = {"metrics": ["mse", "mae"]}
assert amc.compile_options == {'loss_weights': None, 'metrics': ['mse', 'mae'], 'sample_weight_mode': None,
'target_tensors': None, 'weighted_metrics': None}
def test_compile_options_property_type_error(self, amc):
with pytest.raises(TypeError) as einfo:
amc.compile_options = 'hello world'
assert "`value' has to be a dictionary. But it is <class 'str'>" in str(einfo.value)
def test_getattr(self, amc): def test_getattr(self, amc):
amc.model = keras.Model() amc.model = keras.Model()
assert hasattr(amc, "compile") is True assert hasattr(amc, "compile") is True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment