diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 27e13c66c6cdb2466834dd78128f8bb2b1b44d57..d6dcea179bcfa8a6ec41518db34b186e30d908fc 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -82,13 +82,22 @@ class AbstractModelClass(ABC): @property def custom_objects(self) -> Dict: + """ + The custom objects property collects all non-keras utilities that are used in the model class. To load such a + customised and already compiled model (e.g. from local disk), this information is required. + :return: the custom objects in a dictionary + """ return self.__custom_objects @custom_objects.setter def custom_objects(self, value) -> None: self.__custom_objects = value - def get_settings(self): + def get_settings(self) -> Dict: + """ + Get all class attributes that are not protected in the AbstractModelClass as dictionary. + :return: all class attributes + """ return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__")) def set_model(self): @@ -97,10 +106,17 @@ class AbstractModelClass(ABC): def set_loss(self): pass - def set_custom_objects(self, **kwargs): + def set_custom_objects(self, **kwargs) -> None: + """ + Set custom objects that are not part of keras framework. These custom objects are needed if an already compiled + model is loaded from disk. There is a special treatment for the Padding2D class, which is a base class for + different padding types. For a correct behaviour, all supported subclasses are added as custom objects in + addition to the given ones. + :param kwargs: all custom objects, that should be saved + """ if "Padding2D" in kwargs.keys(): kwargs.update(kwargs["Padding2D"].allowed_paddings) - self.custom_objects.update(kwargs) + self.custom_objects = kwargs class MyLittleModel(AbstractModelClass): diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py index 13f982b80906d8d5d6beae7075b23f4c84d6edd1..cee031749b193b91bd1cf16c02acfb3050eaed61 100644 --- a/test/test_model_modules/test_model_class.py +++ b/test/test_model_modules/test_model_class.py @@ -5,15 +5,32 @@ from src.model_modules.model_class import AbstractModelClass from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel +class Paddings: + allowed_paddings = {"pad1": 34, "another_pad": True} + + +class AbstractModelSubClass(AbstractModelClass): + + def __init__(self): + super().__init__() + self.test_attr = "testAttr" + + class TestAbstractModelClass: @pytest.fixture def amc(self): return AbstractModelClass() + @pytest.fixture + def amsc(self): + return AbstractModelSubClass() + def test_init(self, amc): assert amc.model is None assert amc.loss is None + assert amc.model_name == "AbstractModelClass" + assert amc.custom_objects == {} def test_model_property(self, amc): amc.model = keras.Model() @@ -29,6 +46,23 @@ class TestAbstractModelClass: assert hasattr(amc.model, "compile") is True assert amc.compile == amc.model.compile + def test_get_settings(self, amc, amsc): + assert amc.get_settings() == {"model_name": "AbstractModelClass"} + assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass"} + + def test_custom_objects(self, amc): + amc.custom_objects = {"Test": 123} + assert amc.custom_objects == {"Test": 123} + + def test_set_custom_objects(self, amc): + amc.set_custom_objects(Test=22, minor_param="minor") + assert amc.custom_objects == {"Test": 22, "minor_param": "minor"} + amc.set_custom_objects(Test=2, minor_param1="minor1") + assert amc.custom_objects == {"Test": 2, "minor_param1": "minor1"} + paddings = Paddings() + amc.set_custom_objects(Test=1, Padding2D=paddings) + assert amc.custom_objects == {"Test": 1, "Padding2D": paddings, "pad1": 34, "another_pad": True} + class TestMyPaperModel: