diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 338120591ce4d05b86c208dc6672a5e51a48d86f..1a8237e31a6782cf7e53a088012cb67a1a125747 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -62,7 +62,7 @@ def plot_climsum_boxplot(): return -def station_map(generators, plot_folder="."): +def plot_station_map(generators, plot_folder="."): logging.debug("run station_map()") fig = plt.figure(figsize=(10, 5)) diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 43625bdf29a644e7cc4f14ec8b6d73ebbbd117bd..2f1949d1e5dad47339265adec3cd25c721071cef 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -71,7 +71,7 @@ class ModelSetup(RunEnvironment): def build_model(self): args_list = ["activation", "window_history_size", "channels", "regularizer", "dropout_rate", "window_lead_time"] args = self.data_store.create_args_dict(args_list, self.scope) - self.model = my_little_model(**args) + self.model: keras.Model = my_little_model(**args) def plot_model(self): # pragma: no cover with tf.device("/cpu:0"): diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 7ab6abf9857c647617ec42716d471412e9c81392..b87c637055b594005cecbfa391e137ce103e0f32 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import xarray as xr import statsmodels.api as sm +import keras from src.run_modules.run_environment import RunEnvironment from src.data_handling.data_distributor import Distributor @@ -17,20 +18,21 @@ from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src import statistics from src import helpers from src.helpers import TimeTracking -from src.plotting.postprocessing_plotting import plot_monthly_summary, plot_climsum_boxplot, station_map, plot_conditional_quantiles +from src.plotting.postprocessing_plotting import plot_monthly_summary, plot_climsum_boxplot, plot_station_map, plot_conditional_quantiles +from src.datastore import NameNotFoundInDataStore class PostProcessing(RunEnvironment): def __init__(self): super().__init__() - self.model = self.data_store.get("best_model", "general") + self.model: keras.Model = self._load_model() self.ols_model = None - self.batch_size = self.data_store.get("batch_size", "general.model") + self.batch_size: int = self.data_store.get("batch_size", "general.model") self.test_data: DataGenerator = self.data_store.get("generator", "general.test") self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size) self.train_data: DataGenerator = self.data_store.get("generator", "general.train") - self.plot_path = self.data_store.get("plot_path", "general") + self.plot_path: str = self.data_store.get("plot_path", "general") self._run() def _run(self): @@ -38,6 +40,17 @@ class PostProcessing(RunEnvironment): preds_for_all_stations = self.make_prediction() self.plot() + def _load_model(self): + try: + model = self.data_store.get("best_model", "general") + except NameNotFoundInDataStore: + logging.info("no model saved in data store. trying to load model from experiment") + path = self.data_store.get("experiment_path", "general") + name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5" + model_name = os.path.join(path, name) + model = keras.models.load_model(model_name) + return model + def plot(self): logging.debug("Run plotting routines...") path = self.data_store.get("forecast_path", "general") @@ -45,14 +58,14 @@ class PostProcessing(RunEnvironment): target_var = self.data_store.get("target_var", "general") plot_conditional_quantiles(self.test_data.stations, plot_folder=self.plot_path, forecast_path=self.data_store.get("forecast_path", "general")) - # station_map(generators={'b': self.test_data}, plot_folder=self.plot_path) - # plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", window_lead_time, target_var, - # plot_folder=self.plot_path) + plot_station_map(generators={'b': self.test_data}, plot_folder=self.plot_path) + plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", window_lead_time, target_var, + plot_folder=self.plot_path) # plot_climsum_boxplot() def calculate_test_score(self): - test_score = self.model.evaluate(generator=self.test_data_distributed.distribute_on_batches(), - use_multiprocessing=False, verbose=0, steps=1) + test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(), + use_multiprocessing=False, verbose=0, steps=1) logging.info(f"test score = {test_score}") self._save_test_score(test_score) diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 53a505484f4e3d59ef37a37f3d0c6996059485d9..aa4665bf592071cce49c25eebba609bdc69fdd5b 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -14,7 +14,7 @@ class Training(RunEnvironment): def __init__(self): super().__init__() - self.model = self.data_store.get("model", "general.model") + self.model: keras.Model = self.data_store.get("model", "general.model") self.train_set = None self.val_set = None self.test_set = None