From a1c804ce308897e383d70df2d236203ece281597 Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Wed, 18 Aug 2021 08:06:26 +0200
Subject: [PATCH] add IntelliO3_ts_architecture_freeze for two phase, add
 external_weights parameter to load existing model with the given path

---
 mlair/configuration/defaults.py       |  1 +
 mlair/model_modules/model_class.py    | 20 ++++++++++-
 mlair/run_modules/experiment_setup.py |  6 ++--
 mlair/run_modules/training.py         |  6 +++-
 run_two_phase.py                      | 51 +++++++++++++++++++++++++++
 run_with_oversampling.py              | 17 ++++-----
 6 files changed, 89 insertions(+), 12 deletions(-)
 create mode 100644 run_two_phase.py

diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index 47aaf088..dd59201d 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -63,6 +63,7 @@ DEFAULT_OVERSAMPLING_BINS = 10
 DEFAULT_OVERSAMPLING_RATES_CAP = 100
 DEFAULT_OVERSAMPLING_METHOD = None
 
+DEFAULT_EXTERNAL_WEIGHTS = None
 
 
 def get_defaults():
diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py
index bf97864f..a5c6a8e3 100644
--- a/mlair/model_modules/model_class.py
+++ b/mlair/model_modules/model_class.py
@@ -381,6 +381,7 @@ class IntelliO3_ts_architecture(AbstractModelClass):
                                 SymmetricPadding2D=SymmetricPadding2D,
                                 LearningRateDecay=LearningRateDecay)
 
