Skip to content
Snippets Groups Projects
Select Git revision
  • dd825dba3aed3bef855e9cea6d47f5e438c94081
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

training.py

Blame
  • training.py 10.11 KiB
    """Training module."""
    
    __author__ = "Lukas Leufen, Felix Kleinert"
    __date__ = '2019-12-05'
    
    import json
    import logging
    import os
    from typing import Union
    
    import keras
    from keras.callbacks import Callback, History
    
    from src.data_handling.data_distributor import Distributor
    from src.model_modules.keras_extensions import CallbackHandler
    from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
    from src.run_modules.run_environment import RunEnvironment
    
    
    class Training(RunEnvironment):
        """
        Perform training.
            #. set_generators(): set generators for training, validation and testing and distribute according to batch size
            #. make_predict_function(): create predict function before distribution on multiple nodes (detailed information
               in method description)
            #. train(): start or resume training of model and save callbacks
            #. save_model(): save best model from training as final model
    
        Required objects [scope] from data store:
            * `model` [model]
            * `batch_size` [model]
            * `epochs` [model]
            * `callbacks` [model]
            * `model_name` [model]
            * `experiment_name` [.]
            * `experiment_path` [.]
            * `trainable` [.]
            * `create_new_model` [.]
            * `generator` [train, val, test]
            * `plot_path` [.]
    
        Optional objects
            * `permute_data` [train, val, test]
            * `upsampling` [train, val, test]
    
        Sets
            * `best_model` [.]
    
        Creates
            * `<exp_name>_model-best.h5`
            * `<exp_name>_model-best-callbacks-<name>.h5` (all callbacks from CallbackHandler)
            * `history.json`
            * `history_lr.json` (optional)
            * `<exp_name>_history_<name>.pdf` (different monitoring plots depending on loss metrics and callbacks)
    
        """
    
        def __init__(self):
            """Set up training."""
            super().__init__()
            self.model: keras.Model = self.data_store.get("model", "model")
            self.train_set: Union[Distributor, None] = None
            self.val_set: Union[Distributor, None] = None
            self.test_set: Union[Distributor, None] = None
            self.batch_size = self.data_store.get("batch_size", "model")
            self.epochs = self.data_store.get("epochs", "model")
            self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model")
            self.experiment_name = self.data_store.get("experiment_name")
            self._trainable = self.data_store.get("trainable")
            self._create_new_model = self.data_store.get("create_new_model")
            self._run()
    
        def _run(self) -> None:
            """Run training. Details in class description."""
            self.set_generators()
            self.make_predict_function()
            if self._trainable:
                self.train()
                self.save_model()
            else:
                logging.info("No training has started, because trainable parameter was false.")
    
        def make_predict_function(self) -> None:
            """
            Create predict function.
    
            Must be called before distributing. This is necessary, because tf will compile the predict function just in
            the moment it is used the first time. This can cause problems, if the model is distributed on different
            workers. To prevent this, the function is pre-compiled. See discussion @
            https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252
            """
            self.model._make_predict_function()
    
        def _set_gen(self, mode: str) -> None:
            """
            Set and distribute the generators for given mode regarding batch size.
    
            :param mode: name of set, should be from ["train", "val", "test"]
            """
            gen = self.data_store.get("generator", mode)
            kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=mode)
            setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs))
    
        def set_generators(self) -> None:
            """
            Set all generators for training, validation, and testing subsets.
    
            The called sub-method will automatically distribute the data according to the batch size. The subsets can be
            accessed as class variables train_set, val_set, and test_set.
            """
            for mode in ["train", "val", "test"]:
                self._set_gen(mode)
    
        def train(self) -> None:
            """
            Perform training using keras fit_generator().
    
            Callbacks are stored locally in the experiment directory. Best model from training is saved for class
            variable model. If the file path of checkpoint is not empty, this method assumes, that this is not a new
            training starting from the very beginning, but a resumption from a previous started but interrupted training
            (or a stopped and now continued training). Train will automatically load the locally stored information and the
            corresponding model and proceed with the already started training.
            """
            logging.info(f"Train with {len(self.train_set)} mini batches.")
            logging.info(f"Train with option upsampling={self.train_set.upsampling}.")
            logging.info(f"Train with option data_permutation={self.train_set.do_data_permutation}.")
    
            checkpoint = self.callbacks.get_checkpoint()
            if not os.path.exists(checkpoint.filepath) or self._create_new_model:
                history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                                   steps_per_epoch=len(self.train_set),
                                                   epochs=self.epochs,
                                                   verbose=2,
                                                   validation_data=self.val_set.distribute_on_batches(),
                                                   validation_steps=len(self.val_set),
                                                   callbacks=self.callbacks.get_callbacks(as_dict=False))
            else:
                logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
                self.callbacks.load_callbacks()
                self.callbacks.update_checkpoint()
                self.model = keras.models.load_model(checkpoint.filepath)
                hist: History = self.callbacks.get_callback_by_name("hist")
                initial_epoch = max(hist.epoch) + 1
                _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                             steps_per_epoch=len(self.train_set),
                                             epochs=self.epochs,
                                             verbose=2,
                                             validation_data=self.val_set.distribute_on_batches(),
                                             validation_steps=len(self.val_set),
                                             callbacks=self.callbacks.get_callbacks(as_dict=False),
                                             initial_epoch=initial_epoch)
                history = hist
            try:
                lr = self.callbacks.get_callback_by_name("lr")
            except IndexError:
                lr = None
            self.save_callbacks_as_json(history, lr)
            self.load_best_model(checkpoint.filepath)
            self.create_monitoring_plots(history, lr)
    
        def save_model(self) -> None:
            """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`."""
            model_name = self.data_store.get("model_name", "model")
            logging.debug(f"save best model to {model_name}")
            self.model.save(model_name)
            self.data_store.set("best_model", self.model)
    
        def load_best_model(self, name: str) -> None:
            """
            Load model weights for model with name. Skip if no weights are available.
    
            :param name: name of the model to load weights for
            """
            logging.debug(f"load best model: {name}")
            try:
                self.model.load_weights(name)
                logging.info('reload weights...')
            except OSError:
                logging.info('no weights to reload...')
    
        def save_callbacks_as_json(self, history: Callback, lr_sc: Callback) -> None:
            """
            Save callbacks (history, learning rate) of training.
    
            * history.history -> history.json
            * lr_sc.lr -> history_lr.json
    
            :param history: history object of training
            :param lr_sc: learning rate object
            """
            logging.debug("saving callbacks")
            path = self.data_store.get("experiment_path")
            with open(os.path.join(path, "history.json"), "w") as f:
                json.dump(history.history, f)
            if lr_sc:
                with open(os.path.join(path, "history_lr.json"), "w") as f:
                    json.dump(lr_sc.lr, f)
    
        def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None:
            """
            Create plot of history and learning rate in dependence of the number of epochs.
    
            The plots are saved in the experiment's plot_path. History plot is named `<exp_name>_history_loss_val_loss.pdf`,
            the learning rate with `<exp_name>_history_learning_rate.pdf`.
    
            :param history: keras history object with losses to plot (must at least include `loss` and `val_loss`)
            :param lr_sc:  learning rate decay object with 'lr' attribute
            """
            path = self.data_store.get("plot_path")
            name = self.data_store.get("experiment_name")
    
            # plot history of loss and mse (if available)
            filename = os.path.join(path, f"{name}_history_loss.pdf")
            PlotModelHistory(filename=filename, history=history)
            multiple_branches_used = len(history.model.output_names) > 1  # means that there are multiple output branches
            if multiple_branches_used:
                filename = os.path.join(path, f"{name}_history_main_loss.pdf")
                PlotModelHistory(filename=filename, history=history, main_branch=True)
            if "mean_squared_error" in history.model.metrics_names:
                filename = os.path.join(path, f"{name}_history_main_mse.pdf")
                PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used)
    
            # plot learning rate
            if lr_sc:
                PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)