diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index 7ecaad9cf077100f3b9a34b02c99e172d141a218..4a323f46ff95a7ca66c157f2e4d6d3184f244a4a 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -38,9 +38,9 @@ class AbstractModelClass(ABC): self._input_shape = input_shape self._output_shape = self.__extract_from_tuple(output_shape) - def load_model(self, name: str, compile: bool = False): + def load_model(self, name: str, compile: bool = False) -> None: hist = self.model.history - self.model = keras.models.load_model(name) + self.model.load_weights(name) self.model.history = hist if compile is True: self.model.compile(**self.compile_options) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 5e1e585e2114fe29cf01bb89bbbccb17d7bfa4bf..9d03f47172d80b2d06e3ea6f10f44b076883c9ef 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -421,7 +421,7 @@ class PostProcessing(RunEnvironment): """Return model name without path information.""" return self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] - def _load_model(self) -> keras.models: + def _load_model(self) -> AbstractModelClass: """ Load NN model either from data store or from local path. @@ -907,10 +907,11 @@ class PostProcessing(RunEnvironment): errors = {} for station in all_stations: external_data = self._get_external_data(station, path) # test data - external_data.coords[self.model_type_dim] = [{self.forecast_indicator: self.model_display_name}.get(n, n) - for n in external_data.coords[self.model_type_dim].values] + # test errors if external_data is not None: + external_data.coords[self.model_type_dim] = [{self.forecast_indicator: self.model_display_name}.get(n, n) + for n in external_data.coords[self.model_type_dim].values] model_type_list = external_data.coords[self.model_type_dim].values.tolist() for model_type in remove_items(model_type_list, self.observation_indicator): if model_type not in errors.keys():