From 73a7c6f338b51badbf802d65f4968f58ae6a4cf8 Mon Sep 17 00:00:00 2001 From: BING GONG <b.gong@fz-juelich.de> Date: Fri, 4 Feb 2022 11:27:40 -0500 Subject: [PATCH] Impelment the skill scores plots --- .../main_scripts/main_meta_postprocess.py | 149 +++++++++++++----- 1 file changed, 106 insertions(+), 43 deletions(-) diff --git a/video_prediction_tools/main_scripts/main_meta_postprocess.py b/video_prediction_tools/main_scripts/main_meta_postprocess.py index 4bad7702..95f6e54c 100644 --- a/video_prediction_tools/main_scripts/main_meta_postprocess.py +++ b/video_prediction_tools/main_scripts/main_meta_postprocess.py @@ -26,7 +26,7 @@ import xarray as xr class MetaPostprocess(object): def __init__(self, root_dir: str = "/p/project/deepacf/deeprain/video_prediction_shared_folder/", - analysis_config: str = None, metric: str = "mse", exp_id=None, enable_skill_scores=False): + analysis_config: str = None, metric: str = "mse", exp_id: str=None, enable_skill_scores:bool=False): """ This class is used for calculating the evaluation metric, analyize the models' results and make comparsion args: @@ -35,6 +35,7 @@ class MetaPostprocess(object): analysis_dir :str, the path to save the analysis results metric :str, based on which evalution metric for comparison, "mse","ssim", "texture" and "acc" exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot + enable_skill_scores:bool, the """ self.root_dir = root_dir self.analysis_config = analysis_config @@ -43,21 +44,26 @@ class MetaPostprocess(object): self.exp_id = exp_id self.enable_skill_scores = enable_skill_scores self.models_type = [] + self.metric_values = [] # return the shape: [num_results, persi_values, model_values] + self.skill_scores = [] # contain the calculated skill scores [num_results, skill_scores_values] + def __call__(self): self.sanity_check() self.create_analysis_dir() self.copy_analysis_config() self.load_analysis_config() - metric_values = self.get_metrics_values() - self.plot_scores(metric_values) - # self.calculate_skill_scores() - # self.plot_scores() + self.get_metrics_values() + self.calculate_skill_scores() + if self.enable_skill_scores: + self.plot_skill_scores() + else: + self.plot_abs_scores() def sanity_check(self): - available_metrics = ["mse", "ssim", "texture", "acc"] - if self.metric not in available_metrics: + self.available_metrics = ["mse", "ssim", "texture", "acc"] + if self.metric not in self.available_metrics: raise ("The 'metric' must be one of the following:", available_metrics) if type(self.exp_id) is not str: raise ("'exp_id' must be 'str' type ") @@ -112,18 +118,23 @@ class MetaPostprocess(object): return None def get_metrics_values(self): + """ + get the evaluation metric values of all the results, return a list [results,persi, model] + """ self.get_meta_info() - metric_values = [] + for i, result_dir in enumerate(self.f["results"].values()): vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i]) - metric_values.append(vals) # return the shape: [result_id, persi_values,model_values] + self.metric_values.append(vals) print("4. Get metrics values success") - return metric_values + return self.metric_values @staticmethod def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None): + """ obtain the metric values (persistence and DL model) in the "evaluation_metrics.nc" file + return: list contains the evaluatioin metrics of one result. [persi,model] """ filename = 'evaluation_metrics.nc' filepath = os.path.join(result_dir, filename) @@ -138,22 +149,41 @@ class MetaPostprocess(object): print(e) def calculate_skill_scores(self): - if self.enable_skill_scores: + """ + calculate the skill scores + """ + if self.metric_values is None: + raise ("metric_values should be a list but None is provided") + + best_score = 0 + if self.metric == "mse": pass - # do sometthing + + elif self.metric in ["ssim", "acc", "texture"]: + best_score = 1 else: - pass + raise ("The metric should be one of the following available metrics :", self.available_metrics) + + if self.enable_skill_scores: + for i in range(len(self.metric_values)): + skill_val = skill_score(self.metric_values[i][1], self.metric_values[i][0], best_score) + self.skill_scores.append(skill_val) + + return self.skill_scores + else: + return None - def get_lead_time_labels(metric_values: list = None): - leadtimes = metric_values[0][0].shape[1] + def get_lead_time_labels(self): + leadtimes = self.metric_values[0][0].shape[1] leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)] return leadtimelist - def config_plots(self, metric_values): - self.leadtimelist = MetaPostprocess.get_lead_time_labels(metric_values) + def config_plots(self): + self.leadtimelist = self.get_lead_time_labels() self.labels = self.get_labels() self.markers = self.f["markers"] self.colors = self.f["colors"] + self.n_leadtime = len(self.leadtimelist) @staticmethod def map_ylabels(metric): @@ -169,35 +199,27 @@ class MetaPostprocess(object): raise ("The metric is not correct!") return ylabel - def plot_scores(self, metric_values): - - self.config_plots(metric_values) - - if self.enable_skill_scores: - self.plot_skill_scores(metric_values) - else: - self.plot_abs_scores(metric_values) - - def plot_abs_scores(self, metric_values: list = None): - n_leadtime = len(self.leadtimelist) + def plot_abs_scores(self): + self.config_plots() fig = plt.figure(figsize = (8, 6)) ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) - for i in range(len(metric_values)): - score_plot = np.nanquantile(metric_values[i][1], 0.5, axis = 0) - plt.plot(np.arange(1, 1 + n_leadtime), score_plot, label = self.labels[i], color = self.colors[i], + for i in range(len(self.metric_values)): + score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0) + plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, label = self.labels[i], color = self.colors[i], marker = self.markers[i], markeredgecolor = 'k', linewidth = 1.2) - plt.fill_between(np.arange(1, 1 + n_leadtime), - np.nanquantile(metric_values[i][1], 0.95, axis = 0), - np.nanquantile(metric_values[i][1], 0.05, axis = 0), color = self.colors[i], alpha = 0.2) + plt.fill_between(np.arange(1, 1 + self.n_leadtime), + np.nanquantile(self.metric_values[i][1], 0.95, axis = 0), + np.nanquantile(self.metric_values[i][1], 0.05, axis = 0), color = self.colors[i], + alpha = 0.2) if self.models_type[i] == "convLSTM": - score_plot = np.nanquantile(metric_values[i][0], 0.5, axis = 0) - plt.plot(np.arange(1, 1 + n_leadtime), score_plot, label = "Persi_cv" + str(i), color = self.colors[i], - marker = "D", markeredgecolor = 'k', linewidth = 1.2) - plt.fill_between(np.arange(1, 1 + n_leadtime), - np.nanquantile(metric_values[i][0], 0.95, axis = 0), - np.nanquantile(metric_values[i][0], 0.05, axis = 0), color = "b", alpha = 0.2) + score_plot = np.nanquantile(self.metric_values[i][0], 0.5, axis = 0) + plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, label = "Persi_cv" + str(i), + color = self.colors[i], marker = "D", markeredgecolor = 'k', linewidth = 1.2) + plt.fill_between(np.arange(1, 1 + self.n_leadtime), + np.nanquantile(self.metric_values[i][0], 0.95, axis = 0), + np.nanquantile(self.metric_values[i][0], 0.05, axis = 0), color = "b", alpha = 0.2) plt.yticks(fontsize = 16) plt.xticks(np.arange(1, 13), np.arange(1, 13, 1), fontsize = 16) @@ -206,18 +228,59 @@ class MetaPostprocess(object): ylabel = MetaPostprocess.map_ylabels(self.metric) ax.set_xlabel("Lead time (hours)", fontsize = 21) ax.set_ylabel(ylabel, fontsize = 21) - fig_path = os.path.join(self.analysis_dir, self.metric + "abs_values.png") + fig_path = os.path.join(self.analysis_dir, self.metric + "_abs_values.png") # fig_path = os.path.join(prefix,fig_name) plt.savefig(fig_path, bbox_inches = "tight") plt.show() plt.close() print("The plot saved to {}".format(fig_path)) - + def plot_skill_scores(self): + """ + Plot the skill scores once the enable_skill is True + """ + self.config_plots() + fig = plt.figure(figsize = (8, 6)) + ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) + for i in range(len(self.skill_scores)): + if self.models_type[i] == "convLSTM": + c = "r" + elif self.models_type[i] == "savp": + c = "b" + else: + raise ("current only support convLSTM and SAVP for plotinig the skil scores") + + plt.boxplot(self.skill_scores[i], positions = np.arange(1, self.n_leadtime + 1), medianprops = {'color': c}, + capprops = {'color': c}, boxprops = {'color': c}, showfliers = False) + score_plot = np.nanquantile(self.skill_scores[i], 0.5, axis = 0) + plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, color = c, linewidth = 1.2, label = self.labels[i]) + + legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95), fontsize = 14) + plt.yticks(fontsize = 16) + plt.xticks(np.arange(1, 13), np.arange(1, 13, 1), fontsize = 16) + ax.set_xlabel("Lead time (hours)", fontsize = 21) + ax.set_ylabel("Skill scores of {}".format(self.metric), fontsize = 21) + fig_path = os.path.join(self.analysis_dir, self.metric + "_skill_scores.png") + plt.savefig(fig_path, bbox_inches = "tight") + plt.show() + plt.close() + print("The plot saved to {}".format(fig_path)) +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--analysis_config", type=str, required=True, help="The path points to the meta_postprocess configuration file.", + default="../meta_postprocess_config/meta_config.json") + parser.add_argument("--metric", help="Based on which the models are compared, the value should be in one of [mse,ssim,acc,texture]",default="mse") + parser.add_argument("--exp_id", help="The experiment id which will be used as postfix of the output directory",default="exp1") + parser.add_argument("--enable_skill_scores", help="compared by skill scores or the absolute evaluation values",default=True) + args = parser.parse_args() + meta = MetaPostprocess(analysis_config=args.analysis_config, metric=args.metric, exp_id=args.metric, + enable_skill_scores=args.enable_skill_scores) + meta() +if __name__ == '__main__': + main() - -- GitLab