Skip to content
Snippets Groups Projects
Commit 01d0fca0 authored by lukas leufen's avatar lukas leufen
Browse files

include docs and tests

parents 9a21d049 ffe2aa52
Branches
Tags
2 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu"
Pipeline #31861 passed
...@@ -82,13 +82,22 @@ class AbstractModelClass(ABC): ...@@ -82,13 +82,22 @@ class AbstractModelClass(ABC):
@property @property
def custom_objects(self) -> Dict: def custom_objects(self) -> Dict:
"""
The custom objects property collects all non-keras utilities that are used in the model class. To load such a
customised and already compiled model (e.g. from local disk), this information is required.
:return: the custom objects in a dictionary
"""
return self.__custom_objects return self.__custom_objects
@custom_objects.setter @custom_objects.setter
def custom_objects(self, value) -> None: def custom_objects(self, value) -> None:
self.__custom_objects = value self.__custom_objects = value
def get_settings(self): def get_settings(self) -> Dict:
"""
Get all class attributes that are not protected in the AbstractModelClass as dictionary.
:return: all class attributes
"""
return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__")) return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__"))
def set_model(self): def set_model(self):
...@@ -97,10 +106,17 @@ class AbstractModelClass(ABC): ...@@ -97,10 +106,17 @@ class AbstractModelClass(ABC):
def set_loss(self): def set_loss(self):
pass pass
def set_custom_objects(self, **kwargs): def set_custom_objects(self, **kwargs) -> None:
"""
Set custom objects that are not part of keras framework. These custom objects are needed if an already compiled
model is loaded from disk. There is a special treatment for the Padding2D class, which is a base class for
different padding types. For a correct behaviour, all supported subclasses are added as custom objects in
addition to the given ones.
:param kwargs: all custom objects, that should be saved
"""
if "Padding2D" in kwargs.keys(): if "Padding2D" in kwargs.keys():
kwargs.update(kwargs["Padding2D"].allowed_paddings) kwargs.update(kwargs["Padding2D"].allowed_paddings)
self.custom_objects.update(kwargs) self.custom_objects = kwargs
class MyLittleModel(AbstractModelClass): class MyLittleModel(AbstractModelClass):
......
...@@ -5,15 +5,32 @@ from src.model_modules.model_class import AbstractModelClass ...@@ -5,15 +5,32 @@ from src.model_modules.model_class import AbstractModelClass
from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel
class Paddings:
allowed_paddings = {"pad1": 34, "another_pad": True}
class AbstractModelSubClass(AbstractModelClass):
def __init__(self):
super().__init__()
self.test_attr = "testAttr"
class TestAbstractModelClass: class TestAbstractModelClass:
@pytest.fixture @pytest.fixture
def amc(self): def amc(self):
return AbstractModelClass() return AbstractModelClass()
@pytest.fixture
def amsc(self):
return AbstractModelSubClass()
def test_init(self, amc): def test_init(self, amc):
assert amc.model is None assert amc.model is None
assert amc.loss is None assert amc.loss is None
assert amc.model_name == "AbstractModelClass"
assert amc.custom_objects == {}
def test_model_property(self, amc): def test_model_property(self, amc):
amc.model = keras.Model() amc.model = keras.Model()
...@@ -29,6 +46,23 @@ class TestAbstractModelClass: ...@@ -29,6 +46,23 @@ class TestAbstractModelClass:
assert hasattr(amc.model, "compile") is True assert hasattr(amc.model, "compile") is True
assert amc.compile == amc.model.compile assert amc.compile == amc.model.compile
def test_get_settings(self, amc, amsc):
assert amc.get_settings() == {"model_name": "AbstractModelClass"}
assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass"}
def test_custom_objects(self, amc):
amc.custom_objects = {"Test": 123}
assert amc.custom_objects == {"Test": 123}
def test_set_custom_objects(self, amc):
amc.set_custom_objects(Test=22, minor_param="minor")
assert amc.custom_objects == {"Test": 22, "minor_param": "minor"}
amc.set_custom_objects(Test=2, minor_param1="minor1")
assert amc.custom_objects == {"Test": 2, "minor_param1": "minor1"}
paddings = Paddings()
amc.set_custom_objects(Test=1, Padding2D=paddings)
assert amc.custom_objects == {"Test": 1, "Padding2D": paddings, "pad1": 34, "another_pad": True}
class TestMyPaperModel: class TestMyPaperModel:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment