Skip to content
Snippets Groups Projects
Commit 01d0fca0 authored by lukas leufen's avatar lukas leufen
Browse files

include docs and tests

parents 9a21d049 ffe2aa52
No related branches found
No related tags found
2 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu"
Pipeline #31861 passed
......@@ -82,13 +82,22 @@ class AbstractModelClass(ABC):
@property
def custom_objects(self) -> Dict:
"""
The custom objects property collects all non-keras utilities that are used in the model class. To load such a
customised and already compiled model (e.g. from local disk), this information is required.
:return: the custom objects in a dictionary
"""
return self.__custom_objects
@custom_objects.setter
def custom_objects(self, value) -> None:
self.__custom_objects = value
def get_settings(self):
def get_settings(self) -> Dict:
"""
Get all class attributes that are not protected in the AbstractModelClass as dictionary.
:return: all class attributes
"""
return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__"))
def set_model(self):
......@@ -97,10 +106,17 @@ class AbstractModelClass(ABC):
def set_loss(self):
pass
def set_custom_objects(self, **kwargs):
def set_custom_objects(self, **kwargs) -> None:
"""
Set custom objects that are not part of keras framework. These custom objects are needed if an already compiled
model is loaded from disk. There is a special treatment for the Padding2D class, which is a base class for
different padding types. For a correct behaviour, all supported subclasses are added as custom objects in
addition to the given ones.
:param kwargs: all custom objects, that should be saved
"""
if "Padding2D" in kwargs.keys():
kwargs.update(kwargs["Padding2D"].allowed_paddings)
self.custom_objects.update(kwargs)
self.custom_objects = kwargs
class MyLittleModel(AbstractModelClass):
......
......@@ -5,15 +5,32 @@ from src.model_modules.model_class import AbstractModelClass
from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel
class Paddings:
allowed_paddings = {"pad1": 34, "another_pad": True}
class AbstractModelSubClass(AbstractModelClass):
def __init__(self):
super().__init__()
self.test_attr = "testAttr"
class TestAbstractModelClass:
@pytest.fixture
def amc(self):
return AbstractModelClass()
@pytest.fixture
def amsc(self):
return AbstractModelSubClass()
def test_init(self, amc):
assert amc.model is None
assert amc.loss is None
assert amc.model_name == "AbstractModelClass"
assert amc.custom_objects == {}
def test_model_property(self, amc):
amc.model = keras.Model()
......@@ -29,6 +46,23 @@ class TestAbstractModelClass:
assert hasattr(amc.model, "compile") is True
assert amc.compile == amc.model.compile
def test_get_settings(self, amc, amsc):
assert amc.get_settings() == {"model_name": "AbstractModelClass"}
assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass"}
def test_custom_objects(self, amc):
amc.custom_objects = {"Test": 123}
assert amc.custom_objects == {"Test": 123}
def test_set_custom_objects(self, amc):
amc.set_custom_objects(Test=22, minor_param="minor")
assert amc.custom_objects == {"Test": 22, "minor_param": "minor"}
amc.set_custom_objects(Test=2, minor_param1="minor1")
assert amc.custom_objects == {"Test": 2, "minor_param1": "minor1"}
paddings = Paddings()
amc.set_custom_objects(Test=1, Padding2D=paddings)
assert amc.custom_objects == {"Test": 1, "Padding2D": paddings, "pad1": 34, "another_pad": True}
class TestMyPaperModel:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment