diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index f159e6fca1390b12891a511f30eeb1fbbb0672e9..fec4ef6ec5e5fb89c3fedd46fc1fd6fd845ff1ea 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -446,51 +446,90 @@ class PlotAvailabilityHistogram(AbstractPlotClass):  # pragma: no cover
 
 class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
 
-    def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="histogram",
+    def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", plot_name="histogram",
                  variables_dim="variables", time_dim="datetime", window_dim="window"):
         super().__init__(plot_folder, plot_name)
         self.variables_dim = variables_dim
         self.time_dim = time_dim
         self.window_dim = window_dim
-        self.inputs = to_list(generator[0].get_X(as_numpy=False)[0].coords[self.variables_dim].values.tolist())
-        self.targets = to_list(generator[0].get_Y(as_numpy=False).coords[self.variables_dim].values.tolist())
-
-        # normalized versions
-        self._calculate_hist(generator, self.inputs, input_data=True)
-        self._plot(add_name="input")
-        self._calculate_hist(generator, self.targets, input_data=False)
-        self._plot(add_name="target")
-
-    def _calculate_hist(self, generator, variables, input_data=True):
-        bins = {}
-        n_bins = 100
-        interval_width = None
-        bin_edges = None
-        f = lambda x: x.get_X(as_numpy=False)[0] if input_data is True else x.get_Y(as_numpy=False)
-        for gen in generator:
-            w = min(abs(f(gen).coords[self.window_dim].values))
-            data = f(gen).sel({self.window_dim: w})
-            res, interval_width, bin_edges = f_proc_hist(data, variables, n_bins, self.variables_dim)
-            for var in variables:
-                n_var = bins.get(var, np.zeros(n_bins))
-                n_var += res[var]
-                bins[var] = n_var
-        self.bins = bins
-        self.interval_width = interval_width
-        self.bin_edges = bin_edges
-
-    def _plot(self, add_name):
-        plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{add_name}.pdf")
+        self.inputs, self.targets = self._get_inputs_targets(generators, self.variables_dim)
+        self.bins = {}
+
+        # input plots
+        self._calculate_hist(generators, self.inputs, input_data=True)
+        for subset in generators.keys():
+            self._plot(add_name="input", subset=subset)
+        self._plot_combined(add_name="input")
+
+        # target plots
+        self._calculate_hist(generators, self.targets, input_data=False)
+        for subset in generators.keys():
+            self._plot(add_name="target", subset=subset)
+        self._plot_combined(add_name="target")
+
+    @staticmethod
+    def _get_inputs_targets(gens, dim):
+        k = list(gens.keys())[0]
+        gen = gens[k][0]
+        inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist())
+        targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist())
+        return inputs, targets
+
+    def _calculate_hist(self, generators, variables, input_data=True):
+        for set_type, generator in generators.items():
+            bins = {}
+            n_bins = 100
+            interval_width = None
+            bin_edges = None
+            f = lambda x: x.get_X(as_numpy=False)[0] if input_data is True else x.get_Y(as_numpy=False)
+            for gen in generator:
+                w = min(abs(f(gen).coords[self.window_dim].values))
+                data = f(gen).sel({self.window_dim: w})
+                res, interval_width, bin_edges = f_proc_hist(data, variables, n_bins, self.variables_dim)
+                for var in variables:
+                    n_var = bins.get(var, np.zeros(n_bins))
+                    n_var += res[var]
+                    bins[var] = n_var
+            self.bins[set_type] = bins
+            self.interval_width = interval_width
+            self.bin_edges = bin_edges
+
+    def _plot(self, add_name, subset):
+        plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{subset}_{add_name}.pdf")
         pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
-        for var in self.bins.keys():
+        bins = self.bins[subset]
+        colors = self.get_dataset_colors()
+        for var in bins.keys():
             fig, ax = plt.subplots()
-            hist_var = self.bins[var]
+            hist_var = bins[var]
             n_var = sum(hist_var)
             weights = hist_var / (self.interval_width * n_var)
-            ax.hist(self.bin_edges[:-1], self.bin_edges, weights=weights)
+            ax.hist(self.bin_edges[:-1], self.bin_edges, weights=weights, color=colors[subset])
+            ax.set_ylabel("probability density")
+            ax.set_xlabel(f"values ({subset})")
+            ax.set_title(f"Histogram ({var}, n={int(n_var)})")
+            pdf_pages.savefig()
+        # close all open figures / plots
+        pdf_pages.close()
+        plt.close('all')
+
+    def _plot_combined(self, add_name):
+        plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{add_name}.pdf")
+        pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
+        variables = self.bins[list(self.bins.keys())[0]].keys()
+        colors = self.get_dataset_colors()
+        for var in variables:
+            fig, ax = plt.subplots()
+            for subset in self.bins.keys():
+                hist_var = self.bins[subset][var]
+                n_var = sum(hist_var)
+                weights = hist_var / (self.interval_width * n_var)
+                ax.plot(self.bin_edges[:-1] + 0.5 * self.interval_width, weights, label=f"{subset}",
+                        c=colors[subset])
             ax.set_ylabel("probability density")
             ax.set_xlabel(f"{var}")
-            ax.set_title(f"Histogram (n={int(n_var)})")
+            ax.legend(loc="upper right")
+            ax.set_title(f"Histogram")
             pdf_pages.savefig()
         # close all open figures / plots
         pdf_pages.close()
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index fafcff5e13930930f298c99750990642c22cded8..89a6f205d03892c57c55a66399a43c9ba2987b42 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -400,8 +400,8 @@ class PostProcessing(RunEnvironment):
 
         try:
             if "PlotDataHistogram" in plot_list:
-                PlotDataHistogram(self.train_data, plot_folder=self.plot_path, time_dim=time_dim,
-                                  variables_dim=target_dim)
+                gens = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
+                PlotDataHistogram(gens, plot_folder=self.plot_path, time_dim=time_dim, variables_dim=target_dim)
         except Exception as e:
             logging.error(f"Could not create plot PlotDataHistogram due to the following error: {e}")