From fa64b1369981c21bab425bc94b370c68f8b6d75d Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 10 Dec 2019 15:14:58 +0100
Subject: [PATCH] worked on docs and tests, updated some other tests

---
 .gitignore                                 |   1 +
 run.py                                     |   2 +-
 src/modules/model_setup.py                 |  39 +++++--
 src/modules/training.py                    |  78 ++++++++++++--
 test/test_helpers.py                       |   2 +-
 test/test_modules/test_experiment_setup.py |   5 +-
 test/test_modules/test_model_setup.py      |   2 +-
 test/test_modules/test_training.py         | 117 +++++++++++++++++++++
 8 files changed, 223 insertions(+), 23 deletions(-)

diff --git a/.gitignore b/.gitignore
index 11b7c159..cec17a77 100644
--- a/.gitignore
+++ b/.gitignore
@@ -58,3 +58,4 @@ htmlcov/
 /test/test_modules/data/
 report.html
 /TestExperiment/
+/testrun_network/
diff --git a/run.py b/run.py
index 0f88f37b..e45b2dd6 100644
--- a/run.py
+++ b/run.py
@@ -17,7 +17,7 @@ def main(parser_args):
 
     with RunEnvironment():
         ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
-                        station_type='background')
+                        station_type='background', trainable=True)
         PreProcessing()
 
         ModelSetup()
diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py
index d75ecb36..f6c25aff 100644
--- a/src/modules/model_setup.py
+++ b/src/modules/model_setup.py
@@ -26,8 +26,8 @@ class ModelSetup(RunEnvironment):
         self.model = None
         path = self.data_store.get("experiment_path", "general")
         exp_name = self.data_store.get("experiment_name", "general")
-        self.model_name = os.path.join(path, f"{exp_name}_model-best.h5")
         self.scope = "general.model"
+        self.checkpoint_name = os.path.join(path, f"{exp_name}_model-best.h5")
         self._run()
 
     def _run(self):
@@ -58,20 +58,20 @@ class ModelSetup(RunEnvironment):
         self.data_store.put("model", self.model, self.scope)
 
     def _set_checkpoint(self):
-        checkpoint = ModelCheckpoint(self.model_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
+        checkpoint = ModelCheckpoint(self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
         self.data_store.put("checkpoint", checkpoint, self.scope)
 
     def load_weights(self):
         try:
-            logging.debug('reload weights...')
-            self.model.load_weights(self.model_name)
+            self.model.load_weights(self.checkpoint_name)
+            logging.info('reload weights...')
         except OSError:
-            logging.debug('no weights to reload...')
+            logging.info('no weights to reload...')
 
     def build_model(self):
         args_list = ["activation", "window_history_size", "channels", "regularizer", "dropout_rate", "window_lead_time"]
         args = self.data_store.create_args_dict(args_list, self.scope)
-        self.model = my_model(**args)
+        self.model = my_little_model(**args)
 
     def plot_model(self):  # pragma: no cover
         with tf.device("/cpu:0"):
@@ -102,7 +102,7 @@ class ModelSetup(RunEnvironment):
         self.data_store.put("lr_decay", LearningRateDecay(base_lr=initial_lr, drop=.94, epochs_drop=10), self.scope)
 
         # learning settings
-        self.data_store.put("epochs", 10, self.scope)
+        self.data_store.put("epochs", 2, self.scope)
         self.data_store.put("batch_size", int(256), self.scope)
 
         # activation
@@ -110,7 +110,7 @@ class ModelSetup(RunEnvironment):
         self.data_store.put("activation", activation, self.scope)
 
         # set los
-        loss_all = my_loss()
+        loss_all = my_little_loss()
         self.data_store.put("loss", loss_all, self.scope)
 
 
@@ -121,6 +121,29 @@ def my_loss():
     return loss_all
 
 
+def my_little_loss():
+    return losses.mean_squared_error
+
+
+def my_little_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time):
+
+    X_input = keras.layers.Input(
+        shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
+    X_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(X_input)
+    X_in = activation(name='{}_conv_act'.format("major"))(X_in)
+    X_in = keras.layers.Flatten(name='{}'.format("major"))(X_in)
+    X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_1'.format("major"))(X_in)
+    X_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(X_in)
+    X_in = activation()(X_in)
+    X_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(X_in)
+    X_in = activation()(X_in)
+    X_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(X_in)
+    X_in = activation()(X_in)
+    X_in = keras.layers.Dense(window_lead_time, name='{}_Dense'.format("major"))(X_in)
+    out_main = activation()(X_in)
+    return keras.Model(inputs=X_input, outputs=[out_main])
+
+
 def my_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time):
 
     conv_settings_dict1 = {
diff --git a/src/modules/training.py b/src/modules/training.py
index 71f76584..87dcf35e 100644
--- a/src/modules/training.py
+++ b/src/modules/training.py
@@ -5,6 +5,7 @@ __date__ = '2019-12-05'
 import logging
 import os
 import json
+import keras
 
 from src.modules.run_environment import RunEnvironment
 from src.data_handling.data_distributor import Distributor
@@ -25,27 +26,54 @@ class Training(RunEnvironment):
         self.experiment_name = self.data_store.get("experiment_name", "general")
         self._run()
 
-    def _run(self):
+    def _run(self) -> None:
+        """
+        Perform training
+        1) set_generators():
+            set generators for training, validation and testing and distribute according to batch size
+        2) make_predict_function():
+            create predict function before distribution on multiple nodes (detailed information in method description)
+        3) train():
+            train model and save callbacks
+        4) save_model():
+            save best model from training as final model
+        """
         self.set_generators()
         self.make_predict_function()
         self.train()
+        self.save_model()
 
-    def make_predict_function(self):
-        # create the predict function before distributing. This is necessary, because tf will compile the predict
-        # function just in the moment it is used the first time. This can cause problems, if the model is distributed
-        # on different workers. To prevent this, the function is pre-compiled. See discussion @
-        # https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252
+    def make_predict_function(self) -> None:
+        """
+        Creates the predict function. Must be called before distributing. This is necessary, because tf will compile
+        the predict function just in the moment it is used the first time. This can cause problems, if the model is
+        distributed on different workers. To prevent this, the function is pre-compiled. See discussion @
+        https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252
+        """
         self.model._make_predict_function()
 
-    def _set_gen(self, mode):
+    def _set_gen(self, mode: str) -> None:
+        """
+        Set and distribute the generators for given mode regarding batch size
+        :param mode: name of set, should be from ["train", "val", "test"]
+        """
         gen = self.data_store.get("generator", f"general.{mode}")
         setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size))
 
