diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index 0a3b28cea19d2bbfd97fef470c352248c169fced..29693b3f1c5832a460925819da9cc7aeef66a54e 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -23,7 +23,7 @@ from mlair.plotting.abstract_plot_class import AbstractPlotClass @TimeTrackingWrapper class PlotOversampling(AbstractPlotClass): - def __init__(self, data, bin_edges, oversampling_rates, plot_folder: str = ".", + def __init__(self, data, bin_edges, bin_edges_retransformed, oversampling_rates, plot_folder: str = ".", plot_names=["oversampling_histogram", "oversampling_density_histogram", "oversampling_rates", "oversampling_rates_deviation"]): @@ -31,10 +31,10 @@ class PlotOversampling(AbstractPlotClass): Y_hist, Y_extreme_hist, Y_hist_dens, Y_extreme_hist_dens = self._calculate_hist(data, bin_edges) real_oversampling = Y_extreme_hist / Y_hist - self._plot_oversampling_histogram(Y_hist, Y_extreme_hist, bin_edges) + self._plot_oversampling_histogram(Y_hist, Y_extreme_hist, bin_edges_retransformed) self._save() self.plot_name = plot_names[1] - self._plot_oversampling_histogram(Y_hist_dens, Y_extreme_hist_dens, bin_edges) + self._plot_oversampling_histogram(Y_hist_dens, Y_extreme_hist_dens, bin_edges_retransformed) self._save() self.plot_name = plot_names[2] self._plot_oversampling_rates(oversampling_rates, real_oversampling) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index aef5c8e079c849ae1e0ec8ec14e0a6df46f772ad..a19ddbe1dd09fa6da833d4c448f113c1bd02b44c 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -356,8 +356,9 @@ class PostProcessing(RunEnvironment): if (self.data_store.get('oversampling_method')=='bin_oversampling') and ( "PlotOversampling" in plot_list): bin_edges = self.data_store.get('oversampling_bin_edges') + bin_edges_retransformed = self.data_store.get('oversampling_bin_edges_retransformed') oversampling_rates = self.data_store.get('oversampling_rates_capped', 'train') - PlotOversampling(self.train_data, bin_edges, oversampling_rates, plot_folder=self.plot_path) + PlotOversampling(self.train_data, bin_edges, bin_edges_retransformed, oversampling_rates, plot_folder=self.plot_path) except Exception as e: logging.error(f"Could not create plot OversamplingPlots due to the following error: {e}") diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 3354e78c0c9ee85dad71f15a7a0171248913c0b7..ef6f32552c5bc20107755d1f0fa5eff0f4171441 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -99,10 +99,15 @@ class PreProcessing(RunEnvironment): # Get Oversampling rates (with and without cap) oversampling_rates = 1 / histogram oversampling_rates_capped = np.minimum(oversampling_rates, rates_cap) + # Get transformer variables + o3_mean = self.data_store.get("transformation")[0]["o3"]["mean"].values + o3_std = self.data_store.get("transformation")[0]["o3"]["std"].values + bin_edges_retransformed = np.floor(bin_edges*o3_std+o3_mean) # Add to datastore self.data_store.set('oversampling_rates', oversampling_rates, 'train') self.data_store.set('oversampling_rates_capped', oversampling_rates_capped, 'train') self.data_store.set('oversampling_bin_edges', bin_edges) + self.data_store.set('oversampling_bin_edges_retransformed', bin_edges_retransformed) #Y = None #Y_extreme = None for station in data: