From f73acaa44dd0cdca0d487a2b642092d7dcc85e06 Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Mon, 26 Jul 2021 14:25:34 +0200
Subject: [PATCH] Commits vor merge

---
 mlair/data_handler/default_data_handler.py |   1 +
 mlair/plotting/data_insight_plotting.py    |  18 ++-
 mlair/plotting/postprocessing_plotting.py  | 159 ++++++++++++++++-----
 mlair/run_modules/experiment_setup.py      |   2 +-
 mlair/run_modules/post_processing.py       |  12 +-
 mlair/run_modules/pre_processing.py        |  10 --
 6 files changed, 137 insertions(+), 65 deletions(-)

diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index acc3caa0..8d977e11 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -206,6 +206,7 @@ class DefaultDataHandler(AbstractDataHandler):
                     else:
                         self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=self.time_dim), self._X_extreme, extremes_X))
                         self._Y_extreme = xr.concat([self._Y_extreme, extremes_Y], dim=self.time_dim)
+        self._store(fresh_store=True)
 
 
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index a2007f3f..26376637 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -22,6 +22,8 @@ from mlair.plotting.abstract_plot_class import AbstractPlotClass
 @TimeTrackingWrapper
 class PlotOversampling(AbstractPlotClass):
 
+    #Todo: Build histograms correctly
+
     def __init__(self, data, bin_edges, oversampling_rates, plot_folder: str = ".",
                  plot_names=["oversampling_histogram", "oversampling_density_histogram", "oversampling_rates",
                             "oversampling_rates_deviation"]):
@@ -33,7 +35,7 @@ class PlotOversampling(AbstractPlotClass):
         self._plot_oversampling_histogram(Y_hist, Y_extreme_hist, bin_edges)
         self._save()
         self.plot_name = plot_names[1]
-        self._plot_oversampling_density_histogram(Y_hist_dens, Y_extreme_hist_dens, bin_edges)
+        self._plot_oversampling_histogram(Y_hist_dens, Y_extreme_hist_dens, bin_edges)
         self._save()
         self.plot_name = plot_names[2]
         self._plot_oversampling_rates(oversampling_rates, real_oversampling)
@@ -56,15 +58,11 @@ class PlotOversampling(AbstractPlotClass):
 
     def _plot_oversampling_histogram(self, Y_hist, Y_extreme_hist, bin_edges):
         fig, ax = plt.subplots(1, 1)
-        ax.step(bin_edges, np.append(0,Y_hist), label="Before oversampling")
-        ax.step(bin_edges, np.append(0,Y_extreme_hist), label="After oversampling")
-        ax.set_title(f"Histogram before-after oversampling")
-        ax.legend()
-
-    def _plot_oversampling_density_histogram(self, Y_hist_dens, Y_extreme_hist_dens, bin_edges):
-        fig, ax = plt.subplots(1, 1)
-        ax.step(bin_edges, np.append(0,Y_hist_dens), label="Before oversampling")
-        ax.step(bin_edges, np.append(0,Y_extreme_hist_dens), label="After oversampling")
+        ax.hist(bin_edges[:-1], bin_edges, weights=Y_hist, label="Before oversampling")
+        ax.hist(bin_edges[:-1], bin_edges, weights=Y_extreme_hist, label="After oversampling")
+        #ax.plot(bin_edges[:-1] + 0.5 * interval_width, weights, label=f"{subset}", c=colors[subset])
+        #ax.step(bin_edges, np.append(0,Y_hist), label="Before oversampling")
+        #ax.step(bin_edges, np.append(0,Y_extreme_hist), label="After oversampling")
         #ax.stairs(Y_hist_dens, bin_edges, label="Before oversampling")
         #ax.stairs(Y_extreme_hist_dens, bin_edges, label="After oversampling")
         ax.set_title(f"Density Histogram before-after oversampling")
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 67233780..b5e76e55 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -18,7 +18,7 @@ from matplotlib.backends.backend_pdf import PdfPages
 
 from mlair import helpers
 from mlair.data_handler.iterator import DataCollection
-from mlair.helpers import TimeTrackingWrapper
+from mlair.helpers import TimeTrackingWrapper, to_list
 from mlair.plotting.abstract_plot_class import AbstractPlotClass
 
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
@@ -30,53 +30,140 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
 @TimeTrackingWrapper
 class PlotOversamplingContingency(AbstractPlotClass):
+    #Todo: 1. Make competitors flexible
+    #       2. Get min and max_label
 
