From 85eca0975ad58c70be2b76314e60601f10b11c3b Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Tue, 24 Aug 2021 12:28:35 +0200
Subject: [PATCH] Fixed finetuning, allowed correct learning_rate

---
 mlair/model_modules/model_class.py        | 15 +++++---
 mlair/plotting/postprocessing_plotting.py |  3 +-
 mlair/run_modules/model_setup.py          | 13 +++++--
 mlair/run_modules/training.py             |  6 +--
 run_with_finetuning.py                    | 45 +++++++++++++++++++++++
 run_without_finetuning.py                 | 45 +++++++++++++++++++++++
 6 files changed, 111 insertions(+), 16 deletions(-)
 create mode 100644 run_with_finetuning.py
 create mode 100644 run_without_finetuning.py

diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py
index a5c6a8e3..83434268 100644
--- a/mlair/model_modules/model_class.py
+++ b/mlair/model_modules/model_class.py
@@ -364,17 +364,18 @@ class IntelliO3_ts_architecture(AbstractModelClass):
         assert len(output_shape) == 1
         super().__init__(input_shape[0], output_shape[0])
 
-        from mlair.model_modules.keras_extensions import LearningRateDecay
-
         # settings
         self.dropout_rate = .35
         self.regularizer = keras.regularizers.l2(0.01)
         self.initial_lr = 1e-4
-        self.lr_decay = LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
         self.activation = keras.layers.ELU
         self.padding = "SymPad2D"
+        self.apply_to_model()
 
         # apply to model
