diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index e7d0437f2ff62b635146047496f09db9e7fcdd5c..4a323f46ff95a7ca66c157f2e4d6d3184f244a4a 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -37,7 +37,13 @@ class AbstractModelClass(ABC): self.__compile_options_is_set = False self._input_shape = input_shape self._output_shape = self.__extract_from_tuple(output_shape) - # self.avail_gpus = len(K.tensorflow_backend._get_available_gpus()) + + def load_model(self, name: str, compile: bool = False) -> None: + hist = self.model.history + self.model.load_weights(name) + self.model.history = hist + if compile is True: + self.model.compile(**self.compile_options) def __getattr__(self, name: str) -> Any: """ diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index 96cfdccf9e29f0b4cf780432f2a14a4406a60580..a3291dabd5b1684e64aa5beec88aba2caeadde7c 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -454,12 +454,12 @@ class IntelliO3TsArchitecture(AbstractModelClass): kernel_regularizer=self.regularizer ) - model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main]) - if self.avail_gpus <= 1: - self.model = model - else: - self.model = keras.utils.multi_gpu_model(model, self.avail_gpus) - print(f"Set multi_gpu model with {self.avail_gpus} GPUs") + self.model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main]) + # if self.avail_gpus <= 1: + # self.model = model + # else: + # self.model = keras.utils.multi_gpu_model(model, self.avail_gpus) + # print(f"Set multi_gpu model with {self.avail_gpus} GPUs") def set_compile_options(self): self.compile_options = {"optimizer": keras.optimizers.Adam(lr=self.initial_lr, amsgrad=True), @@ -762,6 +762,11 @@ class MyUnet(AbstractModelClass): self.compile_options = {"metrics": ["mse", "mae"]} +class NN3s(MyUnet): + def __init__(self, input_shape: list, output_shape: list): + super().__init__(input_shape, output_shape) + + class MySimpleConv2D(AbstractModelClass): """ Example adopted from https://www.kaggle.com/dimitreoliveira/deep-learning-for-time-series-forecasting diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 0b9e8ec56592901d9feba15eb50b6b21a0c53560..98263eb732d8067fba0950c7a4882fb3ef020995 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -84,7 +84,7 @@ class ModelSetup(RunEnvironment): # load weights if no training shall be performed if not self._train_model and not self._create_new_model: - self.load_weights() + self.load_model() # create checkpoint self._set_callbacks() @@ -131,13 +131,13 @@ class ModelSetup(RunEnvironment): save_best_only=True, mode='auto') self.data_store.set("callbacks", callbacks, self.scope) - def load_weights(self): - """Try to load weights from existing model or skip if not possible.""" + def load_model(self): + """Try to load model from disk or skip if not possible.""" try: - self.model.load_weights(self.model_name) - logging.info(f"reload weights from model {self.model_name} ...") + self.model.load_model(self.model_name) + logging.info(f"reload model {self.model_name} from disk ...") except OSError: - logging.info('no weights to reload...') + logging.info('no local model to load...') def build_model(self): """Build model using input and output shapes from data store."""