diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 6f3c1ceff7292ff9096ea3edba652ef19b8aa771..05bd61d941914c0e6bd332ce886ce2f62e293d8e 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -71,32 +71,29 @@ class PreProcessing(RunEnvironment): self.report_pre_processing() self.prepare_competitors() - def apply_oversampling(self): - #if Abfrage for oversampling=True/False - bins = 10 - rates_cap = 20 + def apply_oversampling(self, bins=10, rates_cap=20): + #if request for oversampling=True/False data = self.data_store.get('data_collection', 'train') histogram = np.array(bins) #get min and max of the whole data - min = 0 - max = 0 + total_min = 0 + total_max = 0 for station in data: - min = np.minimum(np.amin(station.get_Y(as_numpy=True)), min) - max = np.maximum(np.amax(station.get_Y(as_numpy=True)), max) + total_min = np.minimum(np.amin(station.get_Y(as_numpy=True)), total_min) + total_max = np.maximum(np.amax(station.get_Y(as_numpy=True)), total_max) for station in data: - # erstelle Histogramm mit numpy für jede Station - hist, _ = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(min,max)) - #histograms.append(hist) + # Create histogram for each station + hist, _ = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(total_min,total_max)) + # Add up histograms histogram = histogram + hist - # Addiere alle Histogramme zusammen - #histogram = histograms[0]+histograms[1]+histograms[2]+histograms[3] - #teile durch gesamtanzahl - histogram = 1/np.sum(histogram) * histogram - #mult mit 1/häufigste Klasse + # Scale down to most frequent class=1 histogram = 1/np.amax(histogram) * histogram - #Oversampling 1/Kl + # Get Oversampling rates (with and without cap) oversampling_rates = 1 / histogram oversampling_rates_capped = np.minimum(oversampling_rates, rates_cap) + # Add to datastore + self.data_store.set('oversampling_rates', oversampling_rates, 'training') + self.data_store.set('oversampling_rates_capped', oversampling_rates_capped, 'training') def report_pre_processing(self): """Log some metrics on data and create latex report."""