Skip to content
Snippets Groups Projects
Commit e096b689 authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue343_bug_fix-bugs-caused-by-model-name-refac-and-tf-update' into 'develop'

Resolve "fix bugs caused by model name refac and tf update"

See merge request !366
parents 5afc1a70 b7dc46f1
No related branches found
No related tags found
4 merge requests!413update release branch,!412Resolve "release v2.0.0",!367Develop,!366Resolve "fix bugs caused by model name refac and tf update"
Pipeline #84912 passed
......@@ -38,9 +38,9 @@ class AbstractModelClass(ABC):
self._input_shape = input_shape
self._output_shape = self.__extract_from_tuple(output_shape)
def load_model(self, name: str, compile: bool = False):
def load_model(self, name: str, compile: bool = False) -> None:
hist = self.model.history
self.model = keras.models.load_model(name)
self.model.load_weights(name)
self.model.history = hist
if compile is True:
self.model.compile(**self.compile_options)
......
......@@ -421,7 +421,7 @@ class PostProcessing(RunEnvironment):
"""Return model name without path information."""
return self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0]
def _load_model(self) -> keras.models:
def _load_model(self) -> AbstractModelClass:
"""
Load NN model either from data store or from local path.
......@@ -907,10 +907,11 @@ class PostProcessing(RunEnvironment):
errors = {}
for station in all_stations:
external_data = self._get_external_data(station, path) # test data
external_data.coords[self.model_type_dim] = [{self.forecast_indicator: self.model_display_name}.get(n, n)
for n in external_data.coords[self.model_type_dim].values]
# test errors
if external_data is not None:
external_data.coords[self.model_type_dim] = [{self.forecast_indicator: self.model_display_name}.get(n, n)
for n in external_data.coords[self.model_type_dim].values]
model_type_list = external_data.coords[self.model_type_dim].values.tolist()
for model_type in remove_items(model_type_list, self.observation_indicator):
if model_type not in errors.keys():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment