diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 85e5b6ecadb28428e1e78a00ce2da4dc690a3604..b46213e591798861fea4f0da13c9bab824200b4b 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -92,7 +92,7 @@ class AbstractModelClass(ABC): dictionary (1), as attribute, with compile_options=None (2) or as mixture of both of them (3). The method will raise an Error when the same parameter is set differently. - Example (1) + Example (1) Recommended (includes check for valid keywords which are used as args in keras.compile) .. code-block:: python def set_compile_options(self): self.compile_options = {"optimizer": keras.optimizers.SGD(), @@ -122,8 +122,10 @@ class AbstractModelClass(ABC): self.loss = keras.losses.mean_squared_error self.compile_options = {"optimizer" = keras.optimizers.Adam(), "metrics": ["mse", "mae"]} - Note: As long as the attribute and the dict value have exactly the same values, the setter method will not raise + Note: + * As long as the attribute and the dict value have exactly the same values, the setter method will not raise an error + * For example (2) there is no check implemented, if the attributes are valid compile options :return: @@ -132,6 +134,10 @@ class AbstractModelClass(ABC): @compile_options.setter def compile_options(self, value: Dict) -> None: + if isinstance(value, dict): + if not (set(value.keys()) <= set(self.__allowed_compile_options.keys())): + raise ValueError(f"Got invalid key for compile_options. {value.keys()}") + for allow_k in self.__allowed_compile_options.keys(): if hasattr(self, allow_k): new_v_attr = getattr(self, allow_k) diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py index a5dbb35ee213a8b4b1f538f6fa1d7e8dcc688dea..a8df3fe7213eef476b2ef7dbeac29d84b698f05a 100644 --- a/test/test_model_modules/test_model_class.py +++ b/test/test_model_modules/test_model_class.py @@ -153,6 +153,11 @@ class TestAbstractModelClass: assert "Got different values or arguments for same argument: self.optimizer=<class" \ " 'keras.optimizers.SGD'> and 'optimizer': <class 'keras.optimizers.SGD'>" in str(einfo.value) + def test_compile_options_setter_as_dict_invalid_keys(self, amc): + with pytest.raises(ValueError) as einfo: + amc.compile_options = {"optimizer": keras.optimizers.SGD(), "InvalidKeyword": [1, 2, 3]} + assert "Got invalid key for compile_options. dict_keys(['optimizer', 'InvalidKeyword'])" in str(einfo.value) + def test_compare_keras_optimizers_equal(self, amc): assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(), keras.optimizers.SGD()) is True