diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index 2b817f531a39ffdf0fb41d57e604a10734dba2fd..c6e61782102152420af57e5927f29094256e48d4 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -55,9 +55,8 @@ DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA
                        "pm10": "", "so2": ""}
 DEFAULT_USE_MULTIPROCESSING = True
 DEFAULT_USE_MULTIPROCESSING_ON_DEBUG = False
-
-DEFAULT_BINS = 10
-DEFAULT_RATES_CAP = 20
+DEFAULT_OVERSAMPLING_BINS = 10
+DEFAULT_OVERSAMPLING_RATES_CAP = 20
 
 
 def get_defaults():
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 11461ad77c3e910a897a9a1be48aef7cef45480a..0c6d2ddcfc7208ff4b294b0e110fd63e79667ac1 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -166,6 +166,38 @@ class DefaultDataHandler(AbstractDataHandler):
     def apply_transformation(self, data, base="target", dim=0, inverse=False):
         return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse)
 
+    def apply_oversampling(self, bin_edges, oversampling_rates):
+        self._load()
+        if (self._X is None) or (self._Y is None):
+            logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes")
+            return
+        Y = self._Y
+        X = self._X
+        for i_bin in range(len(bin_edges)-1):
+            bin_start = bin_edges[i_bin]
+            if i_bin == len(bin_edges) - 1:
+                bin_end = bin_edges[i_bin+1]+1
+            else:
+                bin_end = bin_edges[i_bin + 1]
+            rate = oversampling_rates[i_bin]
+
+            # extract extremes based on occurrence in labels
+            other_dims = remove_items(list(Y.dims), self.time_dim)
+            extreme_idx = xr.concat([(Y >= bin_start).any(dim=other_dims[0]),
+                                         (Y < bin_end).any(dim=other_dims[0])],
+                                        dim=other_dims[0]).all(dim=other_dims[0])
+
+            extremes_X = list(map(lambda x: x.sel(**{self.time_dim: extreme_idx}), X))
+            self._add_timedelta(extremes_X, dim, timedelta)
+            # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X))
+
+            extremes_Y = Y.sel(**{dim: extreme_idx})
+            extremes_Y.coords[dim].values += np.timedelta64(*timedelta)
+
+            self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
+            self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
+
+
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
                           timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM):
         """
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index e28eb76db0232f72fe4d6548e29a966bd4915951..b249491a843d7304ee87c8c3e5f0e11874d3989b 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -19,7 +19,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
     DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \
-    DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG
+    DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_OVERSAMPLING_BINS, DEFAULT_OVERSAMPLING_RATES_CAP
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
@@ -183,6 +183,9 @@ class ExperimentSetup(RunEnvironment):
     :param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this
         parameter to `True` (default). If set to `False` the computation is performed in an serial approach.
         Multiprocessing is disabled when running in debug mode and cannot be switched on.
+    :param oversampling_bins: Sets the number of classes in which the training data is split. The training samples are then
+        oversampled according to the frequency of the different classes.
+    :param oversampling_rates_cap: Sets the maximum oversampling rate that is applied to a class
 
     """
 
@@ -216,7 +219,7 @@ class ExperimentSetup(RunEnvironment):
                  hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
                  data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
                  use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
-                 bins=None, rates_cap=None, **kwargs):
+                 oversampling_bins=None, oversampling_rates_cap=None, **kwargs):
 
         # create run framework
         super().__init__()
@@ -362,8 +365,8 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("model_class", model, VanillaModel)
 
         # set params for oversampling
-        self._set_param("bins", bins, default=DEFAULT_BINS)
-        self._set_param("rates_cap", rates_cap, default=DEFAULT_RATES_CAP)
+        self._set_param("oversampling_bins", oversampling_bins, default=DEFAULT_OVERSAMPLING_BINS)
+        self._set_param("oversampling_rates_cap", oversampling_rates_cap, default=DEFAULT_OVERSAMPLING_RATES_CAP)
 
         # set remaining kwargs
         if len(kwargs) > 0:
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 92cdac47939cee7b4f66b2c302f6c0adcb205422..4e41e84707b1d785bc893b7b7f292ab9a99fac1a 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -74,8 +74,8 @@ class PreProcessing(RunEnvironment):
     def apply_oversampling(self):
         #if request for oversampling=True/False
         data = self.data_store.get('data_collection', 'train')
-        bins = self.data_store.get_default('bins')
-        rates_cap = self.data_store.get_default('rates_cap')
+        bins = self.data_store.get('oversampling_bins')
+        rates_cap = self.data_store.get('oversampling_rates_cap')
         histogram = np.array(bins)
         #get min and max of the whole data
         total_min = 0
@@ -83,9 +83,10 @@ class PreProcessing(RunEnvironment):
         for station in data:
             total_min = np.minimum(np.amin(station.get_Y(as_numpy=True)), total_min)
             total_max = np.maximum(np.amax(station.get_Y(as_numpy=True)), total_max)
+        bin_edges = []
         for station in data:
             # Create histogram for each station
-            hist, _ = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(total_min,total_max))
+            hist, bin_edges = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(total_min,total_max))
             # Add up histograms
             histogram = histogram + hist
         # Scale down to most frequent class=1
@@ -94,8 +95,11 @@ class PreProcessing(RunEnvironment):
         oversampling_rates = 1 / histogram
         oversampling_rates_capped = np.minimum(oversampling_rates, rates_cap)
         # Add to datastore
-        self.data_store.set('oversampling_rates', oversampling_rates, 'training')
-        self.data_store.set('oversampling_rates_capped', oversampling_rates_capped, 'training')
+        self.data_store.set('oversampling_rates', oversampling_rates, 'train')
+        self.data_store.set('oversampling_rates_capped', oversampling_rates_capped, 'train')
+        self.data_store.set('oversampling_bin_edges', bin_edges)
+        for station in data:
+            station.apply_oversampling(bin_edges, oversampling_rates_capped)
 
     def report_pre_processing(self):
         """Log some metrics on data and create latex report."""