diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 290a527b0ae2ccc9e17ddf5ed49098a4a55a173b..a240507ce962230e3818d746c2a682cf82364b95 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -34,11 +34,13 @@ class AbstractModelClass(ABC): self.__loss = None self.model_name = self.__class__.__name__ self.__custom_objects = {} - self.__allowed_compile_options = {'metrics': None, - 'loss_weights': None, - 'sample_weight_mode': None, - 'weighted_metrics': None, - 'target_tensors': None + self.__allowed_compile_options = {'optimizer': None, + 'loss': None, + 'metrics': None, + 'loss_weights': None, + 'sample_weight_mode': None, + 'weighted_metrics': None, + 'target_tensors': None } self.__compile_options = self.__allowed_compile_options @@ -113,14 +115,34 @@ class AbstractModelClass(ABC): @compile_options.setter def compile_options(self, value: Dict) -> None: - if not isinstance(value, dict): - raise TypeError(f"`value' has to be a dictionary. But it is {type(value)}") - for new_k, new_v in value.items(): - if new_k in self.__allowed_compile_options.keys(): - self.__compile_options[new_k] = new_v + for allow_k in self.__allowed_compile_options.keys(): + if hasattr(self, allow_k): + new_v_attr = getattr(self, allow_k) else: - logging.warning( - f"`{new_k}' is not a valid additional compile option. Will be ignored in keras.compile()") + new_v_attr = None + if isinstance(value, dict): + new_v_dic = value.pop(allow_k, None) + elif value is None: + new_v_dic = None + else: + raise TypeError(f'compile_options must be dict or None, but is {type(value)}.') + if (new_v_attr == new_v_dic) or ((new_v_attr is None) ^ (new_v_dic is None)): + if new_v_attr is not None: + self.__compile_options[allow_k] = new_v_attr + else: + self.__compile_options[allow_k] = new_v_dic + else: + raise SyntaxError( + f"Got different values for same argument: self.{allow_k}={new_v_attr} and '{allow_k}': {new_v_dic}") + + # if not isinstance(value, dict): + # raise TypeError(f"`value' has to be a dictionary. But it is {type(value)}") + # for new_k, new_v in value.items(): + # if new_k in self.__allowed_compile_options.keys(): + # self.__compile_options[new_k] = new_v + # else: + # logging.warning( + # f"`{new_k}' is not a valid additional compile option. Will be ignored in keras.compile()") def get_settings(self) -> Dict: """ @@ -191,9 +213,6 @@ class MyLittleModel(AbstractModelClass): self.channels = channels self.dropout_rate = 0.1 self.regularizer = keras.regularizers.l2(0.1) - self.initial_lr = 1e-2 - self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) - self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) self.epochs = 20 self.batch_size = int(256) self.activation = keras.layers.PReLU @@ -239,10 +258,14 @@ class MyLittleModel(AbstractModelClass): :return: loss function """ - self.loss = keras.losses.mean_squared_error + # self.loss = keras.losses.mean_squared_error def set_compile_options(self): - self.compile_options = {"metrics": ["mse", "mae"]} + self.initial_lr = 1e-2 + self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) + self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, + epochs_drop=10) + self.compile_options = {"loss": keras.losses.mean_squared_error, "metrics": ["mse", "mae"]} class MyBranchedModel(AbstractModelClass): diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index de2d6a576662702128ae5e486a072904b3c3bf73..0563fc88948c7281f389ec300b294116abd7a491 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -60,10 +60,11 @@ class ModelSetup(RunEnvironment): self.data_store.set("channels", channels, self.scope) def compile_model(self): - optimizer = self.data_store.get("optimizer", self.scope) - loss = self.model.loss + # optimizer = self.data_store.get("optimizer", self.scope) + # loss = self.model.loss compile_options = self.model.compile_options - self.model.compile(optimizer=optimizer, loss=loss, **compile_options) + # self.model.compile(optimizer=optimizer, loss=loss, **compile_options) + self.model.compile(**compile_options) self.data_store.set("model", self.model, self.scope) def _set_callbacks(self):