diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
index c6015ed78956fca539a951d6b0072ca8924a7a9c..e8c6044280799ded080ab4bff3627aeb9ffde2db 100644
--- a/src/data_handling/data_distributor.py
+++ b/src/data_handling/data_distributor.py
@@ -57,8 +57,8 @@ class Distributor(keras.utils.Sequence):
                 if self.upsampling:
                     try:
                         s = self.generator.get_data_generator(k)
-                        x_total = np.concatenate([x_total, np.copy(s.extremes_history.copy())], axis=0)
-                        y_total = np.concatenate([y_total, np.copy(s.extremes_labels.copy())], axis=0)
+                        x_total = np.concatenate([x_total, np.copy(s.get_extremes_history())], axis=0)
+                        y_total = np.concatenate([y_total, np.copy(s.get_extremes_label())], axis=0)
                     except AttributeError:  # no extremes history / labels available, copy will fail
                         pass
                 # get number of mini batches
diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index 4751bd35d0852a7044260e5b08a1e203234e1e1b..490d661195aa017113f705da7b2e1e896e55fdc1 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -61,7 +61,7 @@ class DataPrep(object):
         self.label = None
         self.observation = None
         self.extremes_history = None
-        self.extremes_labels = None
+        self.extremes_label = None
         self.kwargs = kwargs
         self.data = None
         self.meta = None
@@ -424,6 +424,12 @@ class DataPrep(object):
     def get_transposed_label(self):
         return self.label.squeeze("Stations").transpose("datetime", "window").copy()
 