-    def __init__(self, predictions, labels, plot_folder: str = ".",
+    def __init__(self, station_names, file_path, comp_path, file_name, plot_folder: str = ".", model_name: str = "nn",
+                 obs_name: str = "obs", comp_names: str = "IntelliO3",
                  plot_names=["oversampling_threat_score", "oversampling_hit_rate", "oversampling_false_alarm_rate",
                              "oversampling_all_scores"]):
 
         super().__init__(plot_folder, plot_names[0])
-        ts = []
-        h = []
-        f = []
-        max_label = 0
-        min_label = 0
-        for station in station_names:
-            file = os.path.join(file_path, file_name % station)
-            forecast = xr.open_dataarray(file)
-            competitors = extract_method(station)
-        for threshold in range(min_label, max_label):
-            true_above = 0
-            false_above = 0
-            false_below = 0
-            true_below = 0
-            for prediction, label in predictions, labels:
-                if prediction >= threshold:
-                    if label >= threshold:
-                        true_above = + 1
-                    else:
-                        false_above = + 1
-                else:
-                    if label >= threshold:
-                        false_below = + 1
-                    else:
-                        true_below = + 1
-                ts.append(true_above/(true_above+false_above+false_below))
-                h.append(true_above/(true_above+false_below))
-                f.append(false_above/(false_above+true_below))
-        plt.plot(range(min_label, max_label), ts)
+        self._stations = station_names
+        self._file_path = file_path
+        self._comp_path = comp_path
+        self._file_name = file_name
+        self._model_name = model_name
+        self._obs_name = obs_name
+        self._comp_names = to_list(comp_names)
+        true_above, false_above, false_below, true_below, borders = self._calculate_contingencies()
+        ts, h, f = self._calculate_scores(true_above, false_above, false_below, true_below)
+        min_label = borders[0]
+        max_label = borders[1]
+        plt.plot(range(min_label, max_label), ts, label="threat score")
+        plt.legend()
         self._save()
         self.plot_name = plot_names[1]
-        plt.plot(range(min_label, max_label), h)
+        plt.plot(range(min_label, max_label), h, label="hit rate")
+        plt.legend()
         self._save()
         self.plot_name = plot_names[2]
-        plt.plot(range(min_label, max_label), f)
+        plt.plot(range(min_label, max_label), f, label="false alarm rate")
+        plt.legend()
+        self._save()
         self.plot_name = plot_names[3]
-        plt.plot(range(min_label, max_label), ts)
-        plt.plot(range(min_label, max_label), h)
-        plt.plot(range(min_label, max_label), f)
+        plt.plot(range(min_label, max_label), ts, label="threat score")
+        plt.plot(range(min_label, max_label), h, label="hit rate")
+        plt.plot(range(min_label, max_label), f, label="false alarm rate")
+        plt.legend()
         self._save()
 
+    def _create_competitor_forecast(self, station_name: str, competitor_name: str) -> xr.DataArray:
+        """
+        Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station
+        indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will
+        raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either
+        there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file.
+
+        :param station_name: name of the station to load data for
+        :param competitor_name: name of the model
+        :return: the forecast of the given competitor
+        """
+        path = os.path.join(self._comp_path, competitor_name)
+        file = os.path.join(path, f"forecasts_{station_name}_test.nc")
+        data = xr.open_dataarray(file)
+        # data = data.expand_dims(Stations=[station_name])  # ToDo: remove line
+        forecast = data.sel(type=[self._model_name])
+        forecast.coords["type"] = [competitor_name]
+        return forecast
+
+    def _load_competitors(self, station_name: str, comp) -> xr.DataArray:
+        """
+        Load all requested and available competitors for a given station. Forecasts must be available in the competitor
+        path like `<competitor_path>/<target_var>/forecasts_<station_name>_test.nc`. The naming style is equal for all
+        forecasts of MLAir, so that forecasts of a different experiment can easily be copied into the competitor path
+        without any change.
+
+        :param station_name: station indicator to load competitors for
+
+        :return: a single xarray with all competing forecasts
+        """
+        competing_predictions = []
+        for competitor_name in comp:
+            try:
+                prediction = self._create_competitor_forecast(station_name, competitor_name)
+                competing_predictions.append(prediction)
+            except (FileNotFoundError, KeyError):
+                logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.")
+                continue
+        return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None
+
+    def _calculate_contingencies(self):
+        for station in self._stations:
+            file = os.path.join(self._file_path, self._file_name % station)
+            forecast_file = xr.open_dataarray(file)
+            obs = forecast_file.sel(type=self._obs_name)
+            model = forecast_file.sel(type=self._model_name)
+            competitors = [self._load_competitors(station, [comp]).sel(type=comp) for comp in self._comp_names]
+            min_label = 0
+            max_label = 100
+            borders = [min_label, max_label]
+            true_above = []
+            false_above = []
+            false_below = []
+            true_below = []
+            for threshold in range(min_label, max_label):
+                ta, fa, fb, tb = self._single_contingency(obs, model, threshold)
+                true_above.append(ta)
+                false_above.append(fa)
+                false_below.append(fb)
+                true_below.append(tb)
+            return np.array(true_above), np.array(false_above), np.array(false_below), np.array(true_below), borders
+
+
+    def _single_contingency(self, obs, pred, threshold):
+        ta = 0
+        fa = 0
+        fb = 0
+        tb = 0
+        observations = obs.values.flatten()
+        predictions = pred.values.flatten()
+        for i in range(len(observations)):
+            if predictions[i] >= threshold:
+                if observations[i] >= threshold:
+                    ta += + 1
+                else:
+                    fa += + 1
+            else:
+                if observations[i] >= threshold:
+                    fb += 1
+                else:
+                    tb += 1
+        return ta, fa, fb, tb
+
+    def _calculate_scores(self, true_above, false_above, false_below, true_below):
+        np.seterr(divide="ignore")
+        np.seterr(divide="ignore")
+        ts = true_above/(true_above + false_above + false_below)
+        h = true_above/(true_above + false_below)
+        f = false_above/(false_above + true_below)
+        np.nan_to_num(ts, copy=False)
+        np.nan_to_num(h, copy=False)
+        np.nan_to_num(f, copy=False)
+        return ts, h, f
+
+
+
 
 @TimeTrackingWrapper
 class PlotMonthlySummary(AbstractPlotClass):
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index cefd4505..c5687e37 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -363,7 +363,7 @@ class ExperimentSetup(RunEnvironment):
         # set competitors
         self._set_param("competitors", competitors, default=[])
         competitor_path_default = os.path.join(self.data_store.get("data_path"), "competitors",
-                                               "_".join(self.data_store.get("target_var")))
+                                               "_".join(to_list(self.data_store.get("target_var"))))
         self._set_param("competitor_path", competitor_path, default=competitor_path_default)
 
         # check variables, statistics and target variable
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 2febedb6..8a594808 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -309,11 +309,9 @@ class PostProcessing(RunEnvironment):
         try:
             if (self.data_store.get('oversampling_method')=='bin_oversampling') and (
                     "PlotOversamplingContingency" in plot_list):
