From e81da4988f9c3cc5d234145ada05d4c57f50ad9f Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Mon, 28 Jun 2021 10:12:24 +0200
Subject: [PATCH] Final fixes

---
 mlair/configuration/defaults.py            |  2 +-
 mlair/data_handler/default_data_handler.py | 33 +++++++++++-----------
 mlair/run_modules/pre_processing.py        | 33 ++++++++++++++++++++++
 3 files changed, 51 insertions(+), 17 deletions(-)

diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index c6e61782..f2538e98 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -56,7 +56,7 @@ DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA
 DEFAULT_USE_MULTIPROCESSING = True
 DEFAULT_USE_MULTIPROCESSING_ON_DEBUG = False
 DEFAULT_OVERSAMPLING_BINS = 10
-DEFAULT_OVERSAMPLING_RATES_CAP = 20
+DEFAULT_OVERSAMPLING_RATES_CAP = 100
 
 
 def get_defaults():
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 7db868ee..784e1946 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -190,22 +190,23 @@ class DefaultDataHandler(AbstractDataHandler):
                                         dim=other_dims[0]).all(dim=other_dims[0])
             extreme_idx = extreme_idx[extreme_idx]
             sel = extreme_idx.coords[self.time_dim].values
-            for i in range(np.ceil(rate).astype(int)):
-                if rate-i < 1:
-                    rest = int(len(sel)*(rate-i))
-                    sel = np.random.choice(sel, rest, replace=False)
-                extremes_X = list(map(lambda x: x.sel(**{self.time_dim: sel}), X))
-                self._add_timedelta(extremes_X, self.time_dim, (i,timedelta[1]))
-                self._add_timedelta(extremes_X, self.time_dim, (i_bin, timedelta2[1]))
-                extremes_Y = Y.sel(**{self.time_dim: sel})
-                extremes_Y.coords[self.time_dim] = extremes_Y.coords[self.time_dim].values + i*np.timedelta64(*timedelta) + i_bin*np.timedelta64(*timedelta2)
-                if (self._X_extreme is None) or (self._Y_extreme is None):
-                    self._X_extreme = extremes_X
-                    self._Y_extreme = extremes_Y
-                else:
-                    self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=self.time_dim), self._X_extreme, extremes_X))
-                    self._Y_extreme = xr.concat([self._Y_extreme, extremes_Y], dim=self.time_dim)
-        self._store(fresh_store=True)
+            if len(extreme_idx)>0:
+                for i in range(np.ceil(rate).astype(int)):
+                    if rate-i < 1:
+                        rest = int(len(sel)*(rate-i))+1
+                        sel = np.random.choice(sel, rest, replace=False)
+                    extremes_X = list(map(lambda x: x.sel(**{self.time_dim: sel}), X))
+                    self._add_timedelta(extremes_X, self.time_dim, (i, timedelta[1]))
+                    self._add_timedelta(extremes_X, self.time_dim, (i_bin, timedelta2[1]))
+                    extremes_Y = Y.sel(**{self.time_dim: sel})
+                    extremes_Y.coords[self.time_dim] = extremes_Y.coords[self.time_dim].values + i*np.timedelta64(*timedelta) + i_bin*np.timedelta64(*timedelta2)
+                    if (self._X_extreme is None) or (self._Y_extreme is None):
+                        self._X_extreme = extremes_X
+                        self._Y_extreme = extremes_Y
+                    else:
+                        self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=self.time_dim), self._X_extreme, extremes_X))
+                        self._Y_extreme = xr.concat([self._Y_extreme, extremes_Y], dim=self.time_dim)
+        #self._store(fresh_store=True)
 
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
                           timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM):
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 4e41e847..9ef5c3f1 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -14,6 +14,8 @@ import requests
 import psutil
 
 import pandas as pd
+import xarray as xr
+from matplotlib import pyplot as plt
 
 from mlair.data_handler import DataCollection, AbstractDataHandler
 from mlair.helpers import TimeTracking, to_list, tables
@@ -98,8 +100,39 @@ class PreProcessing(RunEnvironment):
         self.data_store.set('oversampling_rates', oversampling_rates, 'train')
         self.data_store.set('oversampling_rates_capped', oversampling_rates_capped, 'train')
         self.data_store.set('oversampling_bin_edges', bin_edges)
+        Y = None
+        Y_extreme = None
         for station in data:
             station.apply_oversampling(bin_edges, oversampling_rates_capped)
+            if Y is None:
+                Y = station._Y
+                Y_extreme = station._Y_extreme
+            else:
+                Y = xr.concat([Y, station._Y], dim="Stations")
+                Y_extreme = xr.concat([Y_extreme, station._Y_extreme], dim="Stations")
+
+        fig, ax = plt.subplots(nrows=2, ncols=2)
+        fig.suptitle(f"Window Size=1, Bins={bins}, rates_cap={rates_cap}")
+        Y_hist = Y.plot.hist(bins=bin_edges, histtype="step", label="Before", ax=ax[0,0])[0]
+        Y_extreme_hist = Y_extreme.plot.hist(bins=bin_edges, histtype="step", label="After", ax=ax[0,0])[0]
+        ax[0,0].set_title(f"Histogram before-after oversampling")
+        ax[0,0].legend()
+        Y_hist_dens = Y.plot.hist(bins=bin_edges, density=True, histtype="step", label="Before", ax=ax[0,1])[0]
+        Y_extreme_hist_dens = Y_extreme.plot.hist(bins=bin_edges, density=True, histtype="step", label="After", ax=ax[0,1])[0]
+        ax[0,1].set_title(f"Density-Histogram before-after oversampling")
+        ax[0,1].legend()
+        real_oversampling = Y_extreme_hist/Y_hist
+        ax[1,0].plot(range(len(real_oversampling)), oversampling_rates_capped, label="Desired oversampling_rates")
+        ax[1,0].plot(range(len(real_oversampling)), real_oversampling, label="Actual Oversampling Rates")
+        ax[1,0].set_title(f"Oversampling rates")
+        ax[1,0].legend()
+        ax[1,1].plot(range(len(real_oversampling)), real_oversampling / oversampling_rates_capped,
+                 label="Actual/Desired Rate")
+        ax[1,1].set_title(f"Deviation from desired Oversampling rate")
+        ax[1,1].legend()
+        plt.show()
+        #data[1]._Y.where(data[1]._Y > bin_edges[9], drop=True)
+        #data[1]._Y_extreme.where(data[1]._Y_extreme > bin_edges[9], drop=True)
 
     def report_pre_processing(self):
         """Log some metrics on data and create latex report."""
-- 
GitLab