From f54a7b8729c9281bb1259abf71ef4c97ea1aaacf Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Mon, 21 Jun 2021 11:39:13 +0200
Subject: [PATCH] Changes according to the threads, using histogram += hist and
 histogram /= np.amax(histogram) leads to error because of wrong shape

---
 mlair/run_modules/pre_processing.py | 31 +++++++++++++----------------
 1 file changed, 14 insertions(+), 17 deletions(-)

diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 6f3c1cef..05bd61d9 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -71,32 +71,29 @@ class PreProcessing(RunEnvironment):
         self.report_pre_processing()
         self.prepare_competitors()
 
-    def apply_oversampling(self):
-        #if Abfrage for oversampling=True/False
-        bins = 10
-        rates_cap = 20
+    def apply_oversampling(self, bins=10, rates_cap=20):
+        #if request for oversampling=True/False
         data = self.data_store.get('data_collection', 'train')
         histogram = np.array(bins)
         #get min and max of the whole data
-        min = 0
-        max = 0
+        total_min = 0
+        total_max = 0
         for station in data:
-            min = np.minimum(np.amin(station.get_Y(as_numpy=True)), min)
-            max = np.maximum(np.amax(station.get_Y(as_numpy=True)), max)
+            total_min = np.minimum(np.amin(station.get_Y(as_numpy=True)), total_min)
+            total_max = np.maximum(np.amax(station.get_Y(as_numpy=True)), total_max)
         for station in data:
-            # erstelle Histogramm mit numpy für jede Station
-            hist, _ = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(min,max))
-            #histograms.append(hist)
+            # Create histogram for each station
+            hist, _ = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(total_min,total_max))
+            # Add up histograms
             histogram = histogram + hist
-        # Addiere alle Histogramme zusammen
-        #histogram = histograms[0]+histograms[1]+histograms[2]+histograms[3]
-        #teile durch gesamtanzahl
-        histogram = 1/np.sum(histogram) * histogram
-        #mult mit 1/häufigste Klasse
+        # Scale down to most frequent class=1
         histogram = 1/np.amax(histogram) * histogram
-        #Oversampling 1/Kl
+        # Get Oversampling rates (with and without cap)
         oversampling_rates = 1 / histogram
         oversampling_rates_capped = np.minimum(oversampling_rates, rates_cap)
+        # Add to datastore
+        self.data_store.set('oversampling_rates', oversampling_rates, 'training')
+        self.data_store.set('oversampling_rates_capped', oversampling_rates_capped, 'training')
 
     def report_pre_processing(self):
         """Log some metrics on data and create latex report."""
-- 
GitLab