From 527ddd38953581e49092975e68cde1faf18abcbd Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Wed, 4 Aug 2021 09:01:50 +0200
Subject: [PATCH] Made PlotOversamplingContingency get min and max_threshold

---
 mlair/plotting/postprocessing_plotting.py | 30 ++++++++++++++---------
 1 file changed, 19 insertions(+), 11 deletions(-)

diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index a0d54c18..ed114830 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -48,11 +48,10 @@ class PlotOversamplingContingency(AbstractPlotClass):
         self._all_names = [self._model_name]
         self._all_names.extend(self._comp_names)
         self._plot_names = plot_names
-        contingency_array, borders = self._calculate_contingencies()
+        self._min_threshold, self._max_threshold = self._min_max_threshold()
+        contingency_array = self._calculate_contingencies()
         self._scores = ["ts", "h", "f"]
         score_array = self._calculate_all_scores(contingency_array)
-        self._min_label = borders[0]
-        self._max_label = borders[1]
         self._plot_counter = 0
 
         self._plot(score_array, "ts")
@@ -67,10 +66,10 @@ class PlotOversamplingContingency(AbstractPlotClass):
     def _plot(self, data, score):
         if score == "all_scores":
             for score_name in data.scores.values.tolist():
-                plt.plot(range(self._min_label, self._max_label), data.loc[dict(type="nn", scores=score_name)], label=score_name)
+                plt.plot(range(self._min_threshold, self._max_threshold), data.loc[dict(type="nn", scores=score_name)], label=score_name)
         else:
             for type in data.type.values.tolist():
-                plt.plot(range(self._min_label, self._max_label), data.loc[dict(type=type, scores=score)], label=type)
+                plt.plot(range(self._min_threshold, self._max_threshold), data.loc[dict(type=type, scores=score)], label=type)
         plt.legend()
         self.plot_name = self._plot_names[self._plot_counter]
         self._plot_counter = self._plot_counter + 1
@@ -116,11 +115,20 @@ class PlotOversamplingContingency(AbstractPlotClass):
                 continue
         return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None
 
+    def _min_max_threshold(self):
+        min_threshold = 0
+        max_threshold = 0
+        for station in self._stations:
+            file = os.path.join(self._file_path, self._file_name % station)
+            forecast_file = xr.open_dataarray(file)
+            obs = forecast_file.sel(type=self._obs_name)
+            obs = obs.fillna(0)
+            min_threshold = np.minimum(min_threshold, int(np.min(obs.values.flatten())))
+            max_threshold = np.maximum(max_threshold, int(np.max(obs.values.flatten())))
+        return min_threshold, max_threshold
+
     def _calculate_contingencies(self):
-        min_label = 0
-        max_label = 100
-        borders = [min_label, max_label]
-        thresholds = np.arange(min_label, max_label)
+        thresholds = np.arange(self._min_threshold, self._max_threshold)
         contingency_cell = ["ta", "fa", "fb", "tb"]
         contingency_array = xr.DataArray(dims=["thresholds", "contingency_cell", "type"],
                                          coords=[thresholds, contingency_cell, self._all_names])
@@ -132,14 +140,14 @@ class PlotOversamplingContingency(AbstractPlotClass):
             predictions = [forecast_file.sel(type=self._model_name)]
             competitors = [self._load_competitors(station, [comp]).sel(type=comp) for comp in self._comp_names]
             predictions.extend(competitors)
-            for threshold in range(min_label, max_label):
+            for threshold in range(self._min_threshold, self._max_threshold):
                 for pred in predictions:
                     ta, fa, fb, tb = self._single_contingency(obs, pred, threshold)
                     contingency_array.loc[dict(thresholds=threshold, contingency_cell="ta", type=pred.type.values)] = ta
                     contingency_array.loc[dict(thresholds=threshold, contingency_cell="fa", type=pred.type.values)] = fa
                     contingency_array.loc[dict(thresholds=threshold, contingency_cell="fb", type=pred.type.values)] = fb
                     contingency_array.loc[dict(thresholds=threshold, contingency_cell="tb", type=pred.type.values)] = tb
-        return contingency_array, borders
+        return contingency_array
 
     def _single_contingency(self, obs, pred, threshold):
         ta = 0
-- 
GitLab