From a5ef033e55f2434964df693b18bbab97a994d56b Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Thu, 30 Jan 2020 18:34:13 +0100
Subject: [PATCH] intermediate save commit

---
 run.py                           |  6 +--
 src/helpers.py                   |  5 ++
 src/model_modules/model_class.py |  4 +-
 src/run_modules/model_setup.py   | 81 +++++++++++++++++++++++++++++---
 src/run_modules/training.py      | 44 +++++++++++++----
 5 files changed, 120 insertions(+), 20 deletions(-)

diff --git a/run.py b/run.py
index 71244fb9..c06bf648 100644
--- a/run.py
+++ b/run.py
@@ -24,14 +24,14 @@ def main(parser_args):
 
         Training()
 
-        PostProcessing()
+        # PostProcessing()
 
 
 if __name__ == "__main__":
 
     formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
-    # logging.basicConfig(format=formatter, level=logging.INFO)
-    logging.basicConfig(format=formatter, level=logging.DEBUG)
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    # logging.basicConfig(format=formatter, level=logging.DEBUG)
 
     parser = argparse.ArgumentParser()
     parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
diff --git a/src/helpers.py b/src/helpers.py
index 172a8dd3..f119f140 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -57,6 +57,8 @@ class LearningRateDecay(keras.callbacks.History):
         self.base_lr = self.check_param(base_lr, 'base_lr')
         self.drop = self.check_param(drop, 'drop')
         self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
+        self.epoch = []
+        self.history = {}
 
     @staticmethod
     def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
@@ -80,6 +82,9 @@ class LearningRateDecay(keras.callbacks.History):
             raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
                              f"{name}={value}")
 
+    def on_train_begin(self, logs=None):
+        pass
+
     def on_epoch_begin(self, epoch: int, logs=None):
         """
         Lower learning rate every epochs_drop epochs by factor drop.
diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index 6b0fe236..5e9931d7 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -110,7 +110,7 @@ class MyLittleModel(AbstractModelClass):
         self.initial_lr = 1e-2
         self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
         self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
-        self.epochs = 2
+        self.epochs = 20
         self.batch_size = int(256)
         self.activation = keras.layers.PReLU
 
@@ -190,7 +190,7 @@ class MyBranchedModel(AbstractModelClass):
         self.initial_lr = 1e-2
         self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
         self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
-        self.epochs = 2
+        self.epochs = 20
         self.batch_size = int(256)
         self.activation = keras.layers.PReLU
 
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index a7722018..a4d89f65 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -4,19 +4,20 @@ __date__ = '2019-12-02'
 
 import keras
 from keras import losses
-from keras.callbacks import ModelCheckpoint
+from keras.callbacks import ModelCheckpoint, History
 from keras.regularizers import l2
 from keras.optimizers import SGD
 import tensorflow as tf
 import logging
 import os
+import pickle
 
 from src.run_modules.run_environment import RunEnvironment
 from src.helpers import l_p_loss, LearningRateDecay
 from src.model_modules.inception_model import InceptionModelBase
 from src.model_modules.flatten import flatten_tail
-from src.model_modules.model_class import MyBranchedModel as MyModel
-# from src.model_modules.model_class import MyLittleModel as MyModel
+# from src.model_modules.model_class import MyBranchedModel as MyModel
+from src.model_modules.model_class import MyLittleModel as MyModel
 
 
 class ModelSetup(RunEnvironment):
@@ -30,13 +31,11 @@ class ModelSetup(RunEnvironment):
         exp_name = self.data_store.get("experiment_name", "general")
         self.scope = "general.model"
         self.checkpoint_name = os.path.join(path, f"{exp_name}_model-best.h5")
+        self.callbacks_name = os.path.join(path, f"{exp_name}_model-best-callbacks-%s.pickle")
         self._run()
 
     def _run(self):
 
-        # create checkpoint
-        self._set_checkpoint()
-
         # set channels depending on inputs
         self._set_channels()
 
@@ -50,6 +49,9 @@ class ModelSetup(RunEnvironment):
         if self.data_store.get("trainable", self.scope) is False:
             self.load_weights()
 
+        # create checkpoint
+        self._set_checkpoint()
+
         # compile model
         self.compile_model()
 
@@ -64,7 +66,17 @@ class ModelSetup(RunEnvironment):
         self.data_store.set("model", self.model, self.scope)
 
     def _set_checkpoint(self):
-        checkpoint = ModelCheckpoint(self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
+        lr = self.data_store.get("lr_decay", scope="general.model")
+        # checkpoint = ModelCheckpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
+        # checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
+        #                                      save_best_only=True, mode='auto', callbacks_to_save=lr,
+        #                                      callbacks_filepath=self.callbacks_name)
+        hist = HistoryAdvanced()
+        self.data_store.set("hist", hist, scope="general.model")
+        callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"},
+                     {"callback": hist, "path": self.callbacks_name % "hist"}]
+        checkpoint = ModelCheckpointAdvanced2(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
+                                              save_best_only=True, mode='auto', callbacks=callbacks)
         self.data_store.set("checkpoint", checkpoint, self.scope)
 
     def load_weights(self):
@@ -92,6 +104,61 @@ class ModelSetup(RunEnvironment):
             keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
 
 
+class HistoryAdvanced(History):
+
+    def __init__(self, old_epoch=None, old_history=None):
+        self.epoch = old_epoch or []
+        self.history = old_history or {}
+        super().__init__()
+
+    def on_train_begin(self, logs=None):
+        pass
+
+
+class ModelCheckpointAdvanced(ModelCheckpoint):
+
+    def __init__(self, *args, **kwargs):
+        self.callbacks_to_save = kwargs.pop("callbacks_to_save")
+        self.callbacks_filepath = kwargs.pop("callbacks_filepath")
+        super().__init__(*args, **kwargs)
+
+    def on_epoch_end(self, epoch, logs=None):
+        super().on_epoch_end(epoch, logs)
+
+        file_path = self.callbacks_filepath
+        if self.epochs_since_last_save == 0 and epoch != 0:
+            if self.save_best_only:
+                current = logs.get(self.monitor)
+                if current == self.best:
+                    with open(file_path, "wb") as f:
+                        pickle.dump(self.callbacks_to_save, f)
+            else:
+                with open(file_path, "wb") as f:
+                    pickle.dump(self.callbacks_to_save, f)
+
+
+class ModelCheckpointAdvanced2(ModelCheckpoint):
+
+    def __init__(self, *args, **kwargs):
+        self.callbacks = kwargs.pop("callbacks")
+        super().__init__(*args, **kwargs)
+
+    def on_epoch_end(self, epoch, logs=None):
+        super().on_epoch_end(epoch, logs)
+
+        for callback in self.callbacks:
+            file_path = callback["path"]
+            if self.epochs_since_last_save == 0 and epoch != 0:
+                if self.save_best_only:
+                    current = logs.get(self.monitor)
+                    if current == self.best:
+                        with open(file_path, "wb") as f:
+                            pickle.dump(callback["callback"], f)
+                else:
+                    with open(file_path, "wb") as f:
+                        pickle.dump(callback["callback"], f)
+
+
 def my_loss():
     loss = l_p_loss(4)
     keras_loss = losses.mean_squared_error
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 96936ce1..e9c7487b 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -5,11 +5,13 @@ import logging
 import os
 import json
 import keras
+import pickle
 
 from src.run_modules.run_environment import RunEnvironment
 from src.data_handling.data_distributor import Distributor
 from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
 from src.helpers import LearningRateDecay
+from src.run_modules.model_setup import ModelCheckpointAdvanced2
 
 
 class Training(RunEnvironment):
@@ -22,8 +24,10 @@ class Training(RunEnvironment):
         self.test_set = None
         self.batch_size = self.data_store.get("batch_size", "general.model")
         self.epochs = self.data_store.get("epochs", "general.model")
-        self.checkpoint = self.data_store.get("checkpoint", "general.model")
+        self.checkpoint: ModelCheckpointAdvanced2 = self.data_store.get("checkpoint", "general.model")
+        # self.callbacks = self.data_store.get("callbacks", "general.model")
         self.lr_sc = self.data_store.get("lr_decay", "general.model")
+        self.hist = self.data_store.get("hist", "general.model")
         self.experiment_name = self.data_store.get("experiment_name", "general")
         self._run()
 
@@ -76,13 +80,36 @@ class Training(RunEnvironment):
         model from training is saved for class variable model.
         """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
