diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index 3440321a8708ee905c5c991c4c8c5a1870719e63..aff3b4c73fa90ac10325b025eb39116f95af9a58 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -20,70 +20,50 @@ from mlair.helpers import TimeTrackingWrapper, to_list from mlair.plotting.abstract_plot_class import AbstractPlotClass @TimeTrackingWrapper -class PlotOversamplingHistogram(AbstractPlotClass): +class PlotOversampling(AbstractPlotClass): - def __init__(self, Y, Y_extreme, bin_edges, plot_folder: str = ".", - plot_name="oversampling_histogram"): + def __init__(self, Y, Y_extreme, bin_edges, oversampling_rates, plot_folder: str = ".", + plot_names=["oversampling_histogram", "oversampling_density_histogram", "oversampling_rates", + "oversampling_rates_deviation"]): - super().__init__(plot_folder, plot_name) - self._plot(Y, Y_extreme, bin_edges) + super().__init__(plot_folder, plot_names[0]) + Y_hist, Y_extreme_hist = self._plot_oversampling_histogram(Y, Y_extreme, bin_edges) + real_oversampling = Y_extreme_hist / Y_hist + self._save() + self.plot_name = plot_names[1] + self._plot_oversampling_density_histogram(Y, Y_extreme, bin_edges) + self._save() + self.plot_name = plot_names[2] + self._plot_oversampling_rates(oversampling_rates, real_oversampling) + self._save() + self.plot_name = plot_names[3] + self._plot_oversampling_rates_deviation(oversampling_rates, real_oversampling) self._save() - def _plot(self, Y, Y_extreme, bin_edges): + def _plot_oversampling_histogram(self, Y, Y_extreme, bin_edges): fig, ax = plt.subplots(1, 1) - Y.plot.hist(bins=bin_edges, histtype="step", label="Before", ax=ax)[0] - Y_extreme.plot.hist(bins=bin_edges, histtype="step", label="After", ax=ax)[0] + Y_hist = Y.plot.hist(bins=bin_edges, histtype="step", label="Before", ax=ax)[0] + Y_extreme_hist = Y_extreme.plot.hist(bins=bin_edges, histtype="step", label="After", ax=ax)[0] ax.set_title(f"Histogram before-after oversampling") ax.legend() + return Y_hist, Y_extreme_hist - -@TimeTrackingWrapper -class PlotOversamplingDensityHistogram(AbstractPlotClass): - - def __init__(self, Y, Y_extreme, bin_edges, plot_folder: str = ".", - plot_name="oversampling_density_histogram"): - super().__init__(plot_folder, plot_name) - self._plot(Y, Y_extreme, bin_edges) - self._save() - - def _plot(self, Y, Y_extreme, bin_edges): + def _plot_oversampling_density_histogram(self, Y, Y_extreme, bin_edges): fig, ax = plt.subplots(1, 1) Y.plot.hist(bins=bin_edges, density=True, histtype="step", label="Before", ax=ax)[0] Y_extreme.plot.hist(bins=bin_edges, density=True, histtype="step", label="After", ax=ax)[0] ax.set_title(f"Density Histogram before-after oversampling") ax.legend() - -@TimeTrackingWrapper -class PlotOversamplingRates(AbstractPlotClass): - - def __init__(self, Y, Y_extreme, bin_edges, oversampling_rates, Y_hist, Y_extreme_hist, plot_folder: str = ".", - plot_name="oversampling_rates"): - super().__init__(plot_folder, plot_name) - self._plot(Y, Y_extreme, bin_edges, oversampling_rates, Y_hist, Y_extreme_hist) - self._save() - - def _plot(self, Y, Y_extreme, bin_edges, oversampling_rates, Y_hist, Y_extreme_hist): + def _plot_oversampling_rates(self, oversampling_rates, real_oversampling): fig, ax = plt.subplots(1, 1) - real_oversampling = Y_extreme_hist[0] / Y_hist[0] ax.plot(range(len(real_oversampling)), oversampling_rates, label="Desired oversampling_rates") ax.plot(range(len(real_oversampling)), real_oversampling, label="Actual Oversampling Rates") ax.set_title(f"Oversampling rates") ax.legend() - -@TimeTrackingWrapper -class PlotOversamplingRatesDeviation(AbstractPlotClass): - - def __init__(self, Y, Y_extreme, bin_edges, oversampling_rates, Y_hist, Y_extreme_hist, plot_folder: str = ".", - plot_name="oversampling_rates_deviation"): - super().__init__(plot_folder, plot_name) - self._plot(Y, Y_extreme, bin_edges, oversampling_rates, Y_hist, Y_extreme_hist) - self._save() - - def _plot(self, Y, Y_extreme, bin_edges, oversampling_rates, Y_hist, Y_extreme_hist): + def _plot_oversampling_rates_deviation(self, oversampling_rates, real_oversampling): fig, ax = plt.subplots(1, 1) - real_oversampling = Y_extreme_hist[0] / Y_hist[0] ax.plot(range(len(real_oversampling)), real_oversampling / oversampling_rates, label="Actual/Desired Rate") ax.set_title(f"Deviation from desired oversampling rates") diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 742acf2f772abded667bff3a54b17064f5901c4b..74e786a6b994c4082c70949621b8584e8bf0018e 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -22,8 +22,7 @@ from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles, PlotSeparationOfScales from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ - PlotPeriodogram, PlotDataHistogram, PlotOversamplingHistogram, PlotOversamplingDensityHistogram, \ - PlotOversamplingRates, PlotOversamplingRatesDeviation + PlotPeriodogram, PlotDataHistogram, PlotOversampling from mlair.run_modules.run_environment import RunEnvironment @@ -313,14 +312,7 @@ class PostProcessing(RunEnvironment): oversampling_rates = self.data_store.get('oversampling_rates_capped','train') Y = self.data_store.get('Oversampling_Y') Y_extreme = self.data_store.get('Oversampling_Y_extreme') - Y_hist = Y.plot.hist(bins=bin_edges, histtype="step") - Y_extreme_hist = Y_extreme.plot.hist(bins=bin_edges, histtype="step") - PlotOversamplingHistogram(Y, Y_extreme, bin_edges, plot_folder=self.plot_path) - PlotOversamplingDensityHistogram(Y, Y_extreme, bin_edges, plot_folder=self.plot_path) - PlotOversamplingRates(Y, Y_extreme, bin_edges, oversampling_rates, Y_hist, Y_extreme_hist, - plot_folder=self.plot_path) - PlotOversamplingRatesDeviation(Y, Y_extreme, bin_edges, oversampling_rates, Y_hist, - Y_extreme_hist, plot_folder=self.plot_path) + PlotOversampling(Y, Y_extreme, bin_edges, oversampling_rates, plot_folder=self.plot_path) except Exception as e: logging.error(f"Could not create plot OversamplingPlots due to the following error: {e}")