diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index d6dcea179bcfa8a6ec41518db34b186e30d908fc..0064c795e9bba162fafe3e9d5f60a17b95a4d57f 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -8,11 +8,13 @@ from abc import ABC from typing import Any, Callable, Dict import keras +import logging from src.model_modules.inception_model import InceptionModelBase from src.model_modules.flatten import flatten_tail from src.model_modules.advanced_paddings import PadUtils, Padding2D + class AbstractModelClass(ABC): """ @@ -32,6 +34,13 @@ class AbstractModelClass(ABC): self.__loss = None self.model_name = self.__class__.__name__ self.__custom_objects = {} + self.__allowed_compile_options = {'metrics': None, + 'loss_weights': None, + 'sample_weight_mode': None, + 'weighted_metrics': None, + 'target_tensors': None + } + self.__compile_options = self.__allowed_compile_options def __getattr__(self, name: str) -> Any: @@ -93,6 +102,26 @@ class AbstractModelClass(ABC): def custom_objects(self, value) -> None: self.__custom_objects = value + @property + def compile_options(self) -> Callable: + """ + The compile options property allows the user to use all keras.compile() arguments which are not already covered + by __loss and optimizer + :return: + """ + return self.__compile_options + + @compile_options.setter + def compile_options(self, value: Dict) -> None: + if not isinstance(value, dict): + raise TypeError(f"`value' has to be a dictionary. But it is {type(value)}") + for new_k, new_v in value.items(): + if new_k in self.__allowed_compile_options.keys(): + self.__compile_options[new_k] = new_v + else: + logging.warning( + f"`{new_k}' is not a valid additional compile option. Will be ignored in keras.compile()") + def get_settings(self) -> Dict: """ Get all class attributes that are not protected in the AbstractModelClass as dictionary. @@ -106,6 +135,21 @@ class AbstractModelClass(ABC): def set_loss(self): pass + def set_compile_options(self): + """ + This method only has to be defined in child class, when additional compile options should be used () + (other options than optimizer and loss) + Has to be set as dictionary: {'metrics': None, + 'loss_weights': None, + 'sample_weight_mode': None, + 'weighted_metrics': None, + 'target_tensors': None + } + + :return: + """ + pass + def set_custom_objects(self, **kwargs) -> None: """ Set custom objects that are not part of keras framework. These custom objects are needed if an already compiled @@ -157,6 +201,7 @@ class MyLittleModel(AbstractModelClass): # apply to model self.set_model() self.set_loss() + self.set_compile_options() self.set_custom_objects(loss=self.loss) def set_model(self): @@ -196,6 +241,9 @@ class MyLittleModel(AbstractModelClass): self.loss = keras.losses.mean_squared_error + def set_compile_options(self): + self.compile_options = {"metrics": ["mse", "mae"]} + class MyBranchedModel(AbstractModelClass): @@ -315,6 +363,7 @@ class MyTowerModel(AbstractModelClass): # apply to model self.set_model() self.set_loss() + self.set_compile_options() self.set_custom_objects(loss=self.loss) def set_model(self): @@ -392,6 +441,9 @@ class MyTowerModel(AbstractModelClass): self.loss = [keras.losses.mean_squared_error] + def set_compile_options(self): + self.compile_options = {"metrics": ["mse"]} + class MyPaperModel(AbstractModelClass): diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index c558b5fc76ff336dc6a792ec0239fa3b64eab466..de2d6a576662702128ae5e486a072904b3c3bf73 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -10,8 +10,8 @@ import tensorflow as tf from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler # from src.model_modules.model_class import MyBranchedModel as MyModel -# from src.model_modules.model_class import MyLittleModel as MyModel -from src.model_modules.model_class import MyTowerModel as MyModel +from src.model_modules.model_class import MyLittleModel as MyModel +# from src.model_modules.model_class import MyTowerModel as MyModel # from src.model_modules.model_class import MyPaperModel as MyModel from src.run_modules.run_environment import RunEnvironment @@ -62,7 +62,8 @@ class ModelSetup(RunEnvironment): def compile_model(self): optimizer = self.data_store.get("optimizer", self.scope) loss = self.model.loss - self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) + compile_options = self.model.compile_options + self.model.compile(optimizer=optimizer, loss=loss, **compile_options) self.data_store.set("model", self.model, self.scope) def _set_callbacks(self): diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py index cee031749b193b91bd1cf16c02acfb3050eaed61..147c92532465574b625907d13c814c5cfcbaeac9 100644 --- a/test/test_model_modules/test_model_class.py +++ b/test/test_model_modules/test_model_class.py @@ -40,6 +40,16 @@ class TestAbstractModelClass: amc.loss = keras.losses.mean_absolute_error assert amc.loss == keras.losses.mean_absolute_error + def test_compile_options_property(self, amc): + amc.compile_options = {"metrics": ["mse", "mae"]} + assert amc.compile_options == {'loss_weights': None, 'metrics': ['mse', 'mae'], 'sample_weight_mode': None, + 'target_tensors': None, 'weighted_metrics': None} + + def test_compile_options_property_type_error(self, amc): + with pytest.raises(TypeError) as einfo: + amc.compile_options = 'hello world' + assert "`value' has to be a dictionary. But it is <class 'str'>" in str(einfo.value) + def test_getattr(self, amc): amc.model = keras.Model() assert hasattr(amc, "compile") is True