From 85a74b588a564f0c75a5f37a56cbe0d5c04a013c Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 8 Jul 2021 13:57:59 +0200
Subject: [PATCH] histogram can handle branched inputs

---
 mlair/plotting/data_insight_plotting.py | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index c0e014fb..513f64f2 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -464,16 +464,18 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
         self.variables_dim = variables_dim
         self.time_dim = time_dim
         self.window_dim = window_dim
-        self.inputs, self.targets = self._get_inputs_targets(generators, self.variables_dim)
+        self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim)
         self.bins = {}
         self.interval_width = {}
         self.bin_edges = {}
 
         # input plots
-        self._calculate_hist(generators, self.inputs, input_data=True)
-        for subset in generators.keys():
-            self._plot(add_name="input", subset=subset)
-        self._plot_combined(add_name="input")
+        for branch_pos in range(number_of_branches):
+            self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos)
+            add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}"
+            for subset in generators.keys():
+                self._plot(add_name=add_name, subset=subset)
+            self._plot_combined(add_name=add_name)
 
         # target plots
         self._calculate_hist(generators, self.targets, input_data=False)
@@ -487,16 +489,17 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
         gen = gens[k][0]
         inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist())
         targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist())
-        return inputs, targets
+        n_branches = len(gen.get_X(as_numpy=False))
+        return inputs, targets, n_branches
 
-    def _calculate_hist(self, generators, variables, input_data=True):
+    def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0):
         n_bins = 100
         for set_type, generator in generators.items():
             tmp_bins = {}
             tmp_edges = {}
             end = {}
             start = {}
-            f = lambda x: x.get_X(as_numpy=False)[0] if input_data is True else x.get_Y(as_numpy=False)
+            f = lambda x: x.get_X(as_numpy=False)[branch_pos] if input_data is True else x.get_Y(as_numpy=False)
             for gen in generator:
                 w = min(abs(f(gen).coords[self.window_dim].values))
                 data = f(gen).sel({self.window_dim: w})
-- 
GitLab