From 33e294b06d667804ac2237b1317690fbdc771b27 Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Tue, 12 May 2020 14:44:47 +0200 Subject: [PATCH] check if all dict keys passed to compile_options are valid arguments for kears.compile --- src/model_modules/model_class.py | 10 ++++++++-- test/test_model_modules/test_model_class.py | 5 +++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 85e5b6ec..b46213e5 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -92,7 +92,7 @@ class AbstractModelClass(ABC): dictionary (1), as attribute, with compile_options=None (2) or as mixture of both of them (3). The method will raise an Error when the same parameter is set differently. - Example (1) + Example (1) Recommended (includes check for valid keywords which are used as args in keras.compile) .. code-block:: python def set_compile_options(self): self.compile_options = {"optimizer": keras.optimizers.SGD(), @@ -122,8 +122,10 @@ class AbstractModelClass(ABC): self.loss = keras.losses.mean_squared_error self.compile_options = {"optimizer" = keras.optimizers.Adam(), "metrics": ["mse", "mae"]} - Note: As long as the attribute and the dict value have exactly the same values, the setter method will not raise + Note: + * As long as the attribute and the dict value have exactly the same values, the setter method will not raise an error + * For example (2) there is no check implemented, if the attributes are valid compile options :return: @@ -132,6 +134,10 @@ class AbstractModelClass(ABC): @compile_options.setter def compile_options(self, value: Dict) -> None: + if isinstance(value, dict): + if not (set(value.keys()) <= set(self.__allowed_compile_options.keys())): + raise ValueError(f"Got invalid key for compile_options. {value.keys()}") + for allow_k in self.__allowed_compile_options.keys(): if hasattr(self, allow_k): new_v_attr = getattr(self, allow_k) diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py index a5dbb35e..a8df3fe7 100644 --- a/test/test_model_modules/test_model_class.py +++ b/test/test_model_modules/test_model_class.py @@ -153,6 +153,11 @@ class TestAbstractModelClass: assert "Got different values or arguments for same argument: self.optimizer=<class" \ " 'keras.optimizers.SGD'> and 'optimizer': <class 'keras.optimizers.SGD'>" in str(einfo.value) + def test_compile_options_setter_as_dict_invalid_keys(self, amc): + with pytest.raises(ValueError) as einfo: + amc.compile_options = {"optimizer": keras.optimizers.SGD(), "InvalidKeyword": [1, 2, 3]} + assert "Got invalid key for compile_options. dict_keys(['optimizer', 'InvalidKeyword'])" in str(einfo.value) + def test_compare_keras_optimizers_equal(self, amc): assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(), keras.optimizers.SGD()) is True -- GitLab