+
     def set_model(self):
         """
         Build the model.
@@ -461,4 +462,21 @@ class IntelliO3_ts_architecture(AbstractModelClass):
                                 "loss": [keras.losses.mean_squared_error, keras.losses.mean_squared_error],
                                 "metrics": ['mse'],
                                 "loss_weights": [.01, .99]
-                                }
\ No newline at end of file
+                                }
+
+class IntelliO3_ts_architecture_freeze(IntelliO3_ts_architecture):
+    def __init__(self, input_shape: list, output_shape: list):
+        super().__init__(input_shape, output_shape)
+        self.freeze_layers()
+        self.initial_lr = 1e-5
+    '''
+    def freeze_layers(self):
+        for layer in self.model.layers:
+            if not isinstance(layer, keras.layers.core.Dense):
+                layer.trainable = False
+    '''
+
+    def freeze_layers(self):
+        for layer in self.model.layers:
+            if layer.name not in ["minor_1_out_Dense", "Main_out_Dense"]:
+                layer.trainable = False
\ No newline at end of file
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index 1bd37a63..6d8bde1a 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -24,7 +24,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_OVERSAMPLING_BINS, \
     DEFAULT_OVERSAMPLING_RATES_CAP, DEFAULT_OVERSAMPLING_METHOD, \
     DEFAULT_MAX_NUMBER_MULTIPROCESSING, \
-    DEFAULT_BOOTSTRAP_TYPE, DEFAULT_BOOTSTRAP_METHOD
+    DEFAULT_BOOTSTRAP_TYPE, DEFAULT_BOOTSTRAP_METHOD, DEFAULT_EXTERNAL_WEIGHTS
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
@@ -225,7 +225,8 @@ class ExperimentSetup(RunEnvironment):
                  data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
                  use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
                  max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None,
-                 oversampling_bins=None, oversampling_rates_cap=None, oversampling_method = None, **kwargs):
+                 oversampling_bins=None, oversampling_rates_cap=None, oversampling_method=None, external_weights=None,
+                 **kwargs):
 
         # create run framework
         super().__init__()
@@ -287,6 +288,7 @@ class ExperimentSetup(RunEnvironment):
         # set model path
         self._set_param("model_path", None, os.path.join(experiment_path, "model"))
         path_config.check_path_and_create(self.data_store.get("model_path"))
+        self._set_param("external_weights", external_weights, default=DEFAULT_EXTERNAL_WEIGHTS)
 
         # set plot path
         default_plot_path = os.path.join(experiment_path, "plots")
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 00e8eae1..9cafbbf6 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -171,7 +171,11 @@ class Training(RunEnvironment):
         except IndexError:
             epo_timing = None
         self.save_callbacks_as_json(history, lr, epo_timing)
-        self.load_best_model(checkpoint.filepath)
+        external_weights = self.data_store.get("external_weights")
+        if external_weights is not None:
+            self.load_best_model(external_weights)
+        else:
+            self.load_best_model(checkpoint.filepath)
         self.create_monitoring_plots(history, lr)
 
     def save_model(self) -> None:
diff --git a/run_two_phase.py b/run_two_phase.py
new file mode 100644
index 00000000..421c050e
--- /dev/null
+++ b/run_two_phase.py
@@ -0,0 +1,51 @@
+__author__ = "Lukas Leufen"
+__date__ = '2020-06-29'
+
+import argparse
+from mlair.workflows import DefaultWorkflowHPC
+from mlair.helpers import remove_items
+from mlair.configuration.defaults import DEFAULT_PLOT_LIST
+from mlair.model_modules.model_class import IntelliO3_ts_architecture, IntelliO3_ts_architecture_freeze
+import os
+
+
+def load_stations(external_station_list=None):
+    import json
+    if external_station_list is None:
+        external_station_list = 'supplement/station_list_north_german_plain_rural.json'
+    try:
+        filename = external_station_list
+        with open(filename, 'r') as jfile:
+            stations = json.load(jfile)
+    except FileNotFoundError:
+        stations = None
+    return stations
+
+# 1. How to load existing model
+# https://www.tensorflow.org/tutorials/images/transfer_learning
+# 3. How many epochs?
+# 4. Full data set?
+# 5. lower learning rate?
+
+
+def main(parser_args):
+    plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles")
+    workflow = DefaultWorkflowHPC(#stations=load_stations('supplement/German_background_stations.json'),
+        stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"],
+        epochs=50, external_weights="/home/vincentgramlich/mlair/data/weights/test_weight.h5",
+        train_model=True, create_new_model=False, network="UBA",
+        model=IntelliO3_ts_architecture_freeze,
+        evaluate_bootstraps=False,  # plot_list=["PlotCompetitiveSkillScore"],
+        # competitors=["test_model", "test_model2"],
+        competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
+        window_lead_time=1, # oversampling_bins=10, oversampling_rates_cap=100,
+        **parser_args.__dict__)
+    workflow.run()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
+                        help="set experiment date as string")
+    args = parser.parse_args()
+    main(args)
diff --git a/run_with_oversampling.py b/run_with_oversampling.py
index 39cf7e12..781885f0 100644
--- a/run_with_oversampling.py
+++ b/run_with_oversampling.py
@@ -24,15 +24,16 @@ def load_stations(external_station_list=None):
 
 def main(parser_args):
     plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles")
-    workflow = DefaultWorkflowHPC(stations=load_stations('supplement/German_background_stations.json'),
-        #stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"],
-        epochs=150,
+    workflow = DefaultWorkflowHPC(#stations=load_stations('supplement/German_background_stations.json'),
+        stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"],
+        epochs=1,
         train_model=True, create_new_model=True, network="UBA",
-        model=IntelliO3_ts_architecture, oversampling_method="bin_oversampling",
-        evaluate_bootstraps=False,  # plot_list=["PlotCompetitiveSkillScore"],
-        competitors=["IntelliO3"],
-        competitor_path="/p/project/deepacf/intelliaq/gramlich1/mlair/competitors/o3",
-        window_lead_time=1, oversampling_bins=10, oversampling_rates_cap=100,
+        #model=IntelliO3_ts_architecture,
+        oversampling_method="bin_oversampling",
+        evaluate_bootstraps=False, plot_list=["PlotOversamplingContingency"],
+        competitors=["intellitest"],
+        #competitor_path="/p/project/deepacf/intelliaq/gramlich1/mlair/competitors/o3",
+        window_lead_time=1, oversampling_bins=2, oversampling_rates_cap=2,
         **parser_args.__dict__)
     workflow.run()
 
-- 
GitLab