From 4a662aa58afa711559783a1438d7a3af6e2d9e3d Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Tue, 17 Aug 2021 11:59:55 +0200
Subject: [PATCH] added contingency_cell plots

---
 mlair/plotting/postprocessing_plotting.py | 20 +++++++++++++++++++-
 1 file changed, 19 insertions(+), 1 deletion(-)

diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index c6ec6655..668e3794 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -35,7 +35,7 @@ class PlotOversamplingContingency(AbstractPlotClass):
     def __init__(self, station_names, file_path, comp_path, file_name, plot_folder: str = ".", model_name: str = "nn",
                  obs_name: str = "obs", comp_names: str = "IntelliO3",
                  plot_names=["oversampling_threat_score", "oversampling_hit_rate", "oversampling_false_alarm_rate",
-                             "oversampling_bias", "oversampling_all_scores"]):
+                             "oversampling_bias", "oversampling_all_scores", "contingency_table"]):
 
         super().__init__(plot_folder, plot_names[0])
         self._stations = station_names
@@ -64,6 +64,24 @@ class PlotOversamplingContingency(AbstractPlotClass):
         self._save()
         self._plot(score_array, "all_scores")
         self._save()
+        self._plot_contingency(contingency_array, self._model_name)
+        self._save()
+        for comp in self._comp_names:
+            self._plot_contingency(contingency_array, comp)
+            self._save()
+
+    def _plot_contingency(self, contingency_array, type):
+        plt.plot(range(self._min_threshold, self._max_threshold),
+                 contingency_array.loc[dict(contingency_cell="ta", type=type)], label="a")
+        plt.plot(range(self._min_threshold, self._max_threshold),
+                 contingency_array.loc[dict(contingency_cell="fa", type=type)], label="b")
+        plt.plot(range(self._min_threshold, self._max_threshold),
+                 contingency_array.loc[dict(contingency_cell="fb", type=type)], label="c")
+        plt.plot(range(self._min_threshold, self._max_threshold),
+                 contingency_array.loc[dict(contingency_cell="tb", type=type)], label="d")
+        plt.title(f"contingency table {type}")
+        plt.legend()
+        self.plot_name = f"contingency_table_{type}"
 
     def _plot(self, data, score):
         if score == "all_scores":
-- 
GitLab