-    def set_generators(self):
+    def set_generators(self) -> None:
+        """
+        Set all generators for training, validation, and testing subsets. The called sub-method will automatically
+        distribute the data according to the batch size. The subsets can be accessed as class variables train_set,
+        val_set, and test_set .
+        """
         for mode in ["train", "val", "test"]:
             self._set_gen(mode)
 
-    def train(self):
+    def train(self) -> None:
+        """
+        Perform training using keras fit_generator(). Callbacks are stored locally in the experiment directory. Best
+        model from training is saved for class variable model.
+        """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
         history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                            steps_per_epoch=len(self.train_set),
@@ -55,8 +83,38 @@ class Training(RunEnvironment):
                                            validation_steps=len(self.val_set),
                                            callbacks=[self.checkpoint, self.lr_sc])
         self.save_callbacks(history)
+        self.load_best_model(self.checkpoint.filepath)
 
-    def save_callbacks(self, history):
+    def save_model(self) -> None:
+        """
+        save model in local experiment directory. Model is named as <experiment_name>_my_model.h5 .
+        """
+        path = self.data_store.get("experiment_path", "general")
+        name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5"
+        model_name = os.path.join(path, name)
+        logging.debug(f"save best model to {model_name}")
+        self.model.save(model_name)
+
+    def load_best_model(self, name: str) -> None:
+        """
+        Load model weights for model with name. Skip if no weights are available.
+        :param name: name of the model to load weights for
+        """
+        logging.debug(f"load best model: {name}")
+        try:
+            self.model.load_weights(name)
+            logging.info('reload weights...')
+        except OSError:
+            logging.info('no weights to reload...')
+
+    def save_callbacks(self, history: keras.callbacks.History) -> None:
+        """
+        Save callbacks (history, learning rate) of training.
+        * history.history -> history.json
+        * lr_sc.lr -> history_lr.json
+        :param history: history object of training
+        """
+        logging.debug("saving callbacks")
         path = self.data_store.get("experiment_path", "general")
         with open(os.path.join(path, "history.json"), "w") as f:
             json.dump(history.history, f)
diff --git a/test/test_helpers.py b/test/test_helpers.py
index 181c5f29..ce5d28a6 100644
--- a/test/test_helpers.py
+++ b/test/test_helpers.py
@@ -184,7 +184,7 @@ class TestSetExperimentName:
         assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "TestExperiment"))
         exp_name, exp_path = set_experiment_name(experiment_date="2019-11-14", experiment_path="./test2")
         assert exp_name == "2019-11-14_network"
-        assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test2"))
+        assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test2", exp_name))
 
     def test_set_experiment_from_sys(self):
         exp_name, _ = set_experiment_name(experiment_date="2019-11-14")
diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py
index e1ec57f9..bfff606e 100644
--- a/test/test_modules/test_experiment_setup.py
+++ b/test/test_modules/test_experiment_setup.py
@@ -113,8 +113,9 @@ class TestExperimentSetup:
         assert data_store.get("trainable", "general") is True
         assert data_store.get("fraction_of_training", "general") == 0.5
         # set experiment name
