Skip to content
Snippets Groups Projects
Select Git revision
  • 4a4ba794128bb2a42b4b9d1e1620f28d58a065e2
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

training.py

Blame
  • training.py 5.21 KiB
    __author__ = "Lukas Leufen, Felix Kleinert"
    __date__ = '2019-12-05'
    
    import logging
    import os
    import json
    import keras
    
    from src.run_modules.run_environment import RunEnvironment
    from src.data_handling.data_distributor import Distributor
    
    
    class Training(RunEnvironment):
    
        def __init__(self):
            super().__init__()
            self.model: keras.Model = self.data_store.get("model", "general.model")
            self.train_set = None
            self.val_set = None
            self.test_set = None
            self.batch_size = self.data_store.get("batch_size", "general.model")
            self.epochs = self.data_store.get("epochs", "general.model")
            self.checkpoint = self.data_store.get("checkpoint", "general.model")
            self.lr_sc = self.data_store.get("lr_decay", "general.model")
            self.experiment_name = self.data_store.get("experiment_name", "general")
            self._run()
    
        def _run(self) -> None:
            """
            Perform training
            1) set_generators():
                set generators for training, validation and testing and distribute according to batch size
            2) make_predict_function():
                create predict function before distribution on multiple nodes (detailed information in method description)
            3) train():
                train model and save callbacks
            4) save_model():
                save best model from training as final model
            """
            self.set_generators()
            self.make_predict_function()
            self.train()
            self.save_model()
    
        def make_predict_function(self) -> None:
            """
            Creates the predict function. Must be called before distributing. This is necessary, because tf will compile
            the predict function just in the moment it is used the first time. This can cause problems, if the model is
            distributed on different workers. To prevent this, the function is pre-compiled. See discussion @
            https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252
            """
            self.model._make_predict_function()
    
        def _set_gen(self, mode: str) -> None:
            """
            Set and distribute the generators for given mode regarding batch size
            :param mode: name of set, should be from ["train", "val", "test"]
            """
            gen = self.data_store.get("generator", f"general.{mode}")
            setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size))
    
        def set_generators(self) -> None:
            """
            Set all generators for training, validation, and testing subsets. The called sub-method will automatically
            distribute the data according to the batch size. The subsets can be accessed as class variables train_set,
            val_set, and test_set .
            """
            for mode in ["train", "val", "test"]:
                self._set_gen(mode)
    
        def train(self) -> None:
            """
            Perform training using keras fit_generator(). Callbacks are stored locally in the experiment directory. Best
            model from training is saved for class variable model.
            """
            logging.info(f"Train with {len(self.train_set)} mini batches.")
            history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                               steps_per_epoch=len(self.train_set),
                                               epochs=self.epochs,
                                               verbose=2,
                                               validation_data=self.val_set.distribute_on_batches(),
                                               validation_steps=len(self.val_set),
                                               callbacks=[self.checkpoint, self.lr_sc])
            self.save_callbacks(history)
            self.load_best_model(self.checkpoint.filepath)
    
        def save_model(self) -> None:
            """
            save model in local experiment directory. Model is named as <experiment_name>_my_model.h5 .
            """
            path = self.data_store.get("experiment_path", "general")
            name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5"
            model_name = os.path.join(path, name)
            logging.debug(f"save best model to {model_name}")
            self.model.save(model_name)
            self.data_store.put("best_model", self.model, "general")
    
        def load_best_model(self, name: str) -> None:
            """
            Load model weights for model with name. Skip if no weights are available.
            :param name: name of the model to load weights for
            """
            logging.debug(f"load best model: {name}")
            try:
                self.model.load_weights(name)
                logging.info('reload weights...')
            except OSError:
                logging.info('no weights to reload...')
    
        def save_callbacks(self, history: keras.callbacks.History) -> None:
            """
            Save callbacks (history, learning rate) of training.
            * history.history -> history.json
            * lr_sc.lr -> history_lr.json
            :param history: history object of training
            """
            logging.debug("saving callbacks")
            path = self.data_store.get("experiment_path", "general")
            with open(os.path.join(path, "history.json"), "w") as f:
                json.dump(history.history, f)
            with open(os.path.join(path, "history_lr.json"), "w") as f:
                json.dump(self.lr_sc.lr, f)