diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index ca569720dc41d95621d0613a2170cc4d9d46c082..b630261dbf58d7402f8c3cacaee153347ad4f1e3 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -2,6 +2,9 @@ __author__ = "Lukas Leufen" __date__ = '2020-06-25' +import numpy as np + + DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', @@ -24,6 +27,8 @@ DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY = False DEFAULT_PERMUTE_DATA = False DEFAULT_BATCH_SIZE = int(256 * 2) DEFAULT_EPOCHS = 20 +DEFAULT_EARLY_STOPPING_EPOCHS = np.inf +DEFAULT_RESTORE_BEST_MODEL_WEIGHTS = True DEFAULT_TARGET_VAR = "o3" DEFAULT_TARGET_DIM = "variables" DEFAULT_WINDOW_LEAD_TIME = 3 diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 8104c7c50517e05be14b05aaa9cea8d0e5ba32f4..b583cf7dc473db96181f88b0ab26e60ee225240d 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -122,6 +122,21 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce return round_type(number * multiplier) / multiplier +def relative_round(x: float, sig: int) -> float: + """ + Round small numbers according to given "significance". + + Example: relative_round(0.03112, 2) -> 0.031, relative_round(0.03112, 1) -> 0.03 + + :params x: number to round + :params sig: "significance" to determine number of decimals + + :return: rounded number + """ + assert sig >= 1 + return round(x, sig-int(np.floor(np.log10(abs(x))))-1) + + def remove_items(obj: Union[List, Dict, Tuple], items: Any): """ Remove item(s) from either list, tuple or dictionary. diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index 8b99acd0f5723d3b00ec1bd0098712753da21b52..39b0da5b49f470d11ea64b9ddd344b9ad11e2b7f 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -163,6 +163,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint): def __init__(self, *args, **kwargs): """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" self.callbacks = kwargs.pop("callbacks") + self.epoch_best = None + self.restore_best_weights = kwargs.pop("restore_best_weights", True) super().__init__(*args, **kwargs) def update_best(self, hist): @@ -176,7 +178,19 @@ class ModelCheckpointAdvanced(ModelCheckpoint): :param hist: The History object from the previous (interrupted) training. """ - self.best = hist.history.get(self.monitor)[-1] + if self.restore_best_weights: + f = np.min if self.monitor_op.__name__ == "less" else np.max + f_loc = lambda x: np.where(x == f(x))[0][-1] + _d = hist.history.get(self.monitor) + loc = f_loc(_d) + assert f(_d) == _d[loc] + self.epoch_best = loc + self.best = _d[loc] + logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}") + else: + _d = hist.history.get(self.monitor)[-1] + self.best = _d + logging.info(f"Set only best result ({self.monitor}={self.best}) without best epoch") def update_callbacks(self, callbacks): """ @@ -197,6 +211,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.save_best_only: current = logs.get(self.monitor) if current == self.best: + if self.restore_best_weights: + self.epoch_best = epoch if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py index 39dd80651226519463d7b503fb612e43983d73cf..e651078fe66a5d95e6e01d5384865eee621ba6f4 100644 --- a/mlair/plotting/training_monitoring.py +++ b/mlair/plotting/training_monitoring.py @@ -11,6 +11,7 @@ import matplotlib.pyplot as plt import pandas as pd from mlair.model_modules.keras_extensions import LearningRateDecay +from mlair.helpers.helpers import relative_round # matplotlib.use('Agg') history_object = Union[Dict, keras.callbacks.History] @@ -27,7 +28,8 @@ class PlotModelHistory: parameter filename must include the absolute path for the plot. """ - def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False): + def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False, + epoch_best: int = None): """ Set attributes and create plot. @@ -37,12 +39,15 @@ class PlotModelHistory: :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss) :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default: False -> look for losses from all branches, not only from main) + :param epoch_best: indicator at which epoch the best train result was achieved (should start counting at 0) """ if isinstance(history, keras.callbacks.History): history = history.history self._data = pd.DataFrame.from_dict(history) + self._data.index += 1 self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch) self._additional_columns = self._filter_columns(history) + self._epoch_best = epoch_best self._plot(filename) def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True): @@ -88,10 +93,19 @@ class PlotModelHistory: :param filename: name (including total path) of the plot to save. """ ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7) + if self._epoch_best is not None: + ax.scatter(self._epoch_best+1, self._data[[f"val_{self._plot_metric}"]].iloc[self._epoch_best], + s=100, marker="*", c="black") ax.set_yscale('log') if len(self._additional_columns) > 0: self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True) - title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}" + if self._epoch_best is not None: + final_res = self._data[[f'val_{self._plot_metric}']].min().values[0] + annotation = f"best epoch {self._epoch_best}" + else: + final_res = self._data[[f'val_{self._plot_metric}']].values[-1][0] + annotation = "final" + title = f"Model {self._plot_metric} (val, {annotation}): {relative_round(final_res, 5)}" ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title) ax.axhline(y=0, color="gray", linewidth=0.5) plt.tight_layout() diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index df797ffc23370bf4f45bb2b4f76e5f71e9bd030f..d807db14c96a4a30fde791e54c8b1b32e519fb9c 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -23,7 +23,8 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING, \ DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \ DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \ - DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE + DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE, DEFAULT_EARLY_STOPPING_EPOCHS, \ + DEFAULT_RESTORE_BEST_MODEL_WEIGHTS from mlair.data_handler import DefaultDataHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel @@ -178,6 +179,11 @@ class ExperimentSetup(RunEnvironment): (partly) trained model is lower than this parameter, training is continue. In case this number is higher than the given epochs parameter, no training is resumed. Epochs is set to 20 per default, but this value is just a placeholder that should be adjusted for a meaningful training. + :param early_stopping_epochs: number of consecutive epochs with no improvement on val loss to stop training. When + set to `np.inf` or not providing at all, training is not stopped before reaching `epochs`. + :param restore_best_model_weights: indicates whether to use model state with best val loss (if True) or model state + on ending of training (if False). The later depends on the parameters `epochs` and `early_stopping_epochs` which + trigger stopping of training. :param data_handler: :param data_origin: :param competitors: Provide names of reference models trained by MLAir that can be found in the `competitor_path`. @@ -221,7 +227,9 @@ class ExperimentSetup(RunEnvironment): feature_importance_n_boots: int = None, feature_importance_create_new_bootstraps: bool = None, feature_importance_bootstrap_method=None, feature_importance_bootstrap_type=None, data_path: str = None, batch_path: str = None, login_nodes=None, - hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, + hpc_hosts=None, model=None, batch_size=None, epochs=None, + early_stopping_epochs: int = None, restore_best_model_weights: bool = None, + data_handler=None, data_origin: Dict = None, competitors: list = None, competitor_path: str = None, use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None, max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, @@ -255,6 +263,9 @@ class ExperimentSetup(RunEnvironment): self._set_param("permute_data", permute_data or upsampling, scope="train") self._set_param("batch_size", batch_size, default=DEFAULT_BATCH_SIZE) self._set_param("epochs", epochs, default=DEFAULT_EPOCHS) + self._set_param("early_stopping_epochs", early_stopping_epochs, default=DEFAULT_EARLY_STOPPING_EPOCHS) + self._set_param("restore_best_model_weights", restore_best_model_weights, + default=DEFAULT_RESTORE_BEST_MODEL_WEIGHTS) # set experiment name sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) # always related to output sampling diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 4e9f8fa4439e9885a6c16c2b2eccfee2c97fd936..eab8012b983a0676620bbc66f65ff79b31165aeb 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -11,6 +11,7 @@ from dill.source import getsource import tensorflow.keras as keras import pandas as pd import tensorflow as tf +import numpy as np from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler from mlair.run_modules.run_environment import RunEnvironment @@ -117,18 +118,36 @@ class ModelSetup(RunEnvironment): Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added. """ - lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) - hist = HistoryAdvanced() - epo_timing = EpoTimingCallback() - self.data_store.set("hist", hist, scope="model") - self.data_store.set("epo_timing", epo_timing, scope="model") + # create callback handler callbacks = CallbackHandler() + + # add callback: learning rate + lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) if lr is not None: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") + + # add callback: advanced history + hist = HistoryAdvanced() + self.data_store.set("hist", hist, scope="model") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") + + # add callback: epo timing + epo_timing = EpoTimingCallback() + self.data_store.set("epo_timing", epo_timing, scope="model") callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing") + + # add callback: early stopping + patience = self.data_store.get_default("early_stopping_epochs", default=np.inf) + restore_best_weights = self.data_store.get_default("restore_best_model_weights", default=True) + assert bool(isinstance(patience, int) or np.isinf(patience)) is True + cb = tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=restore_best_weights) + callbacks.add_callback(cb, self.callbacks_name % "early_stopping", "early_stopping") + + # create model checkpoint callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', - save_best_only=True, mode='auto') + save_best_only=True, mode='auto', restore_best_weights=restore_best_weights) + + # store callbacks self.data_store.set("callbacks", callbacks, self.scope) def load_model(self): diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 7a31f83fdf09ceed649aa02d4eecb74a9165eba0..8a5aa98b22d3f3c2cc4e1f32b8f816a14146b716 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -41,7 +41,7 @@ class PostProcessing(RunEnvironment): #. create plots Required objects [scope] from data store: - * `best_model` [.] or locally saved model plus `model_name` [model] and `model` [model] + * `model` [.] or locally saved model plus `model_name` [model] and `model` [model] * `generator` [train, val, test, train_val] * `forecast_path` [.] * `plot_path` [postprocessing] @@ -479,7 +479,7 @@ class PostProcessing(RunEnvironment): :return: the model """ try: # is only available if a model was trained in training stage - model = self.data_store.get("best_model") + model = self.data_store.get("model") except NameNotFoundInDataStore: logging.info("No model was saved in data store. Try to load model from experiment path.") model_name = self.data_store.get("model_name", "model") diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index a38837dce041295d37fae1ea86ef2a215d51dc89..5ce906122ef184d6dcad5527e923e44f04028fe5 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -54,7 +54,7 @@ class Training(RunEnvironment): * `upsampling` [train, val, test] Sets - * `best_model` [.] + * `model` [.] Creates * `<exp_name>_model-best.h5` @@ -165,6 +165,9 @@ class Training(RunEnvironment): initial_epoch=initial_epoch, workers=psutil.cpu_count(logical=False)) history = hist + epoch_best = checkpoint.epoch_best + if epoch_best is not None: + logging.info(f"best epoch: {epoch_best + 1}") try: lr = self.callbacks.get_callback_by_name("lr") except IndexError: @@ -174,29 +177,15 @@ class Training(RunEnvironment): except IndexError: epo_timing = None self.save_callbacks_as_json(history, lr, epo_timing) - self.load_best_model(checkpoint.filepath) - self.create_monitoring_plots(history, lr) + self.create_monitoring_plots(history, lr, epoch_best) def save_model(self) -> None: """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`.""" model_name = self.data_store.get("model_name", "model") - logging.debug(f"save best model to {model_name}") - self.model.save(model_name, save_format='h5') - self.model.save(model_name) - self.data_store.set("best_model", self.model) - - def load_best_model(self, name: str) -> None: - """ - Load model weights for model with name. Skip if no weights are available. - - :param name: name of the model to load weights for - """ - logging.debug(f"load best model: {name}") - try: - self.model.load_model(name, compile=True) - logging.info('reload model...') - except OSError: - logging.info('no weights to reload...') + logging.debug(f"save model to {model_name}") + self.model.save(model_name, save_format="h5") + self.model.save(model_name, save_format="tf") + self.data_store.set("model", self.model) def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None: """ @@ -219,7 +208,7 @@ class Training(RunEnvironment): with open(os.path.join(path, "epo_timing.json"), "w") as f: json.dump(epo_timing.epo_timing, f) - def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None: + def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int = None) -> None: """ Create plot of history and learning rate in dependence of the number of epochs. @@ -228,22 +217,23 @@ class Training(RunEnvironment): :param history: keras history object with losses to plot (must at least include `loss` and `val_loss`) :param lr_sc: learning rate decay object with 'lr' attribute + :param epoch_best: number of best epoch (starts counting as 0) """ path = self.data_store.get("plot_path") name = self.data_store.get("experiment_name") # plot history of loss and mse (if available) filename = os.path.join(path, f"{name}_history_loss.pdf") - PlotModelHistory(filename=filename, history=history) + PlotModelHistory(filename=filename, history=history, epoch_best=epoch_best) multiple_branches_used = len(history.model.output_names) > 1 # means that there are multiple output branches if multiple_branches_used: filename = os.path.join(path, f"{name}_history_main_loss.pdf") - PlotModelHistory(filename=filename, history=history, main_branch=True) + PlotModelHistory(filename=filename, history=history, main_branch=True, epoch_best=epoch_best) mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"])) if len(mse_indicator) > 0: filename = os.path.join(path, f"{name}_history_main_mse.pdf") PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0], - main_branch=multiple_branches_used) + main_branch=multiple_branches_used, epoch_best=epoch_best) # plot learning rate if lr_sc: diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py index f6bc6d24724c2620083602d3864bcbca0a709681..07a5aa2f543b1992baf10421de4b28133feb0eac 100644 --- a/test/test_configuration/test_defaults.py +++ b/test/test_configuration/test_defaults.py @@ -21,6 +21,8 @@ class TestAllDefaults: assert DEFAULT_PERMUTE_DATA is False assert DEFAULT_BATCH_SIZE == int(256 * 2) assert DEFAULT_EPOCHS == 20 + assert bool(np.isinf(DEFAULT_EARLY_STOPPING_EPOCHS)) is True + assert DEFAULT_RESTORE_BEST_MODEL_WEIGHTS is True def test_data_handler_parameters(self): assert DEFAULT_STATIONS == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py index b850b361b09a8d180c5c70c2257d2d7be27c6cc0..70640be9d56d71e4f68145b3bb68fb835e1e27a5 100644 --- a/test/test_helpers/test_helpers.py +++ b/test/test_helpers/test_helpers.py @@ -15,7 +15,7 @@ import string from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, sort_like from mlair.helpers import PyTestRegex from mlair.helpers import Logger, TimeTracking -from mlair.helpers.helpers import is_xarray, convert2xrda +from mlair.helpers.helpers import is_xarray, convert2xrda, relative_round class TestToList: @@ -171,6 +171,39 @@ class TestFloatRound: assert float_round(-34.9221, 0) == -34. +class TestRelativeRound: + + def test_relative_round_big_numbers(self): + assert relative_round(101, 1) == 100 + assert relative_round(99, 1) == 100 + assert relative_round(105, 2) == 100 + assert relative_round(106, 2) == 110 + assert relative_round(106, 3) == 106 + + def test_relative_round_float_numbers(self): + assert relative_round(101.2033, 4) == 101.2 + assert relative_round(101.2033, 5) == 101.2 + assert relative_round(101.2033, 6) == 101.203 + + def test_relative_round_small_numbers(self): + assert relative_round(0.03112, 2) == 0.031 + assert relative_round(0.03112, 1) == 0.03 + assert relative_round(0.031126, 4) == 0.03113 + + def test_relative_round_negative_numbers(self): + assert relative_round(-101.2033, 5) == -101.2 + assert relative_round(-106, 2) == -110 + assert relative_round(-0.03112, 2) == -0.031 + assert relative_round(-0.03112, 1) == -0.03 + assert relative_round(-0.031126, 4) == -0.03113 + + def test_relative_round_wrong_significance(self): + with pytest.raises(AssertionError): + relative_round(300, -1) + with pytest.raises(TypeError): + relative_round(300, 1.1) + + class TestSelectFromDict: @pytest.fixture diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py index 60b37207ceefc4088b33fa002dac9db7c6c35399..962287e09aacd3c44961a827c86b331d643ec401 100644 --- a/test/test_run_modules/test_model_setup.py +++ b/test/test_run_modules/test_model_setup.py @@ -80,7 +80,7 @@ class TestModelSetup: setup._set_callbacks() assert "general.model" in setup.data_store.search_name("callbacks") callbacks = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 4 + assert len(callbacks.get_callbacks()) == 5 def test_set_callbacks_no_lr_decay(self, setup): setup.data_store.set("lr_decay", None, "general.model") @@ -88,7 +88,7 @@ class TestModelSetup: setup.checkpoint_name = "TestName" setup._set_callbacks() callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 3 + assert len(callbacks.get_callbacks()) == 4 with pytest.raises(IndexError): callbacks.get_callback_by_name("lr_decay") diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index 1b83b3823519d63d5dcbc10f0e31fc3433f98f34..8f1fcd1943f9f203e738053017e00f8c269afef1 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -326,16 +326,10 @@ class TestTraining: model_name = "test_model.h5" assert model_name not in os.listdir(model_path) init_without_run.save_model() - message = PyTestRegex(f"save best model to {os.path.join(model_path, model_name)}") + message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}") assert caplog.record_tuples[1] == ("root", 10, message) assert model_name in os.listdir(model_path) - def test_load_best_model_no_weights(self, init_without_run, caplog): - caplog.set_level(logging.DEBUG) - init_without_run.load_best_model("notExisting.h5") - assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5")) - assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) - def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) assert "history.json" in os.listdir(model_path) @@ -360,7 +354,7 @@ class TestTraining: assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 history.model.output_names = mock.MagicMock(return_value=["Main"]) history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"]) - init_without_run.create_monitoring_plots(history, learning_rate) + init_without_run.create_monitoring_plots(history, learning_rate, epoch_best=1) assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var,