From a31e2e7e1aad5f5123714329b434dff765664abb Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Fri, 23 Jul 2021 10:20:51 +0200
Subject: [PATCH] ..

---
 mlair/data_handler/default_data_handler.py |  2 +-
 mlair/plotting/data_insight_plotting.py    | 12 ++++++++-
 mlair/plotting/postprocessing_plotting.py  |  4 +++
 mlair/run_modules/experiment_setup.py      | 20 +++++++++------
 mlair/run_modules/post_processing.py       | 10 +++++---
 mlair/run_modules/pre_processing.py        | 30 ++++------------------
 6 files changed, 39 insertions(+), 39 deletions(-)

diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index fc5a4d96..acc3caa0 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -206,7 +206,7 @@ class DefaultDataHandler(AbstractDataHandler):
                     else:
                         self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=self.time_dim), self._X_extreme, extremes_X))
                         self._Y_extreme = xr.concat([self._Y_extreme, extremes_Y], dim=self.time_dim)
-        #self._store(fresh_store=True)
+
 
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
                           timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM):
diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index aff3b4c7..ccea0b84 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -22,11 +22,14 @@ from mlair.plotting.abstract_plot_class import AbstractPlotClass
 @TimeTrackingWrapper
 class PlotOversampling(AbstractPlotClass):
 
-    def __init__(self, Y, Y_extreme, bin_edges, oversampling_rates, plot_folder: str = ".",
+    def __init__(self, data, bin_edges, oversampling_rates, plot_folder: str = ".",
                  plot_names=["oversampling_histogram", "oversampling_density_histogram", "oversampling_rates",
                             "oversampling_rates_deviation"]):
 
         super().__init__(plot_folder, plot_names[0])
+
+        Y_hist, Y_extreme_hist = self._calculate_hist(data, bin_edges)
+
         Y_hist, Y_extreme_hist = self._plot_oversampling_histogram(Y, Y_extreme, bin_edges)
         real_oversampling = Y_extreme_hist / Y_hist
         self._save()
@@ -40,6 +43,13 @@ class PlotOversampling(AbstractPlotClass):
         self._plot_oversampling_rates_deviation(oversampling_rates, real_oversampling)
         self._save()
 
+    def _calculate_histogram(self, data, bin_edges):
+        Y_hist = np.zeros(len(bin_edges),1)
+        Y_extreme_hist = np.zeros(len(bin_edges), 1)
+        for station in data:
+            Y = station.get_Y(as_numpy=True, upsampling=False)
+            Y_extreme = station.get_Y(as_numpy=True, upsampling=True)
+
     def _plot_oversampling_histogram(self, Y, Y_extreme, bin_edges):
         fig, ax = plt.subplots(1, 1)
         Y_hist = Y.plot.hist(bins=bin_edges, histtype="step", label="Before", ax=ax)[0]
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 3bef0c30..67233780 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -41,6 +41,10 @@ class PlotOversamplingContingency(AbstractPlotClass):
         f = []
         max_label = 0
         min_label = 0
+        for station in station_names:
+            file = os.path.join(file_path, file_name % station)
+            forecast = xr.open_dataarray(file)
+            competitors = extract_method(station)
         for threshold in range(min_label, max_label):
             true_above = 0
             false_above = 0
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index edf1cdf5..cefd4505 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -238,15 +238,23 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("bootstrap_path", bootstrap_path)
         self._set_param("train_model", train_model, default=DEFAULT_TRAIN_MODEL)
         self._set_param("fraction_of_training", fraction_of_train, default=DEFAULT_FRACTION_OF_TRAINING)
+        self._set_param("batch_size", batch_size, default=DEFAULT_BATCH_SIZE)
+        self._set_param("epochs", epochs, default=DEFAULT_EPOCHS)
+
+        # set params for oversampling
+        self._set_param("oversampling_bins", oversampling_bins, default=DEFAULT_OVERSAMPLING_BINS)
+        self._set_param("oversampling_rates_cap", oversampling_rates_cap, default=DEFAULT_OVERSAMPLING_RATES_CAP)
+        self._set_param("oversampling_method", oversampling_method, default=DEFAULT_OVERSAMPLING_METHOD)
         self._set_param("extreme_values", extreme_values, default=DEFAULT_EXTREME_VALUES, scope="train")
         self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only,
                         default=DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY, scope="train")
-        self._set_param("upsampling", extreme_values is not None, scope="train")
-        upsampling = self.data_store.get("upsampling", "train")
+        upsampling = (extreme_values is not None) or (oversampling_method is not None)
+        self._set_param("upsampling", upsampling, scope="train")
         permute_data = DEFAULT_PERMUTE_DATA if permute_data_on_training is None else permute_data_on_training
         self._set_param("permute_data", permute_data or upsampling, scope="train")
-        self._set_param("batch_size", batch_size, default=DEFAULT_BATCH_SIZE)
-        self._set_param("epochs", epochs, default=DEFAULT_EPOCHS)
+        if (extreme_values is not None) and (oversampling_method is not None):
+            logging.info("Parameters extreme_values and oversampling_method are set. In this case only "
+                         "oversampling_method is used.")
 
         # set experiment name
         sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING)  # always related to output sampling
