diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index 335463454a3d3937cd93c739e63ab540f08ffd92..3bec759076a862d89be6c7495ef9abdebd0d4123 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -459,7 +459,7 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover """ def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", plot_name="histogram", - variables_dim="variables", time_dim="datetime", window_dim="window"): + variables_dim="variables", time_dim="datetime", window_dim="window", upsampling=False): super().__init__(plot_folder, plot_name) self.variables_dim = variables_dim self.time_dim = time_dim @@ -468,6 +468,8 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover self.bins = {} self.interval_width = {} self.bin_edges = {} + if upsampling is True: + self._handle_upsampling(generators) # input plots for branch_pos in range(number_of_branches): @@ -483,6 +485,11 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover self._plot(add_name="target", subset=subset) self._plot_combined(add_name="target") + @staticmethod + def _handle_upsampling(generators): + if "train" in generators: + generators.update({"train_upsampled": generators["train"]}) + @staticmethod def _get_inputs_targets(gens, dim): k = list(gens.keys())[0] @@ -495,11 +502,15 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0): n_bins = 100 for set_type, generator in generators.items(): + upsampling = "upsampled" in set_type tmp_bins = {} tmp_edges = {} end = {} start = {} - f = lambda x: x.get_X(as_numpy=False)[branch_pos] if input_data is True else x.get_Y(as_numpy=False) + if input_data is True: + f = lambda x: x.get_X(as_numpy=False, upsampling=upsampling)[branch_pos] + else: + f = lambda x: x.get_Y(as_numpy=False, upsampling=upsampling) for gen in generator: w = min(abs(f(gen).coords[self.window_dim].values)) data = f(gen).sel({self.window_dim: w}) @@ -536,6 +547,7 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover bin_edges = self.bin_edges[subset] interval_width = self.interval_width[subset] colors = self.get_dataset_colors() + colors.update({"train_upsampled": colors.get("train_val", "#000000")}) for var in bins.keys(): fig, ax = plt.subplots() hist_var = bins[var] @@ -555,6 +567,7 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) variables = self.bins[list(self.bins.keys())[0]].keys() colors = self.get_dataset_colors() + colors.update({"train_upsampled": colors.get("train_val", "#000000")}) for var in variables: fig, ax = plt.subplots() for subset in self.bins.keys(): diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index f3909fde29b466af1bf64124ab1d57873ae70d18..edab0a3e8ad3ee6f7eeb50318f617d3b661255bd 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -432,8 +432,10 @@ class PostProcessing(RunEnvironment): try: if "PlotDataHistogram" in plot_list: + upsampling = self.data_store.get_default("upsampling", scope="train", default=False) gens = {"train": self.train_data, "val": self.val_data, "test": self.test_data} - PlotDataHistogram(gens, plot_folder=self.plot_path, time_dim=time_dim, variables_dim=target_dim) + PlotDataHistogram(gens, plot_folder=self.plot_path, time_dim=time_dim, variables_dim=target_dim, + upsampling=upsampling) except Exception as e: logging.error(f"Could not create plot PlotDataHistogram due to the following error: {e}")