+    def get_extremes_history(self):
+        return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy()
+
+    def get_extremes_label(self):
+        return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy()
+
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
                           timedelta: Tuple[int, str] = (1, 'm')):
         """
@@ -450,7 +456,7 @@ class DataPrep(object):
 
         for extr_val in sorted(extreme_values):
             # check if some extreme values are already extracted
-            if (self.extremes_labels is None) or (self.extremes_history is None):
+            if (self.extremes_label is None) or (self.extremes_history is None):
                 # extract extremes based on occurance in labels
                 if extremes_on_right_tail_only:
                     extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1,)
@@ -462,23 +468,21 @@ class DataPrep(object):
                 extremes_history = self.history[..., extreme_label_idx, :]
                 extremes_label.datetime.values += np.timedelta64(*timedelta)
                 extremes_history.datetime.values += np.timedelta64(*timedelta)
-                self.extremes_labels = extremes_label.squeeze('Stations').transpose('datetime', 'window')
-                self.extremes_history = extremes_history.transpose('datetime', 'window', 'Stations', 'variables')
-            else:  # one extr value iteration is done already: self.extremes_labels is NOT None...
+                self.extremes_label = extremes_label#.squeeze('Stations').transpose('datetime', 'window')
+                self.extremes_history = extremes_history#.transpose('datetime', 'window', 'Stations', 'variables')
+            else:  # one extr value iteration is done already: self.extremes_label is NOT None...
                 if extremes_on_right_tail_only:
-                    extreme_label_idx = (self.extremes_labels > extr_val).any(axis=1).values.reshape(-1,)
+                    extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, )
                 else:
-                    extreme_label_idx = np.concatenate(((self.extremes_labels < -extr_val).any(axis=1
-                                                                                                 ).values.reshape(-1, 1),
-                                                        (self.extremes_labels > extr_val).any(axis=1
-                                                                                                 ).values.reshape(-1, 1)
+                    extreme_label_idx = np.concatenate(((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1),
+                                                        (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1)
                                                         ), axis=1).any(axis=1)
                 # check on existing extracted extremes to minimise computational costs for comparison
-                extremes_label = self.extremes_labels[extreme_label_idx, ...]
-                extremes_history = self.extremes_history[extreme_label_idx, ...]
+                extremes_label = self.extremes_label[..., extreme_label_idx]
+                extremes_history = self.extremes_history[..., extreme_label_idx, :]
                 extremes_label.datetime.values += np.timedelta64(*timedelta)
                 extremes_history.datetime.values += np.timedelta64(*timedelta)
-                self.extremes_labels = xr.concat([self.extremes_labels, extremes_label], dim='datetime')
+                self.extremes_label = xr.concat([self.extremes_label, extremes_label], dim='datetime')
                 self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime')
 
 
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
index dd0ca99dde0ed38bfbbd392f65eeddca78bbb075..15344fd808a4aa9ee5774ad8ba647bf5ce06d015 100644
--- a/test/test_data_handling/test_data_distributor.py
+++ b/test/test_data_handling/test_data_distributor.py
@@ -98,3 +98,21 @@ class TestDistributor:
         assert np.testing.assert_equal(x, x_perm) is None
         assert np.testing.assert_equal(y, y_perm) is None
 
+    def test_distribute_on_batches_upsampling_no_extremes_given(self,  generator, model):
+        d = Distributor(generator, model, upsampling=True)
+        gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0]
+        num_mini_batches = math.ceil(gen_len / d.batch_size)
+        i = 0
+        for i, e in enumerate(d.distribute_on_batches(fit_call=False)):
+            assert e[0].shape[0] <= d.batch_size
+        assert i + 1 == num_mini_batches
+
+    def test_distribute_on_batches_upsampling(self, generator, model):
+        generator.extreme_values = [1]
+        d = Distributor(generator, model, upsampling=True)
+        gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0]
+        extr_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_extremes_label().shape[0]
+        i = 0
+        for i, e in enumerate(d.distribute_on_batches(fit_call=False)):
+            assert e[0].shape[0] <= d.batch_size
+        assert i + 1 == math.ceil((gen_len + extr_len) / d.batch_size)
diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py
index 9bf11154609afa9ada2b488455f7a341a41d21ae..939f93cc9ee01c76a282e755aca14b39c6fc4ac9 100644
--- a/test/test_data_handling/test_data_generator.py
+++ b/test/test_data_handling/test_data_generator.py
@@ -238,6 +238,20 @@ class TestDataGenerator:
         assert data._transform_method == "standardise"
         assert data.mean is not None
 
+    def test_get_data_generator_extremes(self, gen_with_transformation):
+        gen = gen_with_transformation
+        gen.kwargs = {"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}}
+        gen.extreme_values = [1.]
+        data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False)
+        assert data.extremes_label is not None
+        assert data.extremes_history is not None
+        assert data.extremes_label.shape[:2] == data.label.shape[:2]
+        assert data.extremes_label.shape[2] <= data.label.shape[2]
+        len_both_tails = data.extremes_label.shape[2]
+        gen.kwargs["extremes_on_right_tail_only"] = True
+        data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False)
+        assert data.extremes_label.shape[2] <= len_both_tails
+
     def test_save_pickle_data(self, gen):
         file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}_2010_2014_.pickle")
         if os.path.exists(file):
diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py
index f202a6efc09502c635ad548471b588e259ebc7e1..71f3a1d6a0a675a155b517901aef1f3c359b104b 100644
--- a/test/test_data_handling/test_data_preparation.py
+++ b/test/test_data_handling/test_data_preparation.py
@@ -410,7 +410,7 @@ class TestDataPrep:
         data.make_labels("variables", "o3", "datetime", 2)
         orig = data.label
         data.multiply_extremes(1)
-        upsampled = data.extremes_labels
+        upsampled = data.extremes_label
         assert (upsampled > 1).sum() == (orig > 1).sum()
         assert (upsampled < -1).sum() == (orig < -1).sum()
 
@@ -420,7 +420,7 @@ class TestDataPrep:
         data.make_labels("variables", "o3", "datetime", 2)
         orig = data.label
         data.multiply_extremes([1, 1.5, 2, 3])
-        upsampled = data.extremes_labels
+        upsampled = data.extremes_label
         def f(d, op, n):
             return op(d, n).any(dim="window").sum()
         assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 1.5), f(orig, gt, 2) * 2, f(orig, gt, 3) * 4])
@@ -438,9 +438,29 @@ class TestDataPrep:
         data.make_labels("variables", "o3", "datetime", 2)
         orig = data.label
         data.multiply_extremes([1, 2], extremes_on_right_tail_only=True)
-        upsampled = data.extremes_labels
+        upsampled = data.extremes_label
         def f(d, op, n):
             return op(d, n).any(dim="window").sum()
         assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 2)])
-        assert len(upsampled) == sum([f(orig, gt, 1), f(orig, gt, 2)])
+        assert upsampled.shape[2] == sum([f(orig, gt, 1), f(orig, gt, 2)])
         assert f(upsampled, lt, -1) == 0
+
+    def test_get_extremes_history(self, data):
+        data.transform("datetime")
+        data.make_history_window("variables", 3, "datetime")
+        data.make_labels("variables", "o3", "datetime", 2)
+        data.make_observation("variables", "o3", "datetime")
+        data.remove_nan("datetime")
+        data.multiply_extremes([1, 2], extremes_on_right_tail_only=True)
+        assert (data.get_extremes_history() ==
+                data.extremes_history.transpose("datetime", "window", "Stations", "variables")).all()
+
+    def test_get_extremes_label(self, data):
+        data.transform("datetime")
+        data.make_history_window("variables", 3, "datetime")
+        data.make_labels("variables", "o3", "datetime", 2)
+        data.make_observation("variables", "o3", "datetime")
+        data.remove_nan("datetime")
+        data.multiply_extremes([1, 2], extremes_on_right_tail_only=True)
+        assert (data.get_extremes_label() ==
+                data.extremes_label.squeeze("Stations").transpose("datetime", "window")).all()