From 7a96f1991a7702987d6bc9a6b3fa2ac48454f100 Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Fri, 23 Jul 2021 19:42:05 +0200
Subject: [PATCH] PlotOversampling fixed

---
 mlair/plotting/data_insight_plotting.py | 33 ++++++++++++++-----------
 1 file changed, 19 insertions(+), 14 deletions(-)

diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index ccea0b84..a2007f3f 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -28,13 +28,12 @@ class PlotOversampling(AbstractPlotClass):
 
         super().__init__(plot_folder, plot_names[0])
 
-        Y_hist, Y_extreme_hist = self._calculate_hist(data, bin_edges)
-
-        Y_hist, Y_extreme_hist = self._plot_oversampling_histogram(Y, Y_extreme, bin_edges)
+        Y_hist, Y_extreme_hist, Y_hist_dens, Y_extreme_hist_dens = self._calculate_hist(data, bin_edges)
         real_oversampling = Y_extreme_hist / Y_hist
+        self._plot_oversampling_histogram(Y_hist, Y_extreme_hist, bin_edges)
         self._save()
         self.plot_name = plot_names[1]
-        self._plot_oversampling_density_histogram(Y, Y_extreme, bin_edges)
+        self._plot_oversampling_density_histogram(Y_hist_dens, Y_extreme_hist_dens, bin_edges)
         self._save()
         self.plot_name = plot_names[2]
         self._plot_oversampling_rates(oversampling_rates, real_oversampling)
@@ -43,25 +42,31 @@ class PlotOversampling(AbstractPlotClass):
         self._plot_oversampling_rates_deviation(oversampling_rates, real_oversampling)
         self._save()
 
-    def _calculate_histogram(self, data, bin_edges):
-        Y_hist = np.zeros(len(bin_edges),1)
-        Y_extreme_hist = np.zeros(len(bin_edges), 1)
+    def _calculate_hist(self, data, bin_edges):
+        Y_hist = np.zeros(len(bin_edges)-1)
+        Y_extreme_hist = np.zeros(len(bin_edges)-1)
         for station in data:
             Y = station.get_Y(as_numpy=True, upsampling=False)
             Y_extreme = station.get_Y(as_numpy=True, upsampling=True)
+            Y_hist = Y_hist + np.histogram(Y, bins=bin_edges)[0]
+            Y_extreme_hist = Y_extreme_hist + np.histogram(Y_extreme, bins=bin_edges)[0]
+        Y_hist_dens = Y_hist/np.sum(Y_hist)
+        Y_extreme_hist_dens = Y_extreme_hist / np.sum(Y_extreme_hist)
+        return Y_hist, Y_extreme_hist, Y_hist_dens, Y_extreme_hist_dens
 
-    def _plot_oversampling_histogram(self, Y, Y_extreme, bin_edges):
+    def _plot_oversampling_histogram(self, Y_hist, Y_extreme_hist, bin_edges):
         fig, ax = plt.subplots(1, 1)
-        Y_hist = Y.plot.hist(bins=bin_edges, histtype="step", label="Before", ax=ax)[0]
-        Y_extreme_hist = Y_extreme.plot.hist(bins=bin_edges, histtype="step", label="After", ax=ax)[0]
+        ax.step(bin_edges, np.append(0,Y_hist), label="Before oversampling")
+        ax.step(bin_edges, np.append(0,Y_extreme_hist), label="After oversampling")
         ax.set_title(f"Histogram before-after oversampling")
         ax.legend()
-        return Y_hist, Y_extreme_hist
 
-    def _plot_oversampling_density_histogram(self, Y, Y_extreme, bin_edges):
+    def _plot_oversampling_density_histogram(self, Y_hist_dens, Y_extreme_hist_dens, bin_edges):
         fig, ax = plt.subplots(1, 1)
-        Y.plot.hist(bins=bin_edges, density=True, histtype="step", label="Before", ax=ax)[0]
-        Y_extreme.plot.hist(bins=bin_edges,  density=True, histtype="step", label="After", ax=ax)[0]
+        ax.step(bin_edges, np.append(0,Y_hist_dens), label="Before oversampling")
+        ax.step(bin_edges, np.append(0,Y_extreme_hist_dens), label="After oversampling")
+        #ax.stairs(Y_hist_dens, bin_edges, label="Before oversampling")
+        #ax.stairs(Y_extreme_hist_dens, bin_edges, label="After oversampling")
         ax.set_title(f"Density Histogram before-after oversampling")
         ax.legend()
 
-- 
GitLab