From 8233f14e45331b946a945ae9b59a710d44012213 Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Mon, 28 Jun 2021 12:14:38 +0200
Subject: [PATCH] Merged with IntelliO3_ts_architecture, made
 oversampling_method as parameter to use bin_oversampling. Run Scripts for
 IntelliO3_architecture with and without oversampling.

---
 mlair/configuration/defaults.py       |  1 +
 mlair/run_modules/experiment_setup.py |  6 ++--
 mlair/run_modules/pre_processing.py   |  7 +++--
 run_with_oversampling.py              | 43 +++++++++++++++++++++++++++
 run_without_oversampling.py           | 43 +++++++++++++++++++++++++++
 5 files changed, 96 insertions(+), 4 deletions(-)
 create mode 100644 run_with_oversampling.py
 create mode 100644 run_without_oversampling.py

diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index f2538e98..7b7584ad 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -57,6 +57,7 @@ DEFAULT_USE_MULTIPROCESSING = True
 DEFAULT_USE_MULTIPROCESSING_ON_DEBUG = False
 DEFAULT_OVERSAMPLING_BINS = 10
 DEFAULT_OVERSAMPLING_RATES_CAP = 100
+DEFAULT_OVERSAMPLING_METHOD = None
 
 
 def get_defaults():
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index b249491a..edf1cdf5 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -19,7 +19,8 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
     DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \
-    DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_OVERSAMPLING_BINS, DEFAULT_OVERSAMPLING_RATES_CAP
+    DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_OVERSAMPLING_BINS, \
+    DEFAULT_OVERSAMPLING_RATES_CAP, DEFAULT_OVERSAMPLING_METHOD
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
@@ -219,7 +220,7 @@ class ExperimentSetup(RunEnvironment):
                  hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
                  data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
                  use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
-                 oversampling_bins=None, oversampling_rates_cap=None, **kwargs):
+                 oversampling_bins=None, oversampling_rates_cap=None, oversampling_method = None, **kwargs):
 
         # create run framework
         super().__init__()
@@ -367,6 +368,7 @@ class ExperimentSetup(RunEnvironment):
         # set params for oversampling
         self._set_param("oversampling_bins", oversampling_bins, default=DEFAULT_OVERSAMPLING_BINS)
         self._set_param("oversampling_rates_cap", oversampling_rates_cap, default=DEFAULT_OVERSAMPLING_RATES_CAP)
+        self._set_param("oversampling_method", oversampling_method, default=DEFAULT_OVERSAMPLING_METHOD)
 
         # set remaining kwargs
         if len(kwargs) > 0:
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 9ef5c3f1..e265bd24 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -69,12 +69,15 @@ class PreProcessing(RunEnvironment):
             raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.")
         self.data_store.set("stations", valid_stations)
         self.split_train_val_test()
-        self.apply_oversampling()
+        if self.data_store.get('oversampling_method')=='bin_oversampling':
+            logging.debug("Apply Oversampling")
+            self.apply_oversampling()
+        else:
+            logging.debug("No Oversampling")
         self.report_pre_processing()
         self.prepare_competitors()
 
     def apply_oversampling(self):
-        #if request for oversampling=True/False
         data = self.data_store.get('data_collection', 'train')
         bins = self.data_store.get('oversampling_bins')
         rates_cap = self.data_store.get('oversampling_rates_cap')
diff --git a/run_with_oversampling.py b/run_with_oversampling.py
new file mode 100644
index 00000000..cbab9b4e
--- /dev/null
+++ b/run_with_oversampling.py
@@ -0,0 +1,43 @@
+__author__ = "Lukas Leufen"
+__date__ = '2020-06-29'
+
+import argparse
+from mlair.workflows import DefaultWorkflow
+from mlair.helpers import remove_items
+from mlair.configuration.defaults import DEFAULT_PLOT_LIST
+from mlair.model_modules.model_class import IntelliO3_ts_architecture
+import os
+
+
+def load_stations():
+    import json
+    try:
+        filename = 'supplement/station_list_north_german_plain_rural.json'
+        with open(filename, 'r') as jfile:
+            stations = json.load(jfile)
+    except FileNotFoundError:
+        stations = None
+    return stations
+
+
+def main(parser_args):
+    plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles")
+    workflow = DefaultWorkflow(  # stations=load_stations(),
+        # stations=["DEBW087","DEBW013", "DEBW107",  "DEBW076"],
+        stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"],
+        train_model=False, create_new_model=True, network="UBA",
+        model=IntelliO3_ts_architecture, oversampling_method="bin_oversampling",
+        evaluate_bootstraps=False,  # plot_list=["PlotCompetitiveSkillScore"],
+        competitors=["test_model", "test_model2"],
+        competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
+        window_lead_time=1, oversampling_bins=10, oversampling_rates_cap=100,
+        **parser_args.__dict__)
+    workflow.run()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
+                        help="set experiment date as string")
+    args = parser.parse_args()
+    main(args)
diff --git a/run_without_oversampling.py b/run_without_oversampling.py
new file mode 100644
index 00000000..3c69b450
--- /dev/null
+++ b/run_without_oversampling.py
@@ -0,0 +1,43 @@
+__author__ = "Lukas Leufen"
+__date__ = '2020-06-29'
+
+import argparse
+from mlair.workflows import DefaultWorkflow
+from mlair.helpers import remove_items
+from mlair.configuration.defaults import DEFAULT_PLOT_LIST
+from mlair.model_modules.model_class import IntelliO3_ts_architecture
+import os
+
+
+def load_stations():
+    import json
+    try:
+        filename = 'supplement/station_list_north_german_plain_rural.json'
+        with open(filename, 'r') as jfile:
+            stations = json.load(jfile)
+    except FileNotFoundError:
+        stations = None
+    return stations
+
+
+def main(parser_args):
+    plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles")
+    workflow = DefaultWorkflow(  # stations=load_stations(),
+        # stations=["DEBW087","DEBW013", "DEBW107",  "DEBW076"],
+        stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"],
+        train_model=False, create_new_model=True, network="UBA",
+        model=IntelliO3_ts_architecture,
+        evaluate_bootstraps=False,  # plot_list=["PlotCompetitiveSkillScore"],
+        competitors=["test_model", "test_model2"],
+        competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
+        window_lead_time=1, oversampling_bins=10, oversampling_rates_cap=100,
+        **parser_args.__dict__)
+    workflow.run()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
+                        help="set experiment date as string")
+    args = parser.parse_args()
+    main(args)
-- 
GitLab