From febd62e60e546be868e4c6b78ee3d4642d8279a4 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 13 Mar 2020 11:21:21 +0100
Subject: [PATCH] added custom_object loading to post-processing

---
 src/run_modules/post_processing.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 0a61ee4f..07c257d7 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -17,6 +17,7 @@ from src.data_handling.bootstraps import BootStraps
 from src.datastore import NameNotFoundInDataStore
 from src.helpers import TimeTracking
 from src.model_modules.linear_model import OrdinaryLeastSquaredModel
+from src.model_modules.model_class import AbstractModelClass
 from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
     PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore
 from src.plotting.postprocessing_plotting import plot_conditional_quantiles
@@ -117,7 +118,8 @@ class PostProcessing(RunEnvironment):
         except NameNotFoundInDataStore:
             logging.info("no model saved in data store. trying to load model from experiment path")
             model_name = self.data_store.get("model_name", "general.model")
-            model = keras.models.load_model(model_name)
+            model_class: AbstractModelClass = self.data_store.get("model", "general.model")
+            model = keras.models.load_model(model_name, custom_objects=model_class.custom_objects)
         return model
 
     def plot(self):
-- 
GitLab