From b3bd487ccbc458bf8f8f1d53e0a708821ac0cc1e Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 26 Nov 2021 10:26:48 +0100
Subject: [PATCH] load model in postprocessing did not work properly

---
 mlair/run_modules/post_processing.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index cb5b2245..2adb3524 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -68,7 +68,7 @@ class PostProcessing(RunEnvironment):
     def __init__(self):
         """Initialise and run post-processing."""
         super().__init__()
-        self.model: keras.Model = self._load_model()
+        self.model: AbstractModelClass = self._load_model()
         self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0]
         self.ols_model = None
         self.batch_size: int = self.data_store.get_default("batch_size", "model", 64)
@@ -430,8 +430,8 @@ class PostProcessing(RunEnvironment):
         except NameNotFoundInDataStore:
             logging.info("No model was saved in data store. Try to load model from experiment path.")
             model_name = self.data_store.get("model_name", "model")
-            model_class: AbstractModelClass = self.data_store.get("model", "model")
-            model = keras.models.load_model(model_name, custom_objects=model_class.custom_objects)
+            model: AbstractModelClass = self.data_store.get("model", "model")
+            model.load_model(model_name)
         return model
 
     # noinspection PyBroadException
-- 
GitLab