From 22d23dbb96a6c9b13f369bcaa967f74cd3da7c5e Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Mon, 16 May 2022 10:11:08 +0200
Subject: [PATCH] early stopping and restore best weights are now ready to use

---
 mlair/configuration/defaults.py          |  5 ++++
 mlair/model_modules/keras_extensions.py  | 25 ++++++++++++-------
 mlair/plotting/training_monitoring.py    |  9 ++++++-
 mlair/run_modules/experiment_setup.py    | 15 ++++++++++--
 mlair/run_modules/model_setup.py         | 31 +++++++++++++++++++-----
 mlair/run_modules/post_processing.py     |  4 +--
 mlair/run_modules/training.py            | 11 ++++-----
 test/test_configuration/test_defaults.py |  2 ++
 8 files changed, 76 insertions(+), 26 deletions(-)

diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index ca569720..b630261d 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -2,6 +2,9 @@ __author__ = "Lukas Leufen"
 __date__ = '2020-06-25'
 
 
+import numpy as np
+
+
 DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
                         'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
@@ -24,6 +27,8 @@ DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY = False
 DEFAULT_PERMUTE_DATA = False
 DEFAULT_BATCH_SIZE = int(256 * 2)
 DEFAULT_EPOCHS = 20
+DEFAULT_EARLY_STOPPING_EPOCHS = np.inf
+DEFAULT_RESTORE_BEST_MODEL_WEIGHTS = True
 DEFAULT_TARGET_VAR = "o3"
 DEFAULT_TARGET_DIM = "variables"
 DEFAULT_WINDOW_LEAD_TIME = 3
diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index 72f40e45..39b0da5b 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -164,6 +164,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
         """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
         self.callbacks = kwargs.pop("callbacks")
         self.epoch_best = None
+        self.restore_best_weights = kwargs.pop("restore_best_weights", True)
         super().__init__(*args, **kwargs)
 
     def update_best(self, hist):
@@ -177,14 +178,19 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
 
         :param hist: The History object from the previous (interrupted) training.
         """
-        f = np.min if self.monitor_op.__name__ == "less" else np.max
-        f_loc = lambda x: np.where(x == f(x))[0][-1]
-        _d = hist.history.get(self.monitor)
-        loc = f_loc(_d)
-        assert f(_d) == _d[loc]
-        self.epoch_best = loc
-        self.best = _d[loc]
-        logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}")
+        if self.restore_best_weights:
+            f = np.min if self.monitor_op.__name__ == "less" else np.max
+            f_loc = lambda x: np.where(x == f(x))[0][-1]
+            _d = hist.history.get(self.monitor)
+            loc = f_loc(_d)
+            assert f(_d) == _d[loc]
+            self.epoch_best = loc
+            self.best = _d[loc]
+            logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}")
+        else:
+            _d = hist.history.get(self.monitor)[-1]
+            self.best = _d
+            logging.info(f"Set only best result ({self.monitor}={self.best}) without best epoch")
 
     def update_callbacks(self, callbacks):
         """
@@ -205,7 +211,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                 if self.save_best_only:
                     current = logs.get(self.monitor)
                     if current == self.best:
-                        self.epoch_best = epoch
+                        if self.restore_best_weights:
+                            self.epoch_best = epoch
                         if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py
index 4884dcb8..e651078f 100644
--- a/mlair/plotting/training_monitoring.py
+++ b/mlair/plotting/training_monitoring.py
@@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
 import pandas as pd
 
 from mlair.model_modules.keras_extensions import LearningRateDecay
+from mlair.helpers.helpers import relative_round
 
 # matplotlib.use('Agg')
 history_object = Union[Dict, keras.callbacks.History]
@@ -98,7 +99,13 @@ class PlotModelHistory:
         ax.set_yscale('log')
         if len(self._additional_columns) > 0:
             self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True)
-        title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}"
+        if self._epoch_best is not None:
+            final_res = self._data[[f'val_{self._plot_metric}']].min().values[0]
+            annotation = f"best epoch {self._epoch_best}"
+        else:
+            final_res = self._data[[f'val_{self._plot_metric}']].values[-1][0]
+            annotation = "final"
+        title = f"Model {self._plot_metric} (val, {annotation}): {relative_round(final_res, 5)}"
         ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title)
         ax.axhline(y=0, color="gray", linewidth=0.5)
         plt.tight_layout()
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index df797ffc..d807db14 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -23,7 +23,8 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING, \
     DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \
     DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \
-    DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE
+    DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE, DEFAULT_EARLY_STOPPING_EPOCHS, \
+    DEFAULT_RESTORE_BEST_MODEL_WEIGHTS
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
@@ -178,6 +179,11 @@ class ExperimentSetup(RunEnvironment):
         (partly) trained model is lower than this parameter, training is continue. In case this number is higher than
         the given epochs parameter, no training is resumed. Epochs is set to 20 per default, but this value is just a
         placeholder that should be adjusted for a meaningful training.
+    :param early_stopping_epochs: number of consecutive epochs with no improvement on val loss to stop training. When
+        set to `np.inf` or not providing at all, training is not stopped before reaching `epochs`.
+    :param restore_best_model_weights: indicates whether to use model state with best val loss (if True) or model state
+        on ending of training (if False). The later depends on the parameters `epochs` and `early_stopping_epochs` which
+        trigger stopping of training.
     :param data_handler:
     :param data_origin:
     :param competitors: Provide names of reference models trained by MLAir that can be found in the `competitor_path`.
@@ -221,7 +227,9 @@ class ExperimentSetup(RunEnvironment):
                  feature_importance_n_boots: int = None, feature_importance_create_new_bootstraps: bool = None,
                  feature_importance_bootstrap_method=None, feature_importance_bootstrap_type=None,
                  data_path: str = None, batch_path: str = None, login_nodes=None,
-                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
+                 hpc_hosts=None, model=None, batch_size=None, epochs=None,
+                 early_stopping_epochs: int = None, restore_best_model_weights: bool = None,
+                 data_handler=None,
                  data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
                  use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
                  max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None,
@@ -255,6 +263,9 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("permute_data", permute_data or upsampling, scope="train")
         self._set_param("batch_size", batch_size, default=DEFAULT_BATCH_SIZE)
         self._set_param("epochs", epochs, default=DEFAULT_EPOCHS)
+        self._set_param("early_stopping_epochs", early_stopping_epochs, default=DEFAULT_EARLY_STOPPING_EPOCHS)
+        self._set_param("restore_best_model_weights", restore_best_model_weights,
+                        default=DEFAULT_RESTORE_BEST_MODEL_WEIGHTS)
 
         # set experiment name
         sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING)  # always related to output sampling
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 4e9f8fa4..eab8012b 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -11,6 +11,7 @@ from dill.source import getsource
 import tensorflow.keras as keras
 import pandas as pd
 import tensorflow as tf
+import numpy as np
 
 from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler
 from mlair.run_modules.run_environment import RunEnvironment
@@ -117,18 +118,36 @@ class ModelSetup(RunEnvironment):
 
         Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added.
         """
