diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index e1b823fbf750d5381e06163aae262c39eb434c82..c12664e356d30bfc23333c3888d750bf245f4564 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -299,7 +299,8 @@ class ExperimentSetup(RunEnvironment): self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data")) # set model path - self._set_param("model_path", model_path, default=os.path.join(experiment_path, "model")) + self._set_param("model_load_path", model_path) + self._set_param("model_path", None, os.path.join(experiment_path, "model")) path_config.check_path_and_create(self.data_store.get("model_path")) # set plot path diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 819b6e0e53636140f1490314e82ae83e65aa9da0..3164db789e79a92364b300b5d8724861031afeb8 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -6,6 +6,8 @@ __date__ = '2019-12-02' import logging import os import re +import shutil + from dill.source import getsource import tensorflow.keras as keras @@ -58,16 +60,16 @@ class ModelSetup(RunEnvironment): """Initialise and run model setup.""" super().__init__() self.model = None - exp_name = self.data_store.get("experiment_name") + self.scope = "model" + self._train_model = self.data_store.get("train_model") + self._create_new_model = self.data_store.get("create_new_model") self.path = self.data_store.get("model_path") self.model_display_name = self.data_store.get_default("model_display_name", default=None) - self.scope = "model" - path = os.path.join(self.path, f"{self.model_display_name or exp_name}%s") + self.model_load_path = None + path = self._set_model_path() self.model_path = path % "%s.h5" self.checkpoint_name = path % "_model-best.h5" self.callbacks_name = path % "_model-best-callbacks-%s.pickle" - self._train_model = self.data_store.get("train_model") - self._create_new_model = self.data_store.get("create_new_model") self._run() def _run(self): @@ -97,6 +99,20 @@ class ModelSetup(RunEnvironment): # report settings self.report_model() + def _set_model_path(self): + exp_name = self.data_store.get("experiment_name") + self.model_load_path = self.data_store.get_default("model_load_path", default=None) + if self.model_load_path is not None: + if not self.model_load_path.endswith(".h5"): + raise FileNotFoundError(f"When providing external models, you need to provide full path including the " + f".h5 file. Given path is not valid: {self.model_load_path}") + if any([self._train_model, self._create_new_model]): + raise ValueError(f"Providing `model_path` along with parameters train_model={self._train_model} and " + f"create_new_model={self._create_new_model} is not possible. Either set both " + f"parameters to False or remove `model_path` parameter. Given was: model_path = " + f"{self.model_load_path}") + return os.path.join(self.path, f"{self.model_display_name or exp_name}%s") + def _set_shapes(self): """Set input and output shapes from train collection.""" shape = list(map(lambda x: x.shape[1:], self.data_store.get("data_collection", "train")[0].get_X())) @@ -160,8 +176,15 @@ class ModelSetup(RunEnvironment): # store callbacks self.data_store.set("callbacks", callbacks, self.scope) + def copy_model(self): + """Copy external model to internal experiment structure.""" + if self.model_load_path is not None: + logging.info(f"Copy external model file: {self.model_load_path} -> {self.model_path}") + shutil.copyfile(self.model_load_path, self.model_path) + def load_model(self): """Try to load model from disk or skip if not possible.""" + self.copy_model() try: self.model.load_model(self.model_path) logging.info(f"reload model {self.model_path} from disk ...")