diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 4ae5dc03f08555c80e5a9fcaa979fdb33e0ef115..7db868ee751298f81ab8a3a20f22c04b0549f7ea 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -169,14 +169,12 @@ class DefaultDataHandler(AbstractDataHandler):
     def apply_oversampling(self, bin_edges, oversampling_rates, timedelta: Tuple[int, str] = (1, 's'), timedelta2: Tuple[int, str] = (1, 'ms')):
         self._load()
         self._X_extreme = None
-        self._X_extreme = None
+        self._Y_extreme = None
         if (self._X is None) or (self._Y is None):
             logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes")
             return
         Y = self._Y
         X = self._X
-        #complete_extremes_X_list = []
-        #complete_extremes_Y_list = []
         for i_bin in range(len(bin_edges)-1):
             bin_start = bin_edges[i_bin]
             if i_bin == len(bin_edges) - 1:
@@ -190,10 +188,9 @@ class DefaultDataHandler(AbstractDataHandler):
             extreme_idx = xr.concat([(Y >= bin_start).any(dim=other_dims[0]),
                                          (Y < bin_end).any(dim=other_dims[0])],
                                         dim=other_dims[0]).all(dim=other_dims[0])
-            #extremes_X_list =[]
-            #extremes_Y_list = []
+            extreme_idx = extreme_idx[extreme_idx]
+            sel = extreme_idx.coords[self.time_dim].values
             for i in range(np.ceil(rate).astype(int)):
-                sel = extreme_idx.coords[self.time_dim].values
                 if rate-i < 1:
                     rest = int(len(sel)*(rate-i))
                     sel = np.random.choice(sel, rest, replace=False)
@@ -208,15 +205,7 @@ class DefaultDataHandler(AbstractDataHandler):
                 else:
                     self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=self.time_dim), self._X_extreme, extremes_X))
                     self._Y_extreme = xr.concat([self._Y_extreme, extremes_Y], dim=self.time_dim)
-                #extremes_X_list.append(extremes_X)
-                #extremes_Y_list.append(extremes_Y)
-
-            #complete_extremes_X_list.append(extremes_X_list)
-            #complete_extremes_Y_list.append(extremes_Y_list)
-        #Convert complete_extremes_X_list (list of lists of xarrays) into xarray and give it to self._X_extreme
-        #self._X_extreme = [[xr.concat(X_list, dim=self.time_dim) for X_list in complete_X_list] for complete_X_list in complete_extremes_X_list]
-        #self._Y_extreme = [[xr.concat(Y_list, dim=self.time_dim) for Y_list in complete_Y_list] for complete_Y_list in complete_extremes_Y_list]
-
+        self._store(fresh_store=True)
 
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
                           timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM):