From f73acaa44dd0cdca0d487a2b642092d7dcc85e06 Mon Sep 17 00:00:00 2001 From: "v.gramlich1" <v.gramlichfz-juelich.de> Date: Mon, 26 Jul 2021 14:25:34 +0200 Subject: [PATCH] Commits vor merge --- mlair/data_handler/default_data_handler.py | 1 + mlair/plotting/data_insight_plotting.py | 18 ++- mlair/plotting/postprocessing_plotting.py | 159 ++++++++++++++++----- mlair/run_modules/experiment_setup.py | 2 +- mlair/run_modules/post_processing.py | 12 +- mlair/run_modules/pre_processing.py | 10 -- 6 files changed, 137 insertions(+), 65 deletions(-) diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index acc3caa0..8d977e11 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -206,6 +206,7 @@ class DefaultDataHandler(AbstractDataHandler): else: self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=self.time_dim), self._X_extreme, extremes_X)) self._Y_extreme = xr.concat([self._Y_extreme, extremes_Y], dim=self.time_dim) + self._store(fresh_store=True) def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index a2007f3f..26376637 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -22,6 +22,8 @@ from mlair.plotting.abstract_plot_class import AbstractPlotClass @TimeTrackingWrapper class PlotOversampling(AbstractPlotClass): + #Todo: Build histograms correctly + def __init__(self, data, bin_edges, oversampling_rates, plot_folder: str = ".", plot_names=["oversampling_histogram", "oversampling_density_histogram", "oversampling_rates", "oversampling_rates_deviation"]): @@ -33,7 +35,7 @@ class PlotOversampling(AbstractPlotClass): self._plot_oversampling_histogram(Y_hist, Y_extreme_hist, bin_edges) self._save() self.plot_name = plot_names[1] - self._plot_oversampling_density_histogram(Y_hist_dens, Y_extreme_hist_dens, bin_edges) + self._plot_oversampling_histogram(Y_hist_dens, Y_extreme_hist_dens, bin_edges) self._save() self.plot_name = plot_names[2] self._plot_oversampling_rates(oversampling_rates, real_oversampling) @@ -56,15 +58,11 @@ class PlotOversampling(AbstractPlotClass): def _plot_oversampling_histogram(self, Y_hist, Y_extreme_hist, bin_edges): fig, ax = plt.subplots(1, 1) - ax.step(bin_edges, np.append(0,Y_hist), label="Before oversampling") - ax.step(bin_edges, np.append(0,Y_extreme_hist), label="After oversampling") - ax.set_title(f"Histogram before-after oversampling") - ax.legend() - - def _plot_oversampling_density_histogram(self, Y_hist_dens, Y_extreme_hist_dens, bin_edges): - fig, ax = plt.subplots(1, 1) - ax.step(bin_edges, np.append(0,Y_hist_dens), label="Before oversampling") - ax.step(bin_edges, np.append(0,Y_extreme_hist_dens), label="After oversampling") + ax.hist(bin_edges[:-1], bin_edges, weights=Y_hist, label="Before oversampling") + ax.hist(bin_edges[:-1], bin_edges, weights=Y_extreme_hist, label="After oversampling") + #ax.plot(bin_edges[:-1] + 0.5 * interval_width, weights, label=f"{subset}", c=colors[subset]) + #ax.step(bin_edges, np.append(0,Y_hist), label="Before oversampling") + #ax.step(bin_edges, np.append(0,Y_extreme_hist), label="After oversampling") #ax.stairs(Y_hist_dens, bin_edges, label="Before oversampling") #ax.stairs(Y_extreme_hist_dens, bin_edges, label="After oversampling") ax.set_title(f"Density Histogram before-after oversampling") diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 67233780..b5e76e55 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -18,7 +18,7 @@ from matplotlib.backends.backend_pdf import PdfPages from mlair import helpers from mlair.data_handler.iterator import DataCollection -from mlair.helpers import TimeTrackingWrapper +from mlair.helpers import TimeTrackingWrapper, to_list from mlair.plotting.abstract_plot_class import AbstractPlotClass logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -30,53 +30,140 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING) @TimeTrackingWrapper class PlotOversamplingContingency(AbstractPlotClass): + #Todo: 1. Make competitors flexible + # 2. Get min and max_label - def __init__(self, predictions, labels, plot_folder: str = ".", + def __init__(self, station_names, file_path, comp_path, file_name, plot_folder: str = ".", model_name: str = "nn", + obs_name: str = "obs", comp_names: str = "IntelliO3", plot_names=["oversampling_threat_score", "oversampling_hit_rate", "oversampling_false_alarm_rate", "oversampling_all_scores"]): super().__init__(plot_folder, plot_names[0]) - ts = [] - h = [] - f = [] - max_label = 0 - min_label = 0 - for station in station_names: - file = os.path.join(file_path, file_name % station) - forecast = xr.open_dataarray(file) - competitors = extract_method(station) - for threshold in range(min_label, max_label): - true_above = 0 - false_above = 0 - false_below = 0 - true_below = 0 - for prediction, label in predictions, labels: - if prediction >= threshold: - if label >= threshold: - true_above = + 1 - else: - false_above = + 1 - else: - if label >= threshold: - false_below = + 1 - else: - true_below = + 1 - ts.append(true_above/(true_above+false_above+false_below)) - h.append(true_above/(true_above+false_below)) - f.append(false_above/(false_above+true_below)) - plt.plot(range(min_label, max_label), ts) + self._stations = station_names + self._file_path = file_path + self._comp_path = comp_path + self._file_name = file_name + self._model_name = model_name + self._obs_name = obs_name + self._comp_names = to_list(comp_names) + true_above, false_above, false_below, true_below, borders = self._calculate_contingencies() + ts, h, f = self._calculate_scores(true_above, false_above, false_below, true_below) + min_label = borders[0] + max_label = borders[1] + plt.plot(range(min_label, max_label), ts, label="threat score") + plt.legend() self._save() self.plot_name = plot_names[1] - plt.plot(range(min_label, max_label), h) + plt.plot(range(min_label, max_label), h, label="hit rate") + plt.legend() self._save() self.plot_name = plot_names[2] - plt.plot(range(min_label, max_label), f) + plt.plot(range(min_label, max_label), f, label="false alarm rate") + plt.legend() + self._save() self.plot_name = plot_names[3] - plt.plot(range(min_label, max_label), ts) - plt.plot(range(min_label, max_label), h) - plt.plot(range(min_label, max_label), f) + plt.plot(range(min_label, max_label), ts, label="threat score") + plt.plot(range(min_label, max_label), h, label="hit rate") + plt.plot(range(min_label, max_label), f, label="false alarm rate") + plt.legend() self._save() + def _create_competitor_forecast(self, station_name: str, competitor_name: str) -> xr.DataArray: + """ + Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station + indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will + raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either + there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file. + + :param station_name: name of the station to load data for + :param competitor_name: name of the model + :return: the forecast of the given competitor + """ + path = os.path.join(self._comp_path, competitor_name) + file = os.path.join(path, f"forecasts_{station_name}_test.nc") + data = xr.open_dataarray(file) + # data = data.expand_dims(Stations=[station_name]) # ToDo: remove line + forecast = data.sel(type=[self._model_name]) + forecast.coords["type"] = [competitor_name] + return forecast + + def _load_competitors(self, station_name: str, comp) -> xr.DataArray: + """ + Load all requested and available competitors for a given station. Forecasts must be available in the competitor + path like `<competitor_path>/<target_var>/forecasts_<station_name>_test.nc`. The naming style is equal for all + forecasts of MLAir, so that forecasts of a different experiment can easily be copied into the competitor path + without any change. + + :param station_name: station indicator to load competitors for + + :return: a single xarray with all competing forecasts + """ + competing_predictions = [] + for competitor_name in comp: + try: + prediction = self._create_competitor_forecast(station_name, competitor_name) + competing_predictions.append(prediction) + except (FileNotFoundError, KeyError): + logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.") + continue + return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None + + def _calculate_contingencies(self): + for station in self._stations: + file = os.path.join(self._file_path, self._file_name % station) + forecast_file = xr.open_dataarray(file) + obs = forecast_file.sel(type=self._obs_name) + model = forecast_file.sel(type=self._model_name) + competitors = [self._load_competitors(station, [comp]).sel(type=comp) for comp in self._comp_names] + min_label = 0 + max_label = 100 + borders = [min_label, max_label] + true_above = [] + false_above = [] + false_below = [] + true_below = [] + for threshold in range(min_label, max_label): + ta, fa, fb, tb = self._single_contingency(obs, model, threshold) + true_above.append(ta) + false_above.append(fa) + false_below.append(fb) + true_below.append(tb) + return np.array(true_above), np.array(false_above), np.array(false_below), np.array(true_below), borders + + + def _single_contingency(self, obs, pred, threshold): + ta = 0 + fa = 0 + fb = 0 + tb = 0 + observations = obs.values.flatten() + predictions = pred.values.flatten() + for i in range(len(observations)): + if predictions[i] >= threshold: + if observations[i] >= threshold: + ta += + 1 + else: + fa += + 1 + else: + if observations[i] >= threshold: + fb += 1 + else: + tb += 1 + return ta, fa, fb, tb + + def _calculate_scores(self, true_above, false_above, false_below, true_below): + np.seterr(divide="ignore") + np.seterr(divide="ignore") + ts = true_above/(true_above + false_above + false_below) + h = true_above/(true_above + false_below) + f = false_above/(false_above + true_below) + np.nan_to_num(ts, copy=False) + np.nan_to_num(h, copy=False) + np.nan_to_num(f, copy=False) + return ts, h, f + + + @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index cefd4505..c5687e37 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -363,7 +363,7 @@ class ExperimentSetup(RunEnvironment): # set competitors self._set_param("competitors", competitors, default=[]) competitor_path_default = os.path.join(self.data_store.get("data_path"), "competitors", - "_".join(self.data_store.get("target_var"))) + "_".join(to_list(self.data_store.get("target_var")))) self._set_param("competitor_path", competitor_path, default=competitor_path_default) # check variables, statistics and target variable diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 2febedb6..8a594808 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -309,11 +309,9 @@ class PostProcessing(RunEnvironment): try: if (self.data_store.get('oversampling_method')=='bin_oversampling') and ( "PlotOversamplingContingency" in plot_list): - predictions = None - labels = None - PlotOversamplingContingency(extract_method=self.load_competitors(), station_names=self.test_data.keys(), - file_path=path, file_name=r"forecasts_%s_test.nc", - plot_folder=self.plot_path) + PlotOversamplingContingency(station_names=self.test_data.keys(), file_path=path, comp_path=self.competitor_path, + comp_names=self.competitors, + file_name=r"forecasts_%s_test.nc", plot_folder=self.plot_path) except Exception as e: logging.error(f"Could not create plot OversamplingContingencyPlots due to the following error: {e}") @@ -321,9 +319,7 @@ class PostProcessing(RunEnvironment): if (self.data_store.get('oversampling_method')=='bin_oversampling') and ( "PlotOversampling" in plot_list): bin_edges = self.data_store.get('oversampling_bin_edges') - oversampling_rates = self.data_store.get('oversampling_rates_capped','train') - #Y = self.data_store.get('Oversampling_Y') - #Y_extreme = self.data_store.get('Oversampling_Y_extreme') + oversampling_rates = self.data_store.get('oversampling_rates_capped', 'train') PlotOversampling(self.train_data, bin_edges, oversampling_rates, plot_folder=self.plot_path) except Exception as e: logging.error(f"Could not create plot OversamplingPlots due to the following error: {e}") diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 2d6dc3b5..c0065523 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -107,16 +107,6 @@ class PreProcessing(RunEnvironment): #Y_extreme = None for station in data: station.apply_oversampling(bin_edges, oversampling_rates_capped) - ''' - if Y is None: - Y = station._Y - Y_extreme = station._Y_extreme - else: - Y = xr.concat([Y, station._Y], dim="Stations") - Y_extreme = xr.concat([Y_extreme, station._Y_extreme], dim="Stations") - self.data_store.set('Oversampling_Y', Y) - self.data_store.set('Oversampling_Y_extreme', Y_extreme) - ''' def report_pre_processing(self): -- GitLab