diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 8bbcfddf741c657af0d0e26f7c3937058e83f821..c12664e356d30bfc23333c3888d750bf245f4564 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -244,7 +244,7 @@ class ExperimentSetup(RunEnvironment): do_uncertainty_estimate: bool = None, do_bias_free_evaluation: bool = None, model_display_name: str = None, transformation_file: str = None, calculate_fresh_transformation: bool = None, snapshot_load_path: str = None, - create_snapshot: bool = None, snapshot_path: str = None, **kwargs): + create_snapshot: bool = None, snapshot_path: str = None, model_path: str = None, **kwargs): # create run framework super().__init__() @@ -299,6 +299,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data")) # set model path + self._set_param("model_load_path", model_path) self._set_param("model_path", None, os.path.join(experiment_path, "model")) path_config.check_path_and_create(self.data_store.get("model_path")) diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index efeff06260d1df41d4a83592c225f316574e77eb..3164db789e79a92364b300b5d8724861031afeb8 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -6,6 +6,8 @@ __date__ = '2019-12-02' import logging import os import re +import shutil + from dill.source import getsource import tensorflow.keras as keras @@ -58,15 +60,16 @@ class ModelSetup(RunEnvironment): """Initialise and run model setup.""" super().__init__() self.model = None - exp_name = self.data_store.get("experiment_name") - self.path = self.data_store.get("model_path") self.scope = "model" - path = os.path.join(self.path, f"{exp_name}_%s") - self.model_name = path % "%s.h5" - self.checkpoint_name = path % "model-best.h5" - self.callbacks_name = path % "model-best-callbacks-%s.pickle" self._train_model = self.data_store.get("train_model") self._create_new_model = self.data_store.get("create_new_model") + self.path = self.data_store.get("model_path") + self.model_display_name = self.data_store.get_default("model_display_name", default=None) + self.model_load_path = None + path = self._set_model_path() + self.model_path = path % "%s.h5" + self.checkpoint_name = path % "_model-best.h5" + self.callbacks_name = path % "_model-best-callbacks-%s.pickle" self._run() def _run(self): @@ -96,6 +99,20 @@ class ModelSetup(RunEnvironment): # report settings self.report_model() + def _set_model_path(self): + exp_name = self.data_store.get("experiment_name") + self.model_load_path = self.data_store.get_default("model_load_path", default=None) + if self.model_load_path is not None: + if not self.model_load_path.endswith(".h5"): + raise FileNotFoundError(f"When providing external models, you need to provide full path including the " + f".h5 file. Given path is not valid: {self.model_load_path}") + if any([self._train_model, self._create_new_model]): + raise ValueError(f"Providing `model_path` along with parameters train_model={self._train_model} and " + f"create_new_model={self._create_new_model} is not possible. Either set both " + f"parameters to False or remove `model_path` parameter. Given was: model_path = " + f"{self.model_load_path}") + return os.path.join(self.path, f"{self.model_display_name or exp_name}%s") + def _set_shapes(self): """Set input and output shapes from train collection.""" shape = list(map(lambda x: x.shape[1:], self.data_store.get("data_collection", "train")[0].get_X())) @@ -159,11 +176,18 @@ class ModelSetup(RunEnvironment): # store callbacks self.data_store.set("callbacks", callbacks, self.scope) + def copy_model(self): + """Copy external model to internal experiment structure.""" + if self.model_load_path is not None: + logging.info(f"Copy external model file: {self.model_load_path} -> {self.model_path}") + shutil.copyfile(self.model_load_path, self.model_path) + def load_model(self): """Try to load model from disk or skip if not possible.""" + self.copy_model() try: - self.model.load_model(self.model_name) - logging.info(f"reload model {self.model_name} from disk ...") + self.model.load_model(self.model_path) + logging.info(f"reload model {self.model_path} from disk ...") except OSError: logging.info('no local model to load...') @@ -195,14 +219,16 @@ class ModelSetup(RunEnvironment): """Load all model settings and store in data store.""" model_settings = self.model.get_settings() self.data_store.set_from_dict(model_settings, self.scope, log=True) - self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") - self.data_store.set("model_name", self.model_name, self.scope) + generic_model_name = self.data_store.get_default("model_name", self.scope, "my_model") + model_annotation = generic_model_name if self.model_display_name is None else "" + self.model_path = self.model_path % model_annotation + self.data_store.set("model_name", self.model_path, self.scope) def plot_model(self): # pragma: no cover """Plot model architecture as `<model_name>.pdf`.""" try: with tf.device("/cpu:0"): - file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" + file_name = f"{self.model_path.rsplit('.', 1)[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) except Exception as e: logging.info(f"Can not plot model due to: {e}")