diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index 4dc9521abf3569eb57249286e92c1e6a259c667d..7ecaad9cf077100f3b9a34b02c99e172d141a218 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -38,10 +38,12 @@ class AbstractModelClass(ABC): self._input_shape = input_shape self._output_shape = self.__extract_from_tuple(output_shape) - def load_model(self, name: str): + def load_model(self, name: str, compile: bool = False): hist = self.model.history self.model = keras.models.load_model(name) self.model.history = hist + if compile is True: + self.model.compile(**self.compile_options) def __getattr__(self, name: str) -> Any: """ diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py index b2b531b99c85bb43e4e758fd23045c9f0575cb24..39dd80651226519463d7b503fb612e43983d73cf 100644 --- a/mlair/plotting/training_monitoring.py +++ b/mlair/plotting/training_monitoring.py @@ -45,15 +45,18 @@ class PlotModelHistory: self._additional_columns = self._filter_columns(history) self._plot(filename) - @staticmethod - def _get_plot_metric(history, plot_metric, main_branch): - if plot_metric.lower() == "mse": - plot_metric = "mean_squared_error" - elif plot_metric.lower() == "mae": - plot_metric = "mean_absolute_error" + def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True): + _plot_metric = plot_metric + if correct_names is True: + if plot_metric.lower() == "mse": + plot_metric = "mean_squared_error" + elif plot_metric.lower() == "mae": + plot_metric = "mean_absolute_error" available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k.lower() if main_branch else True)] available_keys.sort(key=len) + if len(available_keys) == 0 and correct_names is True: + return self._get_plot_metric(history, _plot_metric, main_branch, correct_names=False) return available_keys[0] def _filter_columns(self, history: Dict) -> List[str]: diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 0d875766926e870349337a0597e2b3612a93ee07..c076253d92a0e24f419046805687d2a80143176c 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -149,7 +149,7 @@ class Training(RunEnvironment): logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") self.callbacks.load_callbacks() self.callbacks.update_checkpoint() - self.model.load_model(checkpoint.filepath) + self.model.load_model(checkpoint.filepath, compile=True) hist: History = self.callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 _ = self.model.fit(self.train_set, @@ -190,8 +190,8 @@ class Training(RunEnvironment): """ logging.debug(f"load best model: {name}") try: - self.model.load_model(name) - logging.info('reload weights...') + self.model.load_model(name, compile=True) + logging.info('reload model...') except OSError: logging.info('no weights to reload...') @@ -236,9 +236,11 @@ class Training(RunEnvironment): if multiple_branches_used: filename = os.path.join(path, f"{name}_history_main_loss.pdf") PlotModelHistory(filename=filename, history=history, main_branch=True) - if len([e for e in history.model.metrics_names if "mean_squared_error" in e]) > 0: + mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"])) + if len(mse_indicator) > 0: filename = os.path.join(path, f"{name}_history_main_mse.pdf") - PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used) + PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0], + main_branch=multiple_branches_used) # plot learning rate if lr_sc: diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py index 784f653fbfb2eb4c78e6e858acf67cd0ae47a593..47aa9b970c0e95ccadb60e8c090136c0fa6ceea4 100644 --- a/run_mixed_sampling.py +++ b/run_mixed_sampling.py @@ -4,8 +4,8 @@ __date__ = '2019-11-14' import argparse from mlair.workflows import DefaultWorkflow -from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \ - DataHandlerSeparationOfScales +from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling + stats = {'o3': 'dma8eu', 'no': 'dma8eu', 'no2': 'dma8eu', 'relhum': 'average_values', 'u': 'average_values', 'v': 'average_values', @@ -20,7 +20,7 @@ data_origin = {'o3': '', 'no': '', 'no2': '', def main(parser_args): args = dict(stations=["DEBW107", "DEBW013"], network="UBA", - evaluate_feature_importance=False, plot_list=[], + evaluate_feature_importance=True, # plot_list=[], data_origin=data_origin, data_handler=DataHandlerMixedSampling, interpolation_limit=(3, 1), overwrite_local_data=False, sampling=("hourly", "daily"), @@ -28,8 +28,6 @@ def main(parser_args): create_new_model=True, train_model=False, epochs=1, window_history_size=6 * 24 + 16, window_history_offset=16, - kz_filter_length=[100 * 24, 15 * 24], - kz_filter_iter=[4, 5], start="2006-01-01", train_start="2006-01-01", end="2011-12-31",