diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index cb5b22452f0fdd698f6ca5a6d192e8671b61377d..2adb352428a317922571b96afdcab4dfa966ab04 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -68,7 +68,7 @@ class PostProcessing(RunEnvironment): def __init__(self): """Initialise and run post-processing.""" super().__init__() - self.model: keras.Model = self._load_model() + self.model: AbstractModelClass = self._load_model() self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] self.ols_model = None self.batch_size: int = self.data_store.get_default("batch_size", "model", 64) @@ -430,8 +430,8 @@ class PostProcessing(RunEnvironment): except NameNotFoundInDataStore: logging.info("No model was saved in data store. Try to load model from experiment path.") model_name = self.data_store.get("model_name", "model") - model_class: AbstractModelClass = self.data_store.get("model", "model") - model = keras.models.load_model(model_name, custom_objects=model_class.custom_objects) + model: AbstractModelClass = self.data_store.get("model", "model") + model.load_model(model_name) return model # noinspection PyBroadException