@@ -365,10 +373,6 @@ class ExperimentSetup(RunEnvironment):
         # set model architecture class
         self._set_param("model_class", model, VanillaModel)
 
-        # set params for oversampling
-        self._set_param("oversampling_bins", oversampling_bins, default=DEFAULT_OVERSAMPLING_BINS)
-        self._set_param("oversampling_rates_cap", oversampling_rates_cap, default=DEFAULT_OVERSAMPLING_RATES_CAP)
-        self._set_param("oversampling_method", oversampling_method, default=DEFAULT_OVERSAMPLING_METHOD)
 
         # set remaining kwargs
         if len(kwargs) > 0:
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index c5f5b2d3..2febedb6 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -311,7 +311,9 @@ class PostProcessing(RunEnvironment):
                     "PlotOversamplingContingency" in plot_list):
                 predictions = None
                 labels = None
-                PlotOversamplingContingency(predictions, labels, plot_folder=self.plot_path)
+                PlotOversamplingContingency(extract_method=self.load_competitors(), station_names=self.test_data.keys(),
+                                            file_path=path, file_name=r"forecasts_%s_test.nc",
+                                            plot_folder=self.plot_path)
         except Exception as e:
             logging.error(f"Could not create plot OversamplingContingencyPlots due to the following error: {e}")
 
@@ -320,9 +322,9 @@ class PostProcessing(RunEnvironment):
                     "PlotOversampling" in plot_list):
                 bin_edges = self.data_store.get('oversampling_bin_edges')
                 oversampling_rates = self.data_store.get('oversampling_rates_capped','train')
-                Y = self.data_store.get('Oversampling_Y')
-                Y_extreme = self.data_store.get('Oversampling_Y_extreme')
-                PlotOversampling(Y, Y_extreme, bin_edges, oversampling_rates, plot_folder=self.plot_path)
+                #Y = self.data_store.get('Oversampling_Y')
+                #Y_extreme = self.data_store.get('Oversampling_Y_extreme')
+                PlotOversampling(self.train_data, bin_edges, oversampling_rates, plot_folder=self.plot_path)
         except Exception as e:
             logging.error(f"Could not create plot OversamplingPlots due to the following error: {e}")
 
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 69f14bed..2d6dc3b5 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -91,7 +91,7 @@ class PreProcessing(RunEnvironment):
         bin_edges = []
         for station in data:
             # Create histogram for each station
-            hist, bin_edges = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(total_min,total_max))
+            hist, bin_edges = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(total_min, total_max))
             # Add up histograms
             histogram = histogram + hist
         # Scale down to most frequent class=1
@@ -103,10 +103,11 @@ class PreProcessing(RunEnvironment):
         self.data_store.set('oversampling_rates', oversampling_rates, 'train')
         self.data_store.set('oversampling_rates_capped', oversampling_rates_capped, 'train')
         self.data_store.set('oversampling_bin_edges', bin_edges)
-        Y = None
-        Y_extreme = None
+        #Y = None
+        #Y_extreme = None
         for station in data:
             station.apply_oversampling(bin_edges, oversampling_rates_capped)
+            '''
             if Y is None:
                 Y = station._Y
                 Y_extreme = station._Y_extreme
@@ -116,28 +117,7 @@ class PreProcessing(RunEnvironment):
         self.data_store.set('Oversampling_Y', Y)
         self.data_store.set('Oversampling_Y_extreme', Y_extreme)
         '''
-        if not on HPC:
-            fig, ax = plt.subplots(nrows=2, ncols=2)
-            fig.suptitle(f"Window Size=1, Bins={bins}, rates_cap={rates_cap}")
-            Y_hist = Y.plot.hist(bins=bin_edges, histtype="step", label="Before", ax=ax[0,0])[0]
-            Y_extreme_hist = Y_extreme.plot.hist(bins=bin_edges, histtype="step", label="After", ax=ax[0,0])[0]
-            ax[0,0].set_title(f"Histogram before-after oversampling")
-            ax[0,0].legend()
-            Y_hist_dens = Y.plot.hist(bins=bin_edges, density=True, histtype="step", label="Before", ax=ax[0,1])[0]
-            Y_extreme_hist_dens = Y_extreme.plot.hist(bins=bin_edges, density=True, histtype="step", label="After", ax=ax[0,1])[0]
-            ax[0,1].set_title(f"Density-Histogram before-after oversampling")
-            ax[0,1].legend()
-            real_oversampling = Y_extreme_hist/Y_hist
-            ax[1,0].plot(range(len(real_oversampling)), oversampling_rates_capped, label="Desired oversampling_rates")
-            ax[1,0].plot(range(len(real_oversampling)), real_oversampling, label="Actual Oversampling Rates")
-            ax[1,0].set_title(f"Oversampling rates")
-            ax[1,0].legend()
-            ax[1,1].plot(range(len(real_oversampling)), real_oversampling / oversampling_rates_capped,
-                     label="Actual/Desired Rate")
-            ax[1,1].set_title(f"Deviation from desired Oversampling rate")
-            ax[1,1].legend()
-            plt.show()
-            '''
+
 
     def report_pre_processing(self):
         """Log some metrics on data and create latex report."""
-- 
GitLab