-        history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
-                                           steps_per_epoch=len(self.train_set),
-                                           epochs=self.epochs,
-                                           verbose=2,
-                                           validation_data=self.val_set.distribute_on_batches(),
-                                           validation_steps=len(self.val_set),
-                                           callbacks=[self.checkpoint, self.lr_sc])
+        if not os.path.exists(self.checkpoint.filepath):
+            history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+                                               steps_per_epoch=len(self.train_set),
+                                               epochs=self.epochs,
+                                               verbose=2,
+                                               validation_data=self.val_set.distribute_on_batches(),
+                                               validation_steps=len(self.val_set),
+                                               # callbacks=self.callbacks)
+                                               callbacks=[self.checkpoint, self.lr_sc, self.hist])
+        else:
+            lr_filepath = self.checkpoint.callbacks[0]["path"]  # TODO: stopped here. why does training start 1 epoch too early or doesn't it?
+            hist_filepath = self.checkpoint.callbacks[1]["path"]
+            lr_callbacks = pickle.load(open(lr_filepath, "rb"))
+            hist_callbacks = pickle.load(open(hist_filepath, "rb"))
+            self.lr_sc = lr_callbacks
+            self.hist = hist_callbacks
+            self.model = keras.models.load_model(self.checkpoint.filepath)
+            initial_epoch = max(hist_callbacks.epoch) + 1
+            callbacks = [{"callback": self.lr_sc, "path": lr_filepath},
+                         {"callback": self.hist, "path": hist_filepath}]
+            self.checkpoint.callbacks = callbacks
+            history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+                                               steps_per_epoch=len(self.train_set),
+                                               epochs=self.epochs,
+                                               verbose=2,
+                                               validation_data=self.val_set.distribute_on_batches(),
+                                               validation_steps=len(self.val_set),
+                                               callbacks=[self.checkpoint, self.lr_sc, self.hist],
+                                               initial_epoch=initial_epoch)
+            history = self.hist
         self.save_callbacks(history)
         self.load_best_model(self.checkpoint.filepath)
         self.create_monitoring_plots(history, self.lr_sc)
@@ -123,6 +150,7 @@ class Training(RunEnvironment):
             json.dump(history.history, f)
         with open(os.path.join(path, "history_lr.json"), "w") as f:
             json.dump(self.lr_sc.lr, f)
+            # json.dump(self.callbacks["learning_rate"].lr, f)
 
     def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None:
         """
-- 
GitLab