From 45199b6c8b1bc9198007cb7d699d686f8e07ff5e Mon Sep 17 00:00:00 2001 From: "v.gramlich1" <v.gramlichfz-juelich.de> Date: Thu, 17 Jun 2021 15:28:10 +0200 Subject: [PATCH] apply_oversampling calculates the desired oversampling_rates --- mlair/run_modules/pre_processing.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 9d44ce0b..6f3c1cef 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -8,6 +8,8 @@ import os import traceback from typing import Tuple import multiprocessing + +import numpy as np import requests import psutil @@ -65,9 +67,37 @@ class PreProcessing(RunEnvironment): raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.") self.data_store.set("stations", valid_stations) self.split_train_val_test() + self.apply_oversampling() self.report_pre_processing() self.prepare_competitors() + def apply_oversampling(self): + #if Abfrage for oversampling=True/False + bins = 10 + rates_cap = 20 + data = self.data_store.get('data_collection', 'train') + histogram = np.array(bins) + #get min and max of the whole data + min = 0 + max = 0 + for station in data: + min = np.minimum(np.amin(station.get_Y(as_numpy=True)), min) + max = np.maximum(np.amax(station.get_Y(as_numpy=True)), max) + for station in data: + # erstelle Histogramm mit numpy für jede Station + hist, _ = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(min,max)) + #histograms.append(hist) + histogram = histogram + hist + # Addiere alle Histogramme zusammen + #histogram = histograms[0]+histograms[1]+histograms[2]+histograms[3] + #teile durch gesamtanzahl + histogram = 1/np.sum(histogram) * histogram + #mult mit 1/häufigste Klasse + histogram = 1/np.amax(histogram) * histogram + #Oversampling 1/Kl + oversampling_rates = 1 / histogram + oversampling_rates_capped = np.minimum(oversampling_rates, rates_cap) + def report_pre_processing(self): """Log some metrics on data and create latex report.""" logging.debug(20 * '##') -- GitLab