+    def apply_to_model(self):
+        from mlair.model_modules.keras_extensions import LearningRateDecay
+        self.lr_decay = LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
         self.set_model()
         self.set_compile_options()
         self.set_custom_objects(loss=self.compile_options["loss"][0],
@@ -467,9 +468,12 @@ class IntelliO3_ts_architecture(AbstractModelClass):
 class IntelliO3_ts_architecture_freeze(IntelliO3_ts_architecture):
     def __init__(self, input_shape: list, output_shape: list):
         super().__init__(input_shape, output_shape)
+
         self.freeze_layers()
         self.initial_lr = 1e-5
-    '''
+        self.apply_to_model()
+        # self.lr_decay = None
+
     def freeze_layers(self):
         for layer in self.model.layers:
             if not isinstance(layer, keras.layers.core.Dense):
@@ -479,4 +483,5 @@ class IntelliO3_ts_architecture_freeze(IntelliO3_ts_architecture):
     def freeze_layers(self):
         for layer in self.model.layers:
             if layer.name not in ["minor_1_out_Dense", "Main_out_Dense"]:
-                layer.trainable = False
\ No newline at end of file
+                layer.trainable = False
+    '''
\ No newline at end of file
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 6eba642b..2c8fd6e4 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -30,7 +30,6 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
 @TimeTrackingWrapper
 class PlotContingency(AbstractPlotClass):
-    #Todo: Get min and max_label
 
     def __init__(self, station_names, file_path, comp_path, file_name, plot_folder: str = ".", model_name: str = "nn",
                  obs_name: str = "obs", comp_names: str = "IntelliO3",
@@ -90,7 +89,7 @@ class PlotContingency(AbstractPlotClass):
         else:
             for type in data.type.values.tolist():
                 plt.plot(range(self._min_threshold, self._max_threshold), data.loc[dict(type=type, scores=score)], label=type)
-        plt.title(self._plot_names[self._plot_counter][13:])
+        plt.title(self._plot_names[self._plot_counter])
         plt.legend()
         self.plot_name = self._plot_names[self._plot_counter]
         self._plot_counter = self._plot_counter + 1
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 83f4a2bd..28a4e99a 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -83,7 +83,10 @@ class ModelSetup(RunEnvironment):
         self.plot_model()
 
         # load weights if no training shall be performed
-        if not self._train_model and not self._create_new_model:
+        external_weights = self.data_store.get("external_weights")
+        if external_weights is not None:
+            self.load_weights(external_weights)
+        elif not self._train_model and not self._create_new_model:
             self.load_weights()
 
         # create checkpoint
@@ -131,11 +134,13 @@ class ModelSetup(RunEnvironment):
                                           save_best_only=True, mode='auto')
         self.data_store.set("callbacks", callbacks, self.scope)
 
-    def load_weights(self):
+    def load_weights(self, external_weight=None):
         """Try to load weights from existing model or skip if not possible."""
+        if external_weight is None:
+            external_weight = self.model_name
         try:
-            self.model.load_weights(self.model_name)
-            logging.info(f"reload weights from model {self.model_name} ...")
+            self.model.load_weights(external_weight)
+            logging.info(f"reload weights from model {external_weight} ...")
         except OSError:
             logging.info('no weights to reload...')
 
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 9cafbbf6..00e8eae1 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -171,11 +171,7 @@ class Training(RunEnvironment):
         except IndexError:
             epo_timing = None
         self.save_callbacks_as_json(history, lr, epo_timing)
-        external_weights = self.data_store.get("external_weights")
-        if external_weights is not None:
-            self.load_best_model(external_weights)
-        else:
-            self.load_best_model(checkpoint.filepath)
+        self.load_best_model(checkpoint.filepath)
         self.create_monitoring_plots(history, lr)
 
     def save_model(self) -> None:
diff --git a/run_with_finetuning.py b/run_with_finetuning.py
new file mode 100644
index 00000000..65611793
--- /dev/null
+++ b/run_with_finetuning.py
@@ -0,0 +1,45 @@
+__author__ = "Lukas Leufen"
+__date__ = '2020-06-29'
+
+import argparse
+from mlair.workflows import DefaultWorkflow
+from mlair.helpers import remove_items
+from mlair.configuration.defaults import DEFAULT_PLOT_LIST
+from mlair.model_modules.model_class import IntelliO3_ts_architecture, IntelliO3_ts_architecture_freeze
+import os
+
+
+def load_stations():
+    import json
+    try:
+        filename = 'supplement/station_list_north_german_plain_rural.json'
+        with open(filename, 'r') as jfile:
+            stations = json.load(jfile)
+    except FileNotFoundError:
+        stations = None
+    return stations
+
+
+def main(parser_args):
+    plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles")
+    workflow = DefaultWorkflow(  # stations=load_stations(),
+        stations=["DEBW087", "DEBW013", "DEBW107",  "DEBW076"],
+        #stations=["DEBW013", "DEBW087"],
+        epochs=1, external_weights="/home/vincentgramlich/mlair/data/weights/testrun_network_daily_model-best.h5",
+        train_model=True, create_new_model=True, network="UBA",
+        model=IntelliO3_ts_architecture_freeze,
+        window_lead_time=1,
+        #oversampling_method="bin_oversampling", oversampling_bins=10, oversampling_rates_cap=100, window_lead_time=2,
+        evaluate_bootstraps=False,  plot_list=["PlotContingency"],
+        competitors=["withoutfinetuning"],
+        competitor_path=os.path.join(os.getcwd(), "data", "competitors", "o3"),
+        **parser_args.__dict__, start_script=__file__)
+    workflow.run()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
+                        help="set experiment date as string")
+    args = parser.parse_args()
+    main(args)
diff --git a/run_without_finetuning.py b/run_without_finetuning.py
new file mode 100644
index 00000000..61644a33
--- /dev/null
+++ b/run_without_finetuning.py
@@ -0,0 +1,45 @@
+__author__ = "Lukas Leufen"
+__date__ = '2020-06-29'
+
+import argparse
+from mlair.workflows import DefaultWorkflow
+from mlair.helpers import remove_items
+from mlair.configuration.defaults import DEFAULT_PLOT_LIST
+from mlair.model_modules.model_class import IntelliO3_ts_architecture, IntelliO3_ts_architecture_freeze
+import os
+
+
+def load_stations():
+    import json
+    try:
+        filename = 'supplement/station_list_north_german_plain_rural.json'
+        with open(filename, 'r') as jfile:
+            stations = json.load(jfile)
+    except FileNotFoundError:
+        stations = None
+    return stations
+
+
+def main(parser_args):
+    plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles")
+    workflow = DefaultWorkflow(  # stations=load_stations(),
+        stations=["DEBW087","DEBW013", "DEBW107",  "DEBW076"],
+        #stations=["DEBW013", "DEBW087"],
+        epochs=5, #external_weights="/home/vincentgramlich/mlair/data/weights/test_weight.h5",
+        train_model=True, create_new_model=True, network="UBA",
+        model=IntelliO3_ts_architecture,
+        window_lead_time=1,
+        #oversampling_method="bin_oversampling", oversampling_bins=10, oversampling_rates_cap=100, window_lead_time=2,
+        evaluate_bootstraps=False,  plot_list=plots,
+        #competitors=["testcompetitor", "testcompetitor2"],
+        competitor_path=os.path.join(os.getcwd(), "data", "competitors"),
+        **parser_args.__dict__, start_script=__file__)
+    workflow.run()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
+                        help="set experiment date as string")
+    args = parser.parse_args()
+    main(args)
\ No newline at end of file
-- 
GitLab