diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 0a61ee4f07d0c6eccf698aa16d3de9d7275e75f6..07c257d7216fc340ec3493b7c2ff7bae7895e356 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -17,6 +17,7 @@ from src.data_handling.bootstraps import BootStraps from src.datastore import NameNotFoundInDataStore from src.helpers import TimeTracking from src.model_modules.linear_model import OrdinaryLeastSquaredModel +from src.model_modules.model_class import AbstractModelClass from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore from src.plotting.postprocessing_plotting import plot_conditional_quantiles @@ -117,7 +118,8 @@ class PostProcessing(RunEnvironment): except NameNotFoundInDataStore: logging.info("no model saved in data store. trying to load model from experiment path") model_name = self.data_store.get("model_name", "general.model") - model = keras.models.load_model(model_name) + model_class: AbstractModelClass = self.data_store.get("model", "general.model") + model = keras.models.load_model(model_name, custom_objects=model_class.custom_objects) return model def plot(self):