-        lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
-        hist = HistoryAdvanced()
-        epo_timing = EpoTimingCallback()
-        self.data_store.set("hist", hist, scope="model")
-        self.data_store.set("epo_timing", epo_timing, scope="model")
+        # create callback handler
         callbacks = CallbackHandler()
+
+        # add callback: learning rate
+        lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
         if lr is not None:
             callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
+
+        # add callback: advanced history
+        hist = HistoryAdvanced()
+        self.data_store.set("hist", hist, scope="model")
         callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
+
+        # add callback: epo timing
+        epo_timing = EpoTimingCallback()
+        self.data_store.set("epo_timing", epo_timing, scope="model")
         callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing")
+
+        # add callback: early stopping
+        patience = self.data_store.get_default("early_stopping_epochs", default=np.inf)
+        restore_best_weights = self.data_store.get_default("restore_best_model_weights", default=True)
+        assert bool(isinstance(patience, int) or np.isinf(patience)) is True
+        cb = tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=restore_best_weights)
+        callbacks.add_callback(cb, self.callbacks_name % "early_stopping", "early_stopping")
+
+        # create model checkpoint
         callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
-                                          save_best_only=True, mode='auto')
+                                          save_best_only=True, mode='auto', restore_best_weights=restore_best_weights)
+
+        # store callbacks
         self.data_store.set("callbacks", callbacks, self.scope)
 
     def load_model(self):
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 7a31f83f..8a5aa98b 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -41,7 +41,7 @@ class PostProcessing(RunEnvironment):
         #. create plots
 
     Required objects [scope] from data store:
-        * `best_model` [.] or locally saved model plus `model_name` [model] and `model` [model]
+        * `model` [.] or locally saved model plus `model_name` [model] and `model` [model]
         * `generator` [train, val, test, train_val]
         * `forecast_path` [.]
         * `plot_path` [postprocessing]
@@ -479,7 +479,7 @@ class PostProcessing(RunEnvironment):
         :return: the model
         """
         try:  # is only available if a model was trained in training stage
-            model = self.data_store.get("best_model")
+            model = self.data_store.get("model")
         except NameNotFoundInDataStore:
             logging.info("No model was saved in data store. Try to load model from experiment path.")
             model_name = self.data_store.get("model_name", "model")
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index b0233eb4..cb9527ff 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -54,7 +54,7 @@ class Training(RunEnvironment):
         * `upsampling` [train, val, test]
 
     Sets
-        * `best_model` [.]
+        * `model` [.]
 
     Creates
         * `<exp_name>_model-best.h5`
@@ -177,16 +177,15 @@ class Training(RunEnvironment):
         except IndexError:
             epo_timing = None
         self.save_callbacks_as_json(history, lr, epo_timing)
-        self.load_best_model(checkpoint.filepath)
         self.create_monitoring_plots(history, lr, epoch_best)
 
     def save_model(self) -> None:
         """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`."""
         model_name = self.data_store.get("model_name", "model")
-        logging.debug(f"save best model to {model_name}")
-        self.model.save(model_name, save_format='h5')
-        self.model.save(model_name)
-        self.data_store.set("best_model", self.model)
+        logging.debug(f"save model to {model_name}")
+        self.model.save(model_name, save_format="h5")
+        self.model.save(model_name, save_format="tf")
+        self.data_store.set("model", self.model)
 
     def load_best_model(self, name: str) -> None:
         """
diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py
index f6bc6d24..07a5aa2f 100644
--- a/test/test_configuration/test_defaults.py
+++ b/test/test_configuration/test_defaults.py
@@ -21,6 +21,8 @@ class TestAllDefaults:
         assert DEFAULT_PERMUTE_DATA is False
         assert DEFAULT_BATCH_SIZE == int(256 * 2)
         assert DEFAULT_EPOCHS == 20
+        assert bool(np.isinf(DEFAULT_EARLY_STOPPING_EPOCHS)) is True
+        assert DEFAULT_RESTORE_BEST_MODEL_WEIGHTS is True
 
     def test_data_handler_parameters(self):
         assert DEFAULT_STATIONS == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
-- 
GitLab