From 5de64e132b3c6de853ae69a731e1c8be9338aacd Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 31 Jan 2020 12:05:20 +0100
Subject: [PATCH] created new module keras_extensions to collect all callback
 extensions in one place

---
 src/helpers.py                                |  58 +---------
 src/model_modules/keras_extensions.py         | 108 ++++++++++++++++++
 src/model_modules/model_class.py              |   6 +-
 src/plotting/training_monitoring.py           |   3 +-
 src/run_modules/model_setup.py                |  66 +----------
 src/run_modules/training.py                   |  26 ++---
 test/test_helpers.py                          |  40 +------
 .../test_keras_extensions.py                  |  42 +++++++
 test/test_modules/test_training.py            |   3 +-
 .../test_plotting/test_training_monitoring.py |   2 +-
 10 files changed, 178 insertions(+), 176 deletions(-)
 create mode 100644 src/model_modules/keras_extensions.py
 create mode 100644 test/test_model_modules/test_keras_extensions.py

diff --git a/src/helpers.py b/src/helpers.py
index f119f140..a4ce625c 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -11,12 +11,10 @@ import time
 import socket
 import datetime as dt
 
-import keras
 import keras.backend as K
-import numpy as np
 import xarray as xr
 
-from typing import Union, Dict, Callable
+from typing import Dict, Callable
 
 
 def to_list(arg):
@@ -45,60 +43,6 @@ def l_p_loss(power: int):
     return loss
 
 
