From 22d23dbb96a6c9b13f369bcaa967f74cd3da7c5e Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Mon, 16 May 2022 10:11:08 +0200 Subject: [PATCH] early stopping and restore best weights are now ready to use --- mlair/configuration/defaults.py | 5 ++++ mlair/model_modules/keras_extensions.py | 25 ++++++++++++------- mlair/plotting/training_monitoring.py | 9 ++++++- mlair/run_modules/experiment_setup.py | 15 ++++++++++-- mlair/run_modules/model_setup.py | 31 +++++++++++++++++++----- mlair/run_modules/post_processing.py | 4 +-- mlair/run_modules/training.py | 11 ++++----- test/test_configuration/test_defaults.py | 2 ++ 8 files changed, 76 insertions(+), 26 deletions(-) diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index ca569720..b630261d 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -2,6 +2,9 @@ __author__ = "Lukas Leufen" __date__ = '2020-06-25' +import numpy as np + + DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', @@ -24,6 +27,8 @@ DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY = False DEFAULT_PERMUTE_DATA = False DEFAULT_BATCH_SIZE = int(256 * 2) DEFAULT_EPOCHS = 20 +DEFAULT_EARLY_STOPPING_EPOCHS = np.inf +DEFAULT_RESTORE_BEST_MODEL_WEIGHTS = True DEFAULT_TARGET_VAR = "o3" DEFAULT_TARGET_DIM = "variables" DEFAULT_WINDOW_LEAD_TIME = 3 diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index 72f40e45..39b0da5b 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -164,6 +164,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint): """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" self.callbacks = kwargs.pop("callbacks") self.epoch_best = None + self.restore_best_weights = kwargs.pop("restore_best_weights", True) super().__init__(*args, **kwargs) def update_best(self, hist): @@ -177,14 +178,19 @@ class ModelCheckpointAdvanced(ModelCheckpoint): :param hist: The History object from the previous (interrupted) training. """ - f = np.min if self.monitor_op.__name__ == "less" else np.max - f_loc = lambda x: np.where(x == f(x))[0][-1] - _d = hist.history.get(self.monitor) - loc = f_loc(_d) - assert f(_d) == _d[loc] - self.epoch_best = loc - self.best = _d[loc] - logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}") + if self.restore_best_weights: + f = np.min if self.monitor_op.__name__ == "less" else np.max + f_loc = lambda x: np.where(x == f(x))[0][-1] + _d = hist.history.get(self.monitor) + loc = f_loc(_d) + assert f(_d) == _d[loc] + self.epoch_best = loc + self.best = _d[loc] + logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}") + else: + _d = hist.history.get(self.monitor)[-1] + self.best = _d + logging.info(f"Set only best result ({self.monitor}={self.best}) without best epoch") def update_callbacks(self, callbacks): """ @@ -205,7 +211,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.save_best_only: current = logs.get(self.monitor) if current == self.best: - self.epoch_best = epoch + if self.restore_best_weights: + self.epoch_best = epoch if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py index 4884dcb8..e651078f 100644 --- a/mlair/plotting/training_monitoring.py +++ b/mlair/plotting/training_monitoring.py @@ -11,6 +11,7 @@ import matplotlib.pyplot as plt import pandas as pd from mlair.model_modules.keras_extensions import LearningRateDecay +from mlair.helpers.helpers import relative_round # matplotlib.use('Agg') history_object = Union[Dict, keras.callbacks.History] @@ -98,7 +99,13 @@ class PlotModelHistory: ax.set_yscale('log') if len(self._additional_columns) > 0: self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True) - title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}" + if self._epoch_best is not None: + final_res = self._data[[f'val_{self._plot_metric}']].min().values[0] + annotation = f"best epoch {self._epoch_best}" + else: + final_res = self._data[[f'val_{self._plot_metric}']].values[-1][0] + annotation = "final" + title = f"Model {self._plot_metric} (val, {annotation}): {relative_round(final_res, 5)}" ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title) ax.axhline(y=0, color="gray", linewidth=0.5) plt.tight_layout() diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index df797ffc..d807db14 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -23,7 +23,8 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING, \ DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \ DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \ - DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE + DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE, DEFAULT_EARLY_STOPPING_EPOCHS, \ + DEFAULT_RESTORE_BEST_MODEL_WEIGHTS from mlair.data_handler import DefaultDataHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel @@ -178,6 +179,11 @@ class ExperimentSetup(RunEnvironment): (partly) trained model is lower than this parameter, training is continue. In case this number is higher than the given epochs parameter, no training is resumed. Epochs is set to 20 per default, but this value is just a placeholder that should be adjusted for a meaningful training. + :param early_stopping_epochs: number of consecutive epochs with no improvement on val loss to stop training. When + set to `np.inf` or not providing at all, training is not stopped before reaching `epochs`. + :param restore_best_model_weights: indicates whether to use model state with best val loss (if True) or model state + on ending of training (if False). The later depends on the parameters `epochs` and `early_stopping_epochs` which + trigger stopping of training. :param data_handler: :param data_origin: :param competitors: Provide names of reference models trained by MLAir that can be found in the `competitor_path`. @@ -221,7 +227,9 @@ class ExperimentSetup(RunEnvironment): feature_importance_n_boots: int = None, feature_importance_create_new_bootstraps: bool = None, feature_importance_bootstrap_method=None, feature_importance_bootstrap_type=None, data_path: str = None, batch_path: str = None, login_nodes=None, - hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, + hpc_hosts=None, model=None, batch_size=None, epochs=None, + early_stopping_epochs: int = None, restore_best_model_weights: bool = None, + data_handler=None, data_origin: Dict = None, competitors: list = None, competitor_path: str = None, use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None, max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, @@ -255,6 +263,9 @@ class ExperimentSetup(RunEnvironment): self._set_param("permute_data", permute_data or upsampling, scope="train") self._set_param("batch_size", batch_size, default=DEFAULT_BATCH_SIZE) self._set_param("epochs", epochs, default=DEFAULT_EPOCHS) + self._set_param("early_stopping_epochs", early_stopping_epochs, default=DEFAULT_EARLY_STOPPING_EPOCHS) + self._set_param("restore_best_model_weights", restore_best_model_weights, + default=DEFAULT_RESTORE_BEST_MODEL_WEIGHTS) # set experiment name sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) # always related to output sampling diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 4e9f8fa4..eab8012b 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -11,6 +11,7 @@ from dill.source import getsource import tensorflow.keras as keras import pandas as pd import tensorflow as tf +import numpy as np from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler from mlair.run_modules.run_environment import RunEnvironment @@ -117,18 +118,36 @@ class ModelSetup(RunEnvironment): Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added. """ - lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) - hist = HistoryAdvanced() - epo_timing = EpoTimingCallback() - self.data_store.set("hist", hist, scope="model") - self.data_store.set("epo_timing", epo_timing, scope="model") + # create callback handler callbacks = CallbackHandler() + + # add callback: learning rate + lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) if lr is not None: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") + + # add callback: advanced history + hist = HistoryAdvanced() + self.data_store.set("hist", hist, scope="model") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") + + # add callback: epo timing + epo_timing = EpoTimingCallback() + self.data_store.set("epo_timing", epo_timing, scope="model") callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing") + + # add callback: early stopping + patience = self.data_store.get_default("early_stopping_epochs", default=np.inf) + restore_best_weights = self.data_store.get_default("restore_best_model_weights", default=True) + assert bool(isinstance(patience, int) or np.isinf(patience)) is True + cb = tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=restore_best_weights) + callbacks.add_callback(cb, self.callbacks_name % "early_stopping", "early_stopping") + + # create model checkpoint callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', - save_best_only=True, mode='auto') + save_best_only=True, mode='auto', restore_best_weights=restore_best_weights) + + # store callbacks self.data_store.set("callbacks", callbacks, self.scope) def load_model(self): diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 7a31f83f..8a5aa98b 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -41,7 +41,7 @@ class PostProcessing(RunEnvironment): #. create plots Required objects [scope] from data store: - * `best_model` [.] or locally saved model plus `model_name` [model] and `model` [model] + * `model` [.] or locally saved model plus `model_name` [model] and `model` [model] * `generator` [train, val, test, train_val] * `forecast_path` [.] * `plot_path` [postprocessing] @@ -479,7 +479,7 @@ class PostProcessing(RunEnvironment): :return: the model """ try: # is only available if a model was trained in training stage - model = self.data_store.get("best_model") + model = self.data_store.get("model") except NameNotFoundInDataStore: logging.info("No model was saved in data store. Try to load model from experiment path.") model_name = self.data_store.get("model_name", "model") diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index b0233eb4..cb9527ff 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -54,7 +54,7 @@ class Training(RunEnvironment): * `upsampling` [train, val, test] Sets - * `best_model` [.] + * `model` [.] Creates * `<exp_name>_model-best.h5` @@ -177,16 +177,15 @@ class Training(RunEnvironment): except IndexError: epo_timing = None self.save_callbacks_as_json(history, lr, epo_timing) - self.load_best_model(checkpoint.filepath) self.create_monitoring_plots(history, lr, epoch_best) def save_model(self) -> None: """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`.""" model_name = self.data_store.get("model_name", "model") - logging.debug(f"save best model to {model_name}") - self.model.save(model_name, save_format='h5') - self.model.save(model_name) - self.data_store.set("best_model", self.model) + logging.debug(f"save model to {model_name}") + self.model.save(model_name, save_format="h5") + self.model.save(model_name, save_format="tf") + self.data_store.set("model", self.model) def load_best_model(self, name: str) -> None: """ diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py index f6bc6d24..07a5aa2f 100644 --- a/test/test_configuration/test_defaults.py +++ b/test/test_configuration/test_defaults.py @@ -21,6 +21,8 @@ class TestAllDefaults: assert DEFAULT_PERMUTE_DATA is False assert DEFAULT_BATCH_SIZE == int(256 * 2) assert DEFAULT_EPOCHS == 20 + assert bool(np.isinf(DEFAULT_EARLY_STOPPING_EPOCHS)) is True + assert DEFAULT_RESTORE_BEST_MODEL_WEIGHTS is True def test_data_handler_parameters(self): assert DEFAULT_STATIONS == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] -- GitLab