From f6ad4736623818096b7dbac759565338bfc3a668 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Mon, 20 Dec 2021 09:31:54 +0100
Subject: [PATCH] update model load fkts

---
 mlair/model_modules/abstract_model_class.py |  8 +++++++-
 mlair/model_modules/model_class.py          | 17 +++++++++++------
 mlair/run_modules/model_setup.py            | 12 ++++++------
 3 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py
index e7d0437f..4a323f46 100644
--- a/mlair/model_modules/abstract_model_class.py
+++ b/mlair/model_modules/abstract_model_class.py
@@ -37,7 +37,13 @@ class AbstractModelClass(ABC):
         self.__compile_options_is_set = False
         self._input_shape = input_shape
         self._output_shape = self.__extract_from_tuple(output_shape)
-        # self.avail_gpus = len(K.tensorflow_backend._get_available_gpus())
+
+    def load_model(self, name: str, compile: bool = False) -> None:
+        hist = self.model.history
+        self.model.load_weights(name)
+        self.model.history = hist
+        if compile is True:
+            self.model.compile(**self.compile_options)
 
     def __getattr__(self, name: str) -> Any:
         """
diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py
index 96cfdccf..a3291dab 100644
--- a/mlair/model_modules/model_class.py
+++ b/mlair/model_modules/model_class.py
@@ -454,12 +454,12 @@ class IntelliO3TsArchitecture(AbstractModelClass):
                                 kernel_regularizer=self.regularizer
                                 )
 
-        model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main])
-        if self.avail_gpus <= 1:
-            self.model = model
-        else:
-            self.model = keras.utils.multi_gpu_model(model, self.avail_gpus)
-            print(f"Set multi_gpu model with {self.avail_gpus} GPUs")
+        self.model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main])
+        # if self.avail_gpus <= 1:
+        #     self.model = model
+        # else:
+        #     self.model = keras.utils.multi_gpu_model(model, self.avail_gpus)
+        #     print(f"Set multi_gpu model with {self.avail_gpus} GPUs")
 
     def set_compile_options(self):
         self.compile_options = {"optimizer": keras.optimizers.Adam(lr=self.initial_lr, amsgrad=True),
@@ -762,6 +762,11 @@ class MyUnet(AbstractModelClass):
         self.compile_options = {"metrics": ["mse", "mae"]}
 
 
+class NN3s(MyUnet):
+    def __init__(self, input_shape: list, output_shape: list):
+        super().__init__(input_shape, output_shape)
+
+
 class MySimpleConv2D(AbstractModelClass):
     """
     Example adopted from https://www.kaggle.com/dimitreoliveira/deep-learning-for-time-series-forecasting
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 0b9e8ec5..98263eb7 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -84,7 +84,7 @@ class ModelSetup(RunEnvironment):
 
         # load weights if no training shall be performed
         if not self._train_model and not self._create_new_model:
-            self.load_weights()
+            self.load_model()
 
         # create checkpoint
         self._set_callbacks()
@@ -131,13 +131,13 @@ class ModelSetup(RunEnvironment):
                                           save_best_only=True, mode='auto')
         self.data_store.set("callbacks", callbacks, self.scope)
 
-    def load_weights(self):
-        """Try to load weights from existing model or skip if not possible."""
+    def load_model(self):
+        """Try to load model from disk or skip if not possible."""
         try:
-            self.model.load_weights(self.model_name)
-            logging.info(f"reload weights from model {self.model_name} ...")
+            self.model.load_model(self.model_name)
+            logging.info(f"reload model {self.model_name} from disk ...")
         except OSError:
-            logging.info('no weights to reload...')
+            logging.info('no local model to load...')
 
     def build_model(self):
         """Build model using input and output shapes from data store."""
-- 
GitLab