-                predictions = None
-                labels = None
-                PlotOversamplingContingency(extract_method=self.load_competitors(), station_names=self.test_data.keys(),
-                                            file_path=path, file_name=r"forecasts_%s_test.nc",
-                                            plot_folder=self.plot_path)
+                PlotOversamplingContingency(station_names=self.test_data.keys(), file_path=path, comp_path=self.competitor_path,
+                                            comp_names=self.competitors,
+                                            file_name=r"forecasts_%s_test.nc", plot_folder=self.plot_path)
         except Exception as e:
             logging.error(f"Could not create plot OversamplingContingencyPlots due to the following error: {e}")
 
@@ -321,9 +319,7 @@ class PostProcessing(RunEnvironment):
             if (self.data_store.get('oversampling_method')=='bin_oversampling') and (
                     "PlotOversampling" in plot_list):
                 bin_edges = self.data_store.get('oversampling_bin_edges')
-                oversampling_rates = self.data_store.get('oversampling_rates_capped','train')
-                #Y = self.data_store.get('Oversampling_Y')
-                #Y_extreme = self.data_store.get('Oversampling_Y_extreme')
+                oversampling_rates = self.data_store.get('oversampling_rates_capped', 'train')
                 PlotOversampling(self.train_data, bin_edges, oversampling_rates, plot_folder=self.plot_path)
         except Exception as e:
             logging.error(f"Could not create plot OversamplingPlots due to the following error: {e}")
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 2d6dc3b5..c0065523 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -107,16 +107,6 @@ class PreProcessing(RunEnvironment):
         #Y_extreme = None
         for station in data:
             station.apply_oversampling(bin_edges, oversampling_rates_capped)
-            '''
-            if Y is None:
-                Y = station._Y
-                Y_extreme = station._Y_extreme
-            else:
-                Y = xr.concat([Y, station._Y], dim="Stations")
-                Y_extreme = xr.concat([Y_extreme, station._Y_extreme], dim="Stations")
-        self.data_store.set('Oversampling_Y', Y)
-        self.data_store.set('Oversampling_Y_extreme', Y_extreme)
-        '''
 
 
     def report_pre_processing(self):
-- 
GitLab