From 27b393f5ec967c6a4dc2c0d36903cdfb58c505e6 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Mon, 11 May 2020 14:24:41 +0200
Subject: [PATCH] fist compile setups

---
 src/model_modules/model_class.py | 57 ++++++++++++++++++++++----------
 src/run_modules/model_setup.py   |  7 ++--
 2 files changed, 44 insertions(+), 20 deletions(-)

diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index 290a527b..a240507c 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -34,11 +34,13 @@ class AbstractModelClass(ABC):
         self.__loss = None
         self.model_name = self.__class__.__name__
         self.__custom_objects = {}
-        self.__allowed_compile_options = {'metrics': None,
-                                         'loss_weights': None,
-                                         'sample_weight_mode': None,
-                                         'weighted_metrics': None,
-                                         'target_tensors': None
+        self.__allowed_compile_options = {'optimizer': None,
+                                          'loss': None,
+                                          'metrics': None,
+                                          'loss_weights': None,
+                                          'sample_weight_mode': None,
+                                          'weighted_metrics': None,
+                                          'target_tensors': None
                                           }
         self.__compile_options = self.__allowed_compile_options
 
@@ -113,14 +115,34 @@ class AbstractModelClass(ABC):
 
     @compile_options.setter
     def compile_options(self, value: Dict) -> None:
-        if not isinstance(value, dict):
-            raise TypeError(f"`value' has to be a dictionary. But it is {type(value)}")
-        for new_k, new_v in value.items():
-            if new_k in self.__allowed_compile_options.keys():
-                self.__compile_options[new_k] = new_v
+        for allow_k in self.__allowed_compile_options.keys():
+            if hasattr(self, allow_k):
+                new_v_attr = getattr(self, allow_k)
             else:
-                logging.warning(
-                    f"`{new_k}' is not a valid additional compile option. Will be ignored in keras.compile()")
+                new_v_attr = None
+            if isinstance(value, dict):
+                new_v_dic = value.pop(allow_k, None)
+            elif value is None:
+                new_v_dic = None
+            else:
+                raise TypeError(f'compile_options must be dict or None, but is {type(value)}.')
+            if (new_v_attr == new_v_dic) or ((new_v_attr is None) ^ (new_v_dic is None)):
+                if new_v_attr is not None:
+                    self.__compile_options[allow_k] = new_v_attr
+                else:
+                    self.__compile_options[allow_k] = new_v_dic
+            else:
+                raise SyntaxError(
+                    f"Got different values for same argument: self.{allow_k}={new_v_attr} and '{allow_k}': {new_v_dic}")
+
+        # if not isinstance(value, dict):
+        #     raise TypeError(f"`value' has to be a dictionary. But it is {type(value)}")
+        # for new_k, new_v in value.items():
+        #     if new_k in self.__allowed_compile_options.keys():
+        #         self.__compile_options[new_k] = new_v
+        #     else:
+        #         logging.warning(
+        #             f"`{new_k}' is not a valid additional compile option. Will be ignored in keras.compile()")
 
     def get_settings(self) -> Dict:
         """
@@ -191,9 +213,6 @@ class MyLittleModel(AbstractModelClass):
         self.channels = channels
         self.dropout_rate = 0.1
         self.regularizer = keras.regularizers.l2(0.1)
-        self.initial_lr = 1e-2
-        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
-        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
         self.epochs = 20
         self.batch_size = int(256)
         self.activation = keras.layers.PReLU
@@ -239,10 +258,14 @@ class MyLittleModel(AbstractModelClass):
         :return: loss function
         """
 
-        self.loss = keras.losses.mean_squared_error
+        # self.loss = keras.losses.mean_squared_error
 
     def set_compile_options(self):
-        self.compile_options = {"metrics": ["mse", "mae"]}
+        self.initial_lr = 1e-2
+        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
+        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
+                                                                             epochs_drop=10)
+        self.compile_options = {"loss": keras.losses.mean_squared_error, "metrics": ["mse", "mae"]}
 
 
 class MyBranchedModel(AbstractModelClass):
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index de2d6a57..0563fc88 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -60,10 +60,11 @@ class ModelSetup(RunEnvironment):
         self.data_store.set("channels", channels, self.scope)
 
     def compile_model(self):
-        optimizer = self.data_store.get("optimizer", self.scope)
-        loss = self.model.loss
+        # optimizer = self.data_store.get("optimizer", self.scope)
+        # loss = self.model.loss
         compile_options = self.model.compile_options
-        self.model.compile(optimizer=optimizer, loss=loss, **compile_options)
+        # self.model.compile(optimizer=optimizer, loss=loss, **compile_options)
+        self.model.compile(**compile_options)
         self.data_store.set("model", self.model, self.scope)
 
     def _set_callbacks(self):
-- 
GitLab