diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index c0e014fb18d1cd98b945d04e6407dd1cd2b281ca..513f64f2c174d94cb7230b141387c9a850d678cb 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -464,16 +464,18 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover self.variables_dim = variables_dim self.time_dim = time_dim self.window_dim = window_dim - self.inputs, self.targets = self._get_inputs_targets(generators, self.variables_dim) + self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim) self.bins = {} self.interval_width = {} self.bin_edges = {} # input plots - self._calculate_hist(generators, self.inputs, input_data=True) - for subset in generators.keys(): - self._plot(add_name="input", subset=subset) - self._plot_combined(add_name="input") + for branch_pos in range(number_of_branches): + self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos) + add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}" + for subset in generators.keys(): + self._plot(add_name=add_name, subset=subset) + self._plot_combined(add_name=add_name) # target plots self._calculate_hist(generators, self.targets, input_data=False) @@ -487,16 +489,17 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover gen = gens[k][0] inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist()) targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist()) - return inputs, targets + n_branches = len(gen.get_X(as_numpy=False)) + return inputs, targets, n_branches - def _calculate_hist(self, generators, variables, input_data=True): + def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0): n_bins = 100 for set_type, generator in generators.items(): tmp_bins = {} tmp_edges = {} end = {} start = {} - f = lambda x: x.get_X(as_numpy=False)[0] if input_data is True else x.get_Y(as_numpy=False) + f = lambda x: x.get_X(as_numpy=False)[branch_pos] if input_data is True else x.get_Y(as_numpy=False) for gen in generator: w = min(abs(f(gen).coords[self.window_dim].values)) data = f(gen).sel({self.window_dim: w})