From 27b393f5ec967c6a4dc2c0d36903cdfb58c505e6 Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Mon, 11 May 2020 14:24:41 +0200 Subject: [PATCH] fist compile setups --- src/model_modules/model_class.py | 57 ++++++++++++++++++++++---------- src/run_modules/model_setup.py | 7 ++-- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 290a527b..a240507c 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -34,11 +34,13 @@ class AbstractModelClass(ABC): self.__loss = None self.model_name = self.__class__.__name__ self.__custom_objects = {} - self.__allowed_compile_options = {'metrics': None, - 'loss_weights': None, - 'sample_weight_mode': None, - 'weighted_metrics': None, - 'target_tensors': None + self.__allowed_compile_options = {'optimizer': None, + 'loss': None, + 'metrics': None, + 'loss_weights': None, + 'sample_weight_mode': None, + 'weighted_metrics': None, + 'target_tensors': None } self.__compile_options = self.__allowed_compile_options @@ -113,14 +115,34 @@ class AbstractModelClass(ABC): @compile_options.setter def compile_options(self, value: Dict) -> None: - if not isinstance(value, dict): - raise TypeError(f"`value' has to be a dictionary. But it is {type(value)}") - for new_k, new_v in value.items(): - if new_k in self.__allowed_compile_options.keys(): - self.__compile_options[new_k] = new_v + for allow_k in self.__allowed_compile_options.keys(): + if hasattr(self, allow_k): + new_v_attr = getattr(self, allow_k) else: - logging.warning( - f"`{new_k}' is not a valid additional compile option. Will be ignored in keras.compile()") + new_v_attr = None + if isinstance(value, dict): + new_v_dic = value.pop(allow_k, None) + elif value is None: + new_v_dic = None + else: + raise TypeError(f'compile_options must be dict or None, but is {type(value)}.') + if (new_v_attr == new_v_dic) or ((new_v_attr is None) ^ (new_v_dic is None)): + if new_v_attr is not None: + self.__compile_options[allow_k] = new_v_attr + else: + self.__compile_options[allow_k] = new_v_dic + else: + raise SyntaxError( + f"Got different values for same argument: self.{allow_k}={new_v_attr} and '{allow_k}': {new_v_dic}") + + # if not isinstance(value, dict): + # raise TypeError(f"`value' has to be a dictionary. But it is {type(value)}") + # for new_k, new_v in value.items(): + # if new_k in self.__allowed_compile_options.keys(): + # self.__compile_options[new_k] = new_v + # else: + # logging.warning( + # f"`{new_k}' is not a valid additional compile option. Will be ignored in keras.compile()") def get_settings(self) -> Dict: """ @@ -191,9 +213,6 @@ class MyLittleModel(AbstractModelClass): self.channels = channels self.dropout_rate = 0.1 self.regularizer = keras.regularizers.l2(0.1) - self.initial_lr = 1e-2 - self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) - self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) self.epochs = 20 self.batch_size = int(256) self.activation = keras.layers.PReLU @@ -239,10 +258,14 @@ class MyLittleModel(AbstractModelClass): :return: loss function """ - self.loss = keras.losses.mean_squared_error + # self.loss = keras.losses.mean_squared_error def set_compile_options(self): - self.compile_options = {"metrics": ["mse", "mae"]} + self.initial_lr = 1e-2 + self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) + self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, + epochs_drop=10) + self.compile_options = {"loss": keras.losses.mean_squared_error, "metrics": ["mse", "mae"]} class MyBranchedModel(AbstractModelClass): diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index de2d6a57..0563fc88 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -60,10 +60,11 @@ class ModelSetup(RunEnvironment): self.data_store.set("channels", channels, self.scope) def compile_model(self): - optimizer = self.data_store.get("optimizer", self.scope) - loss = self.model.loss + # optimizer = self.data_store.get("optimizer", self.scope) + # loss = self.model.loss compile_options = self.model.compile_options - self.model.compile(optimizer=optimizer, loss=loss, **compile_options) + # self.model.compile(optimizer=optimizer, loss=loss, **compile_options) + self.model.compile(**compile_options) self.data_store.set("model", self.model, self.scope) def _set_callbacks(self): -- GitLab