From 1ebf5239ac01c63295bad58fdaec150dd9d39b09 Mon Sep 17 00:00:00 2001
From: "v.gramlich1" <v.gramlichfz-juelich.de>
Date: Tue, 22 Jun 2021 11:08:27 +0200
Subject: [PATCH] Todo: l.209

---
 mlair/data_handler/default_data_handler.py | 37 +++++++++++++++-------
 1 file changed, 25 insertions(+), 12 deletions(-)

diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 0c6d2ddc..d0ee9d0a 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -166,13 +166,15 @@ class DefaultDataHandler(AbstractDataHandler):
     def apply_transformation(self, data, base="target", dim=0, inverse=False):
         return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse)
 
-    def apply_oversampling(self, bin_edges, oversampling_rates):
+    def apply_oversampling(self, bin_edges, oversampling_rates, timedelta: Tuple[int, str] = (1, 's'), timedelta2: Tuple[int, str] = (1, 'ms')):
         self._load()
         if (self._X is None) or (self._Y is None):
             logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes")
             return
         Y = self._Y
         X = self._X
+        complete_extremes_X_list = []
+        complete_extremes_Y_list = []
         for i_bin in range(len(bin_edges)-1):
             bin_start = bin_edges[i_bin]
             if i_bin == len(bin_edges) - 1:
@@ -186,16 +188,27 @@ class DefaultDataHandler(AbstractDataHandler):
             extreme_idx = xr.concat([(Y >= bin_start).any(dim=other_dims[0]),
                                          (Y < bin_end).any(dim=other_dims[0])],
                                         dim=other_dims[0]).all(dim=other_dims[0])
-
-            extremes_X = list(map(lambda x: x.sel(**{self.time_dim: extreme_idx}), X))
-            self._add_timedelta(extremes_X, dim, timedelta)
-            # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X))
-
-            extremes_Y = Y.sel(**{dim: extreme_idx})
-            extremes_Y.coords[dim].values += np.timedelta64(*timedelta)
-
-            self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
-            self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
+            extremes_X_list =[]
+            extremes_Y_list = []
+            for i in range(np.ceil(rate).astype(int)):
+                sel = extreme_idx.coords[self.time_dim].values
+                if rate-i < 1:
+                    rest = int(len(sel)*(rate-i))
+                    sel = np.random.choice(sel, rest, replace=False)
+                extremes_X = list(map(lambda x: x.sel(**{self.time_dim: sel}), X))
+                self._add_timedelta(extremes_X, self.time_dim, (i,timedelta[1]))
+                self._add_timedelta(extremes_X, self.time_dim, (i_bin, timedelta2[1]))
+                extremes_Y = Y.sel(**{self.time_dim: sel})
+                extremes_Y.coords[self.time_dim] = extremes_Y.coords[self.time_dim].values + i*np.timedelta64(*timedelta) + i_bin*np.timedelta64(*timedelta2)
+                extremes_X_list.append(extremes_X)
+                extremes_Y_list.append(extremes_Y)
+
+            complete_extremes_X_list.append(extremes_X_list)
+            complete_extremes_Y_list.append(extremes_Y_list)
+
+        #Convert complete_extremes_X_list (list of lists of xarrays) into xarray and give it to self._X_extreme
+        #self._X_extreme = [[xr.concat(X_list, dim=self.time_dim) for X_list in complete_X_list] for complete_X_list in complete_extremes_X_list]
+        #self._Y_extreme = [[xr.concat(Y_list, dim=self.time_dim) for Y_list in complete_Y_list] for complete_Y_list in complete_extremes_Y_list]
 
 
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
@@ -264,7 +277,7 @@ class DefaultDataHandler(AbstractDataHandler):
     @staticmethod
     def _add_timedelta(data, dim, timedelta):
         for d in data:
-            d.coords[dim].values += np.timedelta64(*timedelta)
+            d.coords[dim] = d.coords[dim].values + np.timedelta64(*timedelta)
 
     @classmethod
     def transformation(cls, set_stations, **kwargs):
-- 
GitLab