-class LearningRateDecay(keras.callbacks.History):
-    """
-    Decay learning rate during model training. Start with a base learning rate and lower this rate after every
-    n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
-    """
-
-    def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
-        super().__init__()
-        self.lr = {'lr': []}
-        self.base_lr = self.check_param(base_lr, 'base_lr')
-        self.drop = self.check_param(drop, 'drop')
-        self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
-        self.epoch = []
-        self.history = {}
-
-    @staticmethod
-    def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
-        """
-        Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
-        only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
-        value without any check.
-        :param value: value to check
-        :param name: name of the variable to display in error message
-        :param lower: left (lower) endpoint of interval, opened
-        :param upper: right (upper) endpoint of interval, closed
-        :return: unchanged value or raise ValueError
-        """
-        if lower is None:
-            lower = -np.inf
-        if upper is None:
-            upper = np.inf
-        if lower < value <= upper:
-            return value
-        else:
-            raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
-                             f"{name}={value}")
-
-    def on_train_begin(self, logs=None):
-        pass
-
-    def on_epoch_begin(self, epoch: int, logs=None):
-        """
-        Lower learning rate every epochs_drop epochs by factor drop.
-        :param epoch: current epoch
-        :param logs: ?
-        :return: update keras learning rate
-        """
-        current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop))
-        K.set_value(self.model.optimizer.lr, current_lr)
-        self.lr['lr'].append(current_lr)
-        logging.info(f"Set learning rate to {current_lr}")
-        return K.get_value(self.model.optimizer.lr)
-
-
 class TimeTracking(object):
     """
     Track time to measure execution time. Time tracking automatically starts on initialisation and ends by calling stop
diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py
new file mode 100644
index 00000000..f84f33d1
--- /dev/null
+++ b/src/model_modules/keras_extensions.py
@@ -0,0 +1,108 @@
+__author__ = 'Lukas Leufen, Felix Kleinert'
+__date__ = '2020-01-31'
+
+import logging
+import math
+import pickle
+from typing import Union
+
+import numpy as np
+from keras import backend as K
+from keras.callbacks import History, ModelCheckpoint
+
+
+class HistoryAdvanced(History):
+
+    def __init__(self, old_epoch=None, old_history=None):
+        self.epoch = old_epoch or []
+        self.history = old_history or {}
+        super().__init__()
+
+    def on_train_begin(self, logs=None):
+        pass
+
+
+class LearningRateDecay(History):
+    """
+    Decay learning rate during model training. Start with a base learning rate and lower this rate after every
+    n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
+    """
+
+    def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
+        super().__init__()
+        self.lr = {'lr': []}
+        self.base_lr = self.check_param(base_lr, 'base_lr')
+        self.drop = self.check_param(drop, 'drop')
+        self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
+        self.epoch = []
+        self.history = {}
+
+    @staticmethod
+    def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
+        """
+        Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
+        only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
+        value without any check.
+        :param value: value to check
+        :param name: name of the variable to display in error message
+        :param lower: left (lower) endpoint of interval, opened
+        :param upper: right (upper) endpoint of interval, closed
+        :return: unchanged value or raise ValueError
+        """
+        if lower is None:
+            lower = -np.inf
+        if upper is None:
+            upper = np.inf
+        if lower < value <= upper:
+            return value
+        else:
+            raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
+                             f"{name}={value}")
+
+    def on_train_begin(self, logs=None):
+        pass
+
+    def on_epoch_begin(self, epoch: int, logs=None):
+        """
+        Lower learning rate every epochs_drop epochs by factor drop.
+        :param epoch: current epoch
+        :param logs: ?
+        :return: update keras learning rate
+        """
+        current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop))
+        K.set_value(self.model.optimizer.lr, current_lr)
+        self.lr['lr'].append(current_lr)
+        logging.info(f"Set learning rate to {current_lr}")
+        return K.get_value(self.model.optimizer.lr)
+
+
+class ModelCheckpointAdvanced(ModelCheckpoint):
+    """
+    IMPORTANT: Always add the model checkpoint advanced as last callback to properly update all tracked callbacks, e.g.
+    fit_generator(callbacks=[..., <last_here>])
+    """
+    def __init__(self, *args, **kwargs):
+        self.callbacks = kwargs.pop("callbacks")
+        super().__init__(*args, **kwargs)
+
+    def update_best(self, hist):
+        self.best = hist.history.get(self.monitor)[-1]
+
+    def update_callbacks(self, callbacks):
+        self.callbacks = callbacks
+
+    def on_epoch_end(self, epoch, logs=None):
+        super().on_epoch_end(epoch, logs)
+
+        for callback in self.callbacks:
+            file_path = callback["path"]
+            if self.epochs_since_last_save == 0 and epoch != 0:
+                if self.save_best_only:
+                    current = logs.get(self.monitor)
+                    if current == self.best:
+                        print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
+                        with open(file_path, "wb") as f:
+                            pickle.dump(callback["callback"], f)
+                else:
+                    with open(file_path, "wb") as f:
+                        pickle.dump(callback["callback"], f)
\ No newline at end of file
diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index 5e9931d7..02f43dc1 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -1,3 +1,5 @@
+import src.model_modules.keras_extensions
+
 __author__ = "Lukas Leufen"
 __date__ = '2019-12-12'
 
@@ -109,7 +111,7 @@ class MyLittleModel(AbstractModelClass):
         self.regularizer = keras.regularizers.l2(0.1)
         self.initial_lr = 1e-2
         self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
-        self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
+        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
         self.epochs = 20
         self.batch_size = int(256)
         self.activation = keras.layers.PReLU
@@ -189,7 +191,7 @@ class MyBranchedModel(AbstractModelClass):
         self.regularizer = keras.regularizers.l2(0.1)
         self.initial_lr = 1e-2
         self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
-        self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
+        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
         self.epochs = 20
         self.batch_size = int(256)
         self.activation = keras.layers.PReLU
diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py
index 87e4071a..339ba637 100644
--- a/src/plotting/training_monitoring.py
+++ b/src/plotting/training_monitoring.py
@@ -9,8 +9,7 @@ import pandas as pd
 import matplotlib
 import matplotlib.pyplot as plt
 
-from src.helpers import LearningRateDecay
-
+from src.model_modules.keras_extensions import LearningRateDecay
 
 matplotlib.use('Agg')
 history_object = Union[Dict, keras.callbacks.History]
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index a4d89f65..2512947b 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -4,16 +4,13 @@ __date__ = '2019-12-02'
 
 import keras
 from keras import losses
-from keras.callbacks import ModelCheckpoint, History
-from keras.regularizers import l2
-from keras.optimizers import SGD
 import tensorflow as tf
 import logging
 import os
-import pickle
 
 from src.run_modules.run_environment import RunEnvironment
-from src.helpers import l_p_loss, LearningRateDecay
+from src.helpers import l_p_loss
+from src.model_modules.keras_extensions import HistoryAdvanced, ModelCheckpointAdvanced
 from src.model_modules.inception_model import InceptionModelBase
 from src.model_modules.flatten import flatten_tail
 # from src.model_modules.model_class import MyBranchedModel as MyModel
@@ -75,8 +72,8 @@ class ModelSetup(RunEnvironment):
         self.data_store.set("hist", hist, scope="general.model")
         callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"},
                      {"callback": hist, "path": self.callbacks_name % "hist"}]
-        checkpoint = ModelCheckpointAdvanced2(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
-                                              save_best_only=True, mode='auto', callbacks=callbacks)
+        checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
+                                             save_best_only=True, mode='auto', callbacks=callbacks)
         self.data_store.set("checkpoint", checkpoint, self.scope)
 
     def load_weights(self):
@@ -104,61 +101,6 @@ class ModelSetup(RunEnvironment):
             keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
 
 
-class HistoryAdvanced(History):
-
-    def __init__(self, old_epoch=None, old_history=None):
-        self.epoch = old_epoch or []
-        self.history = old_history or {}
-        super().__init__()
-
-    def on_train_begin(self, logs=None):
-        pass
-
-
-class ModelCheckpointAdvanced(ModelCheckpoint):
-
-    def __init__(self, *args, **kwargs):
-        self.callbacks_to_save = kwargs.pop("callbacks_to_save")
-        self.callbacks_filepath = kwargs.pop("callbacks_filepath")
-        super().__init__(*args, **kwargs)
-
-    def on_epoch_end(self, epoch, logs=None):
-        super().on_epoch_end(epoch, logs)
-
-        file_path = self.callbacks_filepath
-        if self.epochs_since_last_save == 0 and epoch != 0:
-            if self.save_best_only:
-                current = logs.get(self.monitor)
-                if current == self.best:
-                    with open(file_path, "wb") as f:
-                        pickle.dump(self.callbacks_to_save, f)
-            else:
-                with open(file_path, "wb") as f:
-                    pickle.dump(self.callbacks_to_save, f)
-
-
-class ModelCheckpointAdvanced2(ModelCheckpoint):
-
-    def __init__(self, *args, **kwargs):
-        self.callbacks = kwargs.pop("callbacks")
-        super().__init__(*args, **kwargs)
-
-    def on_epoch_end(self, epoch, logs=None):
-        super().on_epoch_end(epoch, logs)
-
-        for callback in self.callbacks:
-            file_path = callback["path"]
-            if self.epochs_since_last_save == 0 and epoch != 0:
-                if self.save_best_only:
-                    current = logs.get(self.monitor)
-                    if current == self.best:
-                        with open(file_path, "wb") as f:
-                            pickle.dump(callback["callback"], f)
-                else:
-                    with open(file_path, "wb") as f:
-                        pickle.dump(callback["callback"], f)
-
-
 def my_loss():
     loss = l_p_loss(4)
     keras_loss = losses.mean_squared_error
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index e9c7487b..7dfa06ff 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -10,8 +10,7 @@ import pickle
 from src.run_modules.run_environment import RunEnvironment
 from src.data_handling.data_distributor import Distributor
 from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
-from src.helpers import LearningRateDecay
-from src.run_modules.model_setup import ModelCheckpointAdvanced2
+from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced
 
 
 class Training(RunEnvironment):
@@ -24,7 +23,7 @@ class Training(RunEnvironment):
         self.test_set = None
         self.batch_size = self.data_store.get("batch_size", "general.model")
         self.epochs = self.data_store.get("epochs", "general.model")
-        self.checkpoint: ModelCheckpointAdvanced2 = self.data_store.get("checkpoint", "general.model")
+        self.checkpoint: ModelCheckpointAdvanced = self.data_store.get("checkpoint", "general.model")
         # self.callbacks = self.data_store.get("callbacks", "general.model")
         self.lr_sc = self.data_store.get("lr_decay", "general.model")
         self.hist = self.data_store.get("hist", "general.model")
@@ -88,7 +87,7 @@ class Training(RunEnvironment):
                                                validation_data=self.val_set.distribute_on_batches(),
                                                validation_steps=len(self.val_set),
                                                # callbacks=self.callbacks)
-                                               callbacks=[self.checkpoint, self.lr_sc, self.hist])
+                                               callbacks=[self.lr_sc, self.hist, self.checkpoint])
         else:
             lr_filepath = self.checkpoint.callbacks[0]["path"]  # TODO: stopped here. why does training start 1 epoch too early or doesn't it?
             hist_filepath = self.checkpoint.callbacks[1]["path"]
@@ -100,15 +99,16 @@ class Training(RunEnvironment):
             initial_epoch = max(hist_callbacks.epoch) + 1
             callbacks = [{"callback": self.lr_sc, "path": lr_filepath},
                          {"callback": self.hist, "path": hist_filepath}]
-            self.checkpoint.callbacks = callbacks
-            history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
-                                               steps_per_epoch=len(self.train_set),
-                                               epochs=self.epochs,
-                                               verbose=2,
-                                               validation_data=self.val_set.distribute_on_batches(),
-                                               validation_steps=len(self.val_set),
-                                               callbacks=[self.checkpoint, self.lr_sc, self.hist],
-                                               initial_epoch=initial_epoch)
+            self.checkpoint.update_callbacks(callbacks)
+            self.checkpoint.update_best(hist_callbacks)
+            self.hist = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+                                                 steps_per_epoch=len(self.train_set),
+                                                 epochs=self.epochs,
+                                                 verbose=2,
+                                                 validation_data=self.val_set.distribute_on_batches(),
+                                                 validation_steps=len(self.val_set),
+                                                 callbacks=[self.lr_sc, self.hist, self.checkpoint],
+                                                 initial_epoch=initial_epoch)
             history = self.hist
         self.save_callbacks(history)
         self.load_best_model(self.checkpoint.filepath)
diff --git a/test/test_helpers.py b/test/test_helpers.py
index e98a46fa..c909960b 100644
--- a/test/test_helpers.py
+++ b/test/test_helpers.py
@@ -7,6 +7,8 @@ import numpy as np
 import mock
 import platform
 
+from src.model_modules.keras_extensions import LearningRateDecay
+
 
 class TestToList:
 
@@ -44,44 +46,6 @@ class TestLoss:
         assert hist.history['loss'][0] == 2.25
 
 
-class TestLearningRateDecay:
-
-    def test_init(self):
-        lr_decay = LearningRateDecay()
-        assert lr_decay.lr == {'lr': []}
-        assert lr_decay.base_lr == 0.01
-        assert lr_decay.drop == 0.96
-        assert lr_decay.epochs_drop == 8
-
-    def test_check_param(self):
-        lr_decay = object.__new__(LearningRateDecay)
-        assert lr_decay.check_param(1, "tester") == 1
-        assert lr_decay.check_param(0.5, "tester") == 0.5
-        with pytest.raises(ValueError) as e:
-            lr_decay.check_param(0, "tester")
-        assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0]
-        with pytest.raises(ValueError) as e:
-            lr_decay.check_param(1.5, "tester")
-        assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0]
-        assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
-        with pytest.raises(ValueError) as e:
-            lr_decay.check_param(0, "tester", upper=None)
-        assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
-        assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
-        with pytest.raises(ValueError) as e:
-            lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
-        assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0]
-        assert lr_decay.check_param(10, "tester", upper=None, lower=None)
-
-    def test_on_epoch_begin(self):
-        lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2)
-        model = keras.Sequential()
-        model.add(keras.layers.Dense(1, input_dim=1))
-        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
-        model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
-        assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]
-
-
 class TestTimeTracking:
 
     def test_init(self):
diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py
new file mode 100644
index 00000000..7bf5cf51
--- /dev/null
+++ b/test/test_model_modules/test_keras_extensions.py
@@ -0,0 +1,42 @@
+import pytest
+from src.model_modules.keras_extensions import *
+import keras
+import numpy as np
+
+
+class TestLearningRateDecay:
+
+    def test_init(self):
+        lr_decay = LearningRateDecay()
+        assert lr_decay.lr == {'lr': []}
+        assert lr_decay.base_lr == 0.01
+        assert lr_decay.drop == 0.96
+        assert lr_decay.epochs_drop == 8
+
+    def test_check_param(self):
+        lr_decay = object.__new__(LearningRateDecay)
+        assert lr_decay.check_param(1, "tester") == 1
+        assert lr_decay.check_param(0.5, "tester") == 0.5
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(0, "tester")
+        assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0]
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(1.5, "tester")
+        assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0]
+        assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(0, "tester", upper=None)
+        assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
+        assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
+        assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0]
+        assert lr_decay.check_param(10, "tester", upper=None, lower=None)
+
+    def test_on_epoch_begin(self):
+        lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2)
+        model = keras.Sequential()
+        model.add(keras.layers.Dense(1, input_dim=1))
+        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
+        model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
+        assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index accb32e5..4631fe5a 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -14,7 +14,8 @@ from src.run_modules.training import Training
 from src.run_modules.run_environment import RunEnvironment
 from src.data_handling.data_distributor import Distributor
 from src.data_handling.data_generator import DataGenerator
-from src.helpers import LearningRateDecay, PyTestRegex
+from src.helpers import PyTestRegex
+from src.model_modules.keras_extensions import LearningRateDecay
 
 
 def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
diff --git a/test/test_plotting/test_training_monitoring.py b/test/test_plotting/test_training_monitoring.py
index e2d17057..358a19ad 100644
--- a/test/test_plotting/test_training_monitoring.py
+++ b/test/test_plotting/test_training_monitoring.py
@@ -3,7 +3,7 @@ import pytest
 import os
 
 from src.plotting.training_monitoring import PlotModelLearningRate, PlotModelHistory
-from src.helpers import LearningRateDecay
+from src.model_modules.keras_extensions import LearningRateDecay
 
 
 @pytest.fixture
-- 
GitLab