Skip to content
Snippets Groups Projects
Commit ffe2aa52 authored by lukas leufen's avatar lukas leufen
Browse files

supplement on testing and docs

parent febd62e6
No related branches found
No related tags found
2 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu"
Pipeline #31857 passed
...@@ -82,13 +82,22 @@ class AbstractModelClass(ABC): ...@@ -82,13 +82,22 @@ class AbstractModelClass(ABC):
@property @property
def custom_objects(self) -> Dict: def custom_objects(self) -> Dict:
"""
The custom objects property collects all non-keras utilities that are used in the model class. To load such a
customised and already compiled model (e.g. from local disk), this information is required.
:return: the custom objects in a dictionary
"""
return self.__custom_objects return self.__custom_objects
@custom_objects.setter @custom_objects.setter
def custom_objects(self, value) -> None: def custom_objects(self, value) -> None:
self.__custom_objects = value self.__custom_objects = value
def get_settings(self): def get_settings(self) -> Dict:
"""
Get all class attributes that are not protected in the AbstractModelClass as dictionary.
:return: all class attributes
"""
return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__")) return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__"))
def set_model(self): def set_model(self):
...@@ -97,10 +106,17 @@ class AbstractModelClass(ABC): ...@@ -97,10 +106,17 @@ class AbstractModelClass(ABC):
def set_loss(self): def set_loss(self):
pass pass
def set_custom_objects(self, **kwargs): def set_custom_objects(self, **kwargs) -> None:
"""
Set custom objects that are not part of keras framework. These custom objects are needed if an already compiled
model is loaded from disk. There is a special treatment for the Padding2D class, which is a base class for
different padding types. For a correct behaviour, all supported subclasses are added as custom objects in
addition to the given ones.
:param kwargs: all custom objects, that should be saved
"""
if "Padding2D" in kwargs.keys(): if "Padding2D" in kwargs.keys():
kwargs.update(kwargs["Padding2D"].allowed_paddings) kwargs.update(kwargs["Padding2D"].allowed_paddings)
self.custom_objects.update(kwargs) self.custom_objects = kwargs
class MyLittleModel(AbstractModelClass): class MyLittleModel(AbstractModelClass):
......
...@@ -5,15 +5,32 @@ from src.model_modules.model_class import AbstractModelClass ...@@ -5,15 +5,32 @@ from src.model_modules.model_class import AbstractModelClass
from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel
class Paddings:
allowed_paddings = {"pad1": 34, "another_pad": True}
class AbstractModelSubClass(AbstractModelClass):
def __init__(self):
super().__init__()
self.test_attr = "testAttr"
class TestAbstractModelClass: class TestAbstractModelClass:
@pytest.fixture @pytest.fixture
def amc(self): def amc(self):
return AbstractModelClass() return AbstractModelClass()
@pytest.fixture
def amsc(self):
return AbstractModelSubClass()
def test_init(self, amc): def test_init(self, amc):
assert amc.model is None assert amc.model is None
assert amc.loss is None assert amc.loss is None
assert amc.model_name == "AbstractModelClass"
assert amc.custom_objects == {}
def test_model_property(self, amc): def test_model_property(self, amc):
amc.model = keras.Model() amc.model = keras.Model()
...@@ -29,6 +46,23 @@ class TestAbstractModelClass: ...@@ -29,6 +46,23 @@ class TestAbstractModelClass:
assert hasattr(amc.model, "compile") is True assert hasattr(amc.model, "compile") is True
assert amc.compile == amc.model.compile assert amc.compile == amc.model.compile
def test_get_settings(self, amc, amsc):
assert amc.get_settings() == {"model_name": "AbstractModelClass"}
assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass"}
def test_custom_objects(self, amc):
amc.custom_objects = {"Test": 123}
assert amc.custom_objects == {"Test": 123}
def test_set_custom_objects(self, amc):
amc.set_custom_objects(Test=22, minor_param="minor")
assert amc.custom_objects == {"Test": 22, "minor_param": "minor"}
amc.set_custom_objects(Test=2, minor_param1="minor1")
assert amc.custom_objects == {"Test": 2, "minor_param1": "minor1"}
paddings = Paddings()
amc.set_custom_objects(Test=1, Padding2D=paddings)
assert amc.custom_objects == {"Test": 1, "Padding2D": paddings, "pad1": 34, "another_pad": True}
class TestMyPaperModel: class TestMyPaperModel:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment