diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 22a8ca4150c46ad1ce9ba0d005fa643e27906f53..8ad3e1e7ff583bd511d6311f2ab9de886f440fc9 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -215,32 +215,29 @@ class DefaultDataHandler(AbstractDataHandler):
                 raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
                                 f"{i} is type {type(i)}")
 
+        extremes_X, extremes_Y = None, None
         for extr_val in sorted(extreme_values):
             # check if some extreme values are already extracted
-            if (self._X_extreme is None) or (self._Y_extreme is None):
-                X = self._X
-                Y = self._Y
+            if (extremes_X is None) or (extremes_Y is None):
+                X, Y = self._X, self._Y
+                extremes_X, extremes_Y = X, Y
             else:  # one extr value iteration is done already: self.extremes_label is NOT None...
-                X = self._X_extreme
-                Y = self._Y_extreme
+                X, Y = self._X_extreme, self._Y_extreme
 
             # extract extremes based on occurrence in labels
             other_dims = remove_items(list(Y.dims), dim)
             if extremes_on_right_tail_only:
-                extreme_idx = (Y > extr_val).any(dim=other_dims)
+                extreme_idx = (extremes_Y > extr_val).any(dim=other_dims)
             else:
-                extreme_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]),
-                                           (Y > extr_val).any(dim=other_dims[0])],
+                extreme_idx = xr.concat([(extremes_Y < -extr_val).any(dim=other_dims[0]),
+                                           (extremes_Y > extr_val).any(dim=other_dims[0])],
                                           dim=other_dims[0]).any(dim=other_dims[0])
 
             sel = extreme_idx[extreme_idx].coords[dim].values
-            extremes_X = list(map(lambda x: x.sel(**{dim: sel}), X))
+            extremes_X = list(map(lambda x: x.sel(**{dim: sel}), extremes_X))
             self._add_timedelta(extremes_X, dim, timedelta)
-            # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X))
-
-            extremes_Y = Y.sel(**{dim: extreme_idx})
-            #extremes_Y.coords[dim].values += np.timedelta64(*timedelta)
-            self._add_timedelta(extremes_Y, dim, timedelta)
+            extremes_Y = extremes_Y.sel(**{dim: extreme_idx})
+            self._add_timedelta([extremes_Y], dim, timedelta)
 
             self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
             self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index 335463454a3d3937cd93c739e63ab540f08ffd92..3bec759076a862d89be6c7495ef9abdebd0d4123 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -459,7 +459,7 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
     """
 
     def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", plot_name="histogram",
-                 variables_dim="variables", time_dim="datetime", window_dim="window"):
+                 variables_dim="variables", time_dim="datetime", window_dim="window", upsampling=False):
         super().__init__(plot_folder, plot_name)
         self.variables_dim = variables_dim
         self.time_dim = time_dim
@@ -468,6 +468,8 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
         self.bins = {}
         self.interval_width = {}
         self.bin_edges = {}
+        if upsampling is True:
+            self._handle_upsampling(generators)
 
         # input plots
         for branch_pos in range(number_of_branches):
@@ -483,6 +485,11 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
             self._plot(add_name="target", subset=subset)
         self._plot_combined(add_name="target")
 
+    @staticmethod
+    def _handle_upsampling(generators):
+        if "train" in generators:
+            generators.update({"train_upsampled": generators["train"]})
+
     @staticmethod
     def _get_inputs_targets(gens, dim):
         k = list(gens.keys())[0]
@@ -495,11 +502,15 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
     def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0):
         n_bins = 100
         for set_type, generator in generators.items():
+            upsampling = "upsampled" in set_type
             tmp_bins = {}
             tmp_edges = {}
             end = {}
             start = {}
-            f = lambda x: x.get_X(as_numpy=False)[branch_pos] if input_data is True else x.get_Y(as_numpy=False)
+            if input_data is True:
+                f = lambda x: x.get_X(as_numpy=False, upsampling=upsampling)[branch_pos]
+            else:
+                f = lambda x: x.get_Y(as_numpy=False, upsampling=upsampling)
             for gen in generator:
                 w = min(abs(f(gen).coords[self.window_dim].values))
                 data = f(gen).sel({self.window_dim: w})
@@ -536,6 +547,7 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
         bin_edges = self.bin_edges[subset]
         interval_width = self.interval_width[subset]
         colors = self.get_dataset_colors()
+        colors.update({"train_upsampled": colors.get("train_val", "#000000")})
         for var in bins.keys():
             fig, ax = plt.subplots()
             hist_var = bins[var]
@@ -555,6 +567,7 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
         pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
         variables = self.bins[list(self.bins.keys())[0]].keys()
         colors = self.get_dataset_colors()
+        colors.update({"train_upsampled": colors.get("train_val", "#000000")})
         for var in variables:
             fig, ax = plt.subplots()
             for subset in self.bins.keys():
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index f3909fde29b466af1bf64124ab1d57873ae70d18..edab0a3e8ad3ee6f7eeb50318f617d3b661255bd 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -432,8 +432,10 @@ class PostProcessing(RunEnvironment):
 
         try:
             if "PlotDataHistogram" in plot_list:
+                upsampling = self.data_store.get_default("upsampling", scope="train", default=False)
                 gens = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
-                PlotDataHistogram(gens, plot_folder=self.plot_path, time_dim=time_dim, variables_dim=target_dim)
+                PlotDataHistogram(gens, plot_folder=self.plot_path, time_dim=time_dim, variables_dim=target_dim,
+                                  upsampling=upsampling)
         except Exception as e:
             logging.error(f"Could not create plot PlotDataHistogram due to the following error: {e}")