Skip to content
Snippets Groups Projects
Commit 85a74b58 authored by leufen1's avatar leufen1
Browse files

histogram can handle branched inputs

parent 4eee713d
Branches
Tags 0.7.0
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!295Resolve "data handler FIR filter",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #72588 passed
...@@ -464,16 +464,18 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover ...@@ -464,16 +464,18 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover
self.variables_dim = variables_dim self.variables_dim = variables_dim
self.time_dim = time_dim self.time_dim = time_dim
self.window_dim = window_dim self.window_dim = window_dim
self.inputs, self.targets = self._get_inputs_targets(generators, self.variables_dim) self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim)
self.bins = {} self.bins = {}
self.interval_width = {} self.interval_width = {}
self.bin_edges = {} self.bin_edges = {}
# input plots # input plots
self._calculate_hist(generators, self.inputs, input_data=True) for branch_pos in range(number_of_branches):
self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos)
add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}"
for subset in generators.keys(): for subset in generators.keys():
self._plot(add_name="input", subset=subset) self._plot(add_name=add_name, subset=subset)
self._plot_combined(add_name="input") self._plot_combined(add_name=add_name)
# target plots # target plots
self._calculate_hist(generators, self.targets, input_data=False) self._calculate_hist(generators, self.targets, input_data=False)
...@@ -487,16 +489,17 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover ...@@ -487,16 +489,17 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover
gen = gens[k][0] gen = gens[k][0]
inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist()) inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist())
targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist()) targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist())
return inputs, targets n_branches = len(gen.get_X(as_numpy=False))
return inputs, targets, n_branches
def _calculate_hist(self, generators, variables, input_data=True): def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0):
n_bins = 100 n_bins = 100
for set_type, generator in generators.items(): for set_type, generator in generators.items():
tmp_bins = {} tmp_bins = {}
tmp_edges = {} tmp_edges = {}
end = {} end = {}
start = {} start = {}
f = lambda x: x.get_X(as_numpy=False)[0] if input_data is True else x.get_Y(as_numpy=False) f = lambda x: x.get_X(as_numpy=False)[branch_pos] if input_data is True else x.get_Y(as_numpy=False)
for gen in generator: for gen in generator:
w = min(abs(f(gen).coords[self.window_dim].values)) w = min(abs(f(gen).coords[self.window_dim].values))
data = f(gen).sel({self.window_dim: w}) data = f(gen).sel({self.window_dim: w})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment