Skip to content
Snippets Groups Projects
Select Git revision
8 results Searching

start_jupyter.sh

Blame
  • keras_extensions.py 3.96 KiB
    __author__ = 'Lukas Leufen, Felix Kleinert'
    __date__ = '2020-01-31'
    
    import logging
    import math
    import pickle
    from typing import Union
    
    import numpy as np
    from keras import backend as K
    from keras.callbacks import History, ModelCheckpoint
    
    
    class HistoryAdvanced(History):
    
        def __init__(self, old_epoch=None, old_history=None):
            self.epoch = old_epoch or []
            self.history = old_history or {}
            super().__init__()
    
        def on_train_begin(self, logs=None):
            pass
    
    
    class LearningRateDecay(History):
        """
        Decay learning rate during model training. Start with a base learning rate and lower this rate after every
        n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
        """
    
        def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
            super().__init__()
            self.lr = {'lr': []}
            self.base_lr = self.check_param(base_lr, 'base_lr')
            self.drop = self.check_param(drop, 'drop')
            self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
            self.epoch = []
            self.history = {}
    
        @staticmethod
        def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
            """
            Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
            only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
            value without any check.
            :param value: value to check
            :param name: name of the variable to display in error message
            :param lower: left (lower) endpoint of interval, opened
            :param upper: right (upper) endpoint of interval, closed
            :return: unchanged value or raise ValueError
            """
            if lower is None:
                lower = -np.inf
            if upper is None:
                upper = np.inf
            if lower < value <= upper:
                return value
            else:
                raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
                                 f"{name}={value}")
    
        def on_train_begin(self, logs=None):
            pass
    
        def on_epoch_begin(self, epoch: int, logs=None):
            """
            Lower learning rate every epochs_drop epochs by factor drop.
            :param epoch: current epoch
            :param logs: ?
            :return: update keras learning rate
            """
            current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop))
            K.set_value(self.model.optimizer.lr, current_lr)
            self.lr['lr'].append(current_lr)
            logging.info(f"Set learning rate to {current_lr}")
            return K.get_value(self.model.optimizer.lr)
    
    
    class ModelCheckpointAdvanced(ModelCheckpoint):
        """
        IMPORTANT: Always add the model checkpoint advanced as last callback to properly update all tracked callbacks, e.g.
        fit_generator(callbacks=[..., <last_here>])
        """
        def __init__(self, *args, **kwargs):
            self.callbacks = kwargs.pop("callbacks")
            super().__init__(*args, **kwargs)
    
        def update_best(self, hist):
            self.best = hist.history.get(self.monitor)[-1]
    
        def update_callbacks(self, callbacks):
            self.callbacks = callbacks
    
        def on_epoch_end(self, epoch, logs=None):
            super().on_epoch_end(epoch, logs)
    
            for callback in self.callbacks:
                file_path = callback["path"]
                if self.epochs_since_last_save == 0 and epoch != 0:
                    if self.save_best_only:
                        current = logs.get(self.monitor)
                        if current == self.best:
                            print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                            with open(file_path, "wb") as f:
                                pickle.dump(callback["callback"], f)
                    else:
                        with open(file_path, "wb") as f:
                            pickle.dump(callback["callback"], f)