-        assert data_store.get("experiment_name", "general") == "TODAY_network/"
-        path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
+        assert data_store.get("experiment_name", "general") == "TODAY_network"
+        path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder",
+                                            "TODAY_network"))
         assert data_store.get("experiment_path", "general") == path
         # setup for data
         assert data_store.get("var_all_dict", "general") == {'o3': 'dma8eu', 'relhum': 'average_values',
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
index 5c8223ee..85cb24e3 100644
--- a/test/test_modules/test_model_setup.py
+++ b/test/test_modules/test_model_setup.py
@@ -33,7 +33,7 @@ class TestModelSetup:
 
     def test_set_checkpoint(self, setup):
         assert "general.modeltest" not in setup.data_store.search_name("checkpoint")
-        setup.model_name = "TestName"
+        setup.checkpoint_name = "TestName"
         setup._set_checkpoint()
         assert "general.modeltest" in setup.data_store.search_name("checkpoint")
 
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index 590bb380..3598fe0b 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -1,7 +1,18 @@
 import keras
+import pytest
+from keras.callbacks import ModelCheckpoint, History
+import mock
+import os
+import json
+import shutil
+import logging
 
 from src.inception_model import InceptionModelBase
 from src.flatten import flatten_tail
+from src.modules.training import Training
+from src.modules.run_environment import RunEnvironment
+from src.data_handling.data_distributor import Distributor
+from src.helpers import LearningRateDecay, PyTestRegex
 
 
 def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
@@ -19,3 +30,109 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
     X_in = keras.layers.Dropout(dropout_rate)(X_in)
     out.append(flatten_tail(X_in, 'Main', activation=activation))
     return keras.Model(inputs=X_input, outputs=out)
+
+
+class TestTraining:
+
+    @pytest.fixture
+    def init_without_run(self, path):
+        obj = object.__new__(Training)
+        super(Training, obj).__init__()
+        obj.model = my_test_model(keras.layers.PReLU, 5, 3, 0.1, False)
+        obj.train_set = None
+        obj.val_set = None
+        obj.test_set = None
+        obj.batch_size = 256
+        obj.epochs = 2
+        obj.checkpoint = ModelCheckpoint("model_checkpoint", monitor='val_loss', save_best_only=True, mode='auto')
+        obj.lr_sc = LearningRateDecay()
+        obj.experiment_name = "TestExperiment"
+        obj.data_store.put("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
+        obj.data_store.put("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val")
+        obj.data_store.put("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test")
+        os.makedirs(path)
+        obj.data_store.put("experiment_path", path, "general")
+        obj.data_store.put("experiment_name", "TestExperiment", "general")
+        yield obj
+        if os.path.exists(path):
+            shutil.rmtree(path)
+        RunEnvironment().__del__()
+
+    @pytest.fixture
+    def history(self):
+        h = History()
+        h.epoch = [0, 1]
+        h.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
+                     'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
+                     'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
+                     'loss': [0.6795708956961347, 0.45963566494176616],
+                     'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
+                     'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
+        return h
+
+    @pytest.fixture
+    def path(self):
+        return os.path.join(os.path.dirname(__file__), "TestExperiment")
+
+    def test_init(self):
+        pass
+
+    def test_run(self):
+        pass
+
+    def test_make_predict_function(self, init_without_run):
+        assert hasattr(init_without_run.model, "predict_function") is False
+        init_without_run.make_predict_function()
+        assert hasattr(init_without_run.model, "predict_function")
+
+    def test_set_gen(self, init_without_run):
+        assert init_without_run.train_set is None
+        init_without_run._set_gen("train")
+        assert isinstance(init_without_run.train_set, Distributor)
+        assert init_without_run.train_set.generator.return_value == "mock_train_gen"
+
+    def test_set_generators(self, init_without_run):
+        sets = ["train", "val", "test"]
+        assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
+        init_without_run.set_generators()
+        assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
+        assert all([getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets])
+
+    def test_train(self, init_without_run):
+        pass
+
+    def test_save_model(self, init_without_run, path, caplog):
+        caplog.set_level(logging.DEBUG)
+        model_name = "TestExperiment_my_model.h5"
+        assert model_name not in os.listdir(path)
+        init_without_run.save_model()
+        assert caplog.record_tuples[0] == ("root", 10, PyTestRegex(f"save best model to {os.path.join(path, model_name)}"))
+        assert model_name in os.listdir(path)
+
+    def test_load_best_model_no_weights(self, init_without_run, caplog):
+        caplog.set_level(logging.DEBUG)
+        init_without_run.load_best_model("notExisting")
+        assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
+        assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
+
+    def test_save_callbacks_history_created(self, init_without_run, history, path):
+        init_without_run.save_callbacks(history)
+        assert "history.json" in os.listdir(path)
+
+    def test_save_callbacks_lr_created(self, init_without_run, history, path):
+        init_without_run.save_callbacks(history)
+        assert "history_lr.json" in os.listdir(path)
+
+    def test_save_callbacks_inspect_history(self, init_without_run, history, path):
+        init_without_run.save_callbacks(history)
+        with open(os.path.join(path, "history.json")) as jfile:
+            hist = json.load(jfile)
+            assert hist == history.history
+
+    def test_save_callbacks_inspect_lr(self, init_without_run, history, path):
+        init_without_run.save_callbacks(history)
+        with open(os.path.join(path, "history_lr.json")) as jfile:
+            lr = json.load(jfile)
+            assert lr == init_without_run.lr_sc.lr
+
+
-- 
GitLab