From 85a74b588a564f0c75a5f37a56cbe0d5c04a013c Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Thu, 8 Jul 2021 13:57:59 +0200 Subject: [PATCH] histogram can handle branched inputs --- mlair/plotting/data_insight_plotting.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index c0e014fb..513f64f2 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -464,16 +464,18 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover self.variables_dim = variables_dim self.time_dim = time_dim self.window_dim = window_dim - self.inputs, self.targets = self._get_inputs_targets(generators, self.variables_dim) + self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim) self.bins = {} self.interval_width = {} self.bin_edges = {} # input plots - self._calculate_hist(generators, self.inputs, input_data=True) - for subset in generators.keys(): - self._plot(add_name="input", subset=subset) - self._plot_combined(add_name="input") + for branch_pos in range(number_of_branches): + self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos) + add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}" + for subset in generators.keys(): + self._plot(add_name=add_name, subset=subset) + self._plot_combined(add_name=add_name) # target plots self._calculate_hist(generators, self.targets, input_data=False) @@ -487,16 +489,17 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover gen = gens[k][0] inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist()) targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist()) - return inputs, targets + n_branches = len(gen.get_X(as_numpy=False)) + return inputs, targets, n_branches - def _calculate_hist(self, generators, variables, input_data=True): + def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0): n_bins = 100 for set_type, generator in generators.items(): tmp_bins = {} tmp_edges = {} end = {} start = {} - f = lambda x: x.get_X(as_numpy=False)[0] if input_data is True else x.get_Y(as_numpy=False) + f = lambda x: x.get_X(as_numpy=False)[branch_pos] if input_data is True else x.get_Y(as_numpy=False) for gen in generator: w = min(abs(f(gen).coords[self.window_dim].values)) data = f(gen).sel({self.window_dim: w}) -- GitLab