diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index fc4f7f0910c086bb5bb7f7802e5332404f7b5359..31b58a56375ea26a857ee132c2170680bab4e55a 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -48,7 +48,8 @@ DEFAULT_CREATE_NEW_BOOTSTRAPS = False DEFAULT_NUMBER_OF_BOOTSTRAPS = 20 DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", - "PlotAvailability", "PlotAvailabilityHistogram", "PlotDataHistogram", "PlotOversampling"] + "PlotAvailability", "PlotAvailabilityHistogram", "PlotDataHistogram", "PlotOversampling", + "PlotOversamplingContingency"] DEFAULT_SAMPLING = "daily" DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", "no": "", "no2": "", "o3": "", diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 491aa52e0a9fe0010f77cde315d1f9b7ddb76dfb..3bef0c305036f290ceadbffa2f27a8085c174142 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -28,6 +28,51 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING) # matplotlib.use("TkAgg") # import matplotlib.pyplot as plt +@TimeTrackingWrapper +class PlotOversamplingContingency(AbstractPlotClass): + + def __init__(self, predictions, labels, plot_folder: str = ".", + plot_names=["oversampling_threat_score", "oversampling_hit_rate", "oversampling_false_alarm_rate", + "oversampling_all_scores"]): + + super().__init__(plot_folder, plot_names[0]) + ts = [] + h = [] + f = [] + max_label = 0 + min_label = 0 + for threshold in range(min_label, max_label): + true_above = 0 + false_above = 0 + false_below = 0 + true_below = 0 + for prediction, label in predictions, labels: + if prediction >= threshold: + if label >= threshold: + true_above = + 1 + else: + false_above = + 1 + else: + if label >= threshold: + false_below = + 1 + else: + true_below = + 1 + ts.append(true_above/(true_above+false_above+false_below)) + h.append(true_above/(true_above+false_below)) + f.append(false_above/(false_above+true_below)) + plt.plot(range(min_label, max_label), ts) + self._save() + self.plot_name = plot_names[1] + plt.plot(range(min_label, max_label), h) + self._save() + self.plot_name = plot_names[2] + plt.plot(range(min_label, max_label), f) + self.plot_name = plot_names[3] + plt.plot(range(min_label, max_label), ts) + plt.plot(range(min_label, max_label), h) + plt.plot(range(min_label, max_label), f) + self._save() + @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 74e786a6b994c4082c70949621b8584e8bf0018e..6acfcfbb9a16a6a5af37afb6b3c56807f8f2ff5c 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -20,7 +20,8 @@ from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles, PlotSeparationOfScales + PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles,\ + PlotSeparationOfScales, PlotOversamplingContingency from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ PlotPeriodogram, PlotDataHistogram, PlotOversampling from mlair.run_modules.run_environment import RunEnvironment @@ -305,6 +306,17 @@ class PostProcessing(RunEnvironment): target_dim = self.data_store.get("target_dim") iter_dim = self.data_store.get("iter_dim") + try: + if (self.data_store.get('oversampling_method')=='bin_oversampling') and ( + "PlotOversamplingContingency" in plot_list): + bin_edges = self.data_store.get('oversampling_bin_edges') + oversampling_rates = self.data_store.get('oversampling_rates_capped','train') + predictions = None + labels = None + PlotOversampling(predictions, labels, plot_folder=self.plot_path) + except Exception as e: + logging.error(f"Could not create plot OversamplingPlots due to the following error: {e}") + try: if (self.data_store.get('oversampling_method')=='bin_oversampling') and ( "PlotOversampling" in plot_list): diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py index 0f098dc2e57daaad6b475175f057f01a690551f0..922de3599dc7dc40717e0aeb8c7b8158ad21da38 100644 --- a/test/test_configuration/test_defaults.py +++ b/test/test_configuration/test_defaults.py @@ -68,4 +68,4 @@ class TestAllDefaults: assert DEFAULT_PLOT_LIST == ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", "PlotAvailability", "PlotAvailabilityHistogram", - "PlotDataHistogram","PlotOversampling"] + "PlotDataHistogram","PlotOversampling","PlotOversamplingContingency"]