From f6ad4736623818096b7dbac759565338bfc3a668 Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Mon, 20 Dec 2021 09:31:54 +0100 Subject: [PATCH] update model load fkts --- mlair/model_modules/abstract_model_class.py | 8 +++++++- mlair/model_modules/model_class.py | 17 +++++++++++------ mlair/run_modules/model_setup.py | 12 ++++++------ 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index e7d0437f..4a323f46 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -37,7 +37,13 @@ class AbstractModelClass(ABC): self.__compile_options_is_set = False self._input_shape = input_shape self._output_shape = self.__extract_from_tuple(output_shape) - # self.avail_gpus = len(K.tensorflow_backend._get_available_gpus()) + + def load_model(self, name: str, compile: bool = False) -> None: + hist = self.model.history + self.model.load_weights(name) + self.model.history = hist + if compile is True: + self.model.compile(**self.compile_options) def __getattr__(self, name: str) -> Any: """ diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index 96cfdccf..a3291dab 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -454,12 +454,12 @@ class IntelliO3TsArchitecture(AbstractModelClass): kernel_regularizer=self.regularizer ) - model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main]) - if self.avail_gpus <= 1: - self.model = model - else: - self.model = keras.utils.multi_gpu_model(model, self.avail_gpus) - print(f"Set multi_gpu model with {self.avail_gpus} GPUs") + self.model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main]) + # if self.avail_gpus <= 1: + # self.model = model + # else: + # self.model = keras.utils.multi_gpu_model(model, self.avail_gpus) + # print(f"Set multi_gpu model with {self.avail_gpus} GPUs") def set_compile_options(self): self.compile_options = {"optimizer": keras.optimizers.Adam(lr=self.initial_lr, amsgrad=True), @@ -762,6 +762,11 @@ class MyUnet(AbstractModelClass): self.compile_options = {"metrics": ["mse", "mae"]} +class NN3s(MyUnet): + def __init__(self, input_shape: list, output_shape: list): + super().__init__(input_shape, output_shape) + + class MySimpleConv2D(AbstractModelClass): """ Example adopted from https://www.kaggle.com/dimitreoliveira/deep-learning-for-time-series-forecasting diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 0b9e8ec5..98263eb7 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -84,7 +84,7 @@ class ModelSetup(RunEnvironment): # load weights if no training shall be performed if not self._train_model and not self._create_new_model: - self.load_weights() + self.load_model() # create checkpoint self._set_callbacks() @@ -131,13 +131,13 @@ class ModelSetup(RunEnvironment): save_best_only=True, mode='auto') self.data_store.set("callbacks", callbacks, self.scope) - def load_weights(self): - """Try to load weights from existing model or skip if not possible.""" + def load_model(self): + """Try to load model from disk or skip if not possible.""" try: - self.model.load_weights(self.model_name) - logging.info(f"reload weights from model {self.model_name} ...") + self.model.load_model(self.model_name) + logging.info(f"reload model {self.model_name} from disk ...") except OSError: - logging.info('no weights to reload...') + logging.info('no local model to load...') def build_model(self): """Build model using input and output shapes from data store.""" -- GitLab