Skip to content
Snippets Groups Projects
Commit 73a7c6f3 authored by BING GONG's avatar BING GONG
Browse files

Impelment the skill scores plots

parent 4fb114c9
Branches
Tags
No related merge requests found
Pipeline #90897 failed
...@@ -26,7 +26,7 @@ import xarray as xr ...@@ -26,7 +26,7 @@ import xarray as xr
class MetaPostprocess(object): class MetaPostprocess(object):
def __init__(self, root_dir: str = "/p/project/deepacf/deeprain/video_prediction_shared_folder/", def __init__(self, root_dir: str = "/p/project/deepacf/deeprain/video_prediction_shared_folder/",
analysis_config: str = None, metric: str = "mse", exp_id=None, enable_skill_scores=False): analysis_config: str = None, metric: str = "mse", exp_id: str=None, enable_skill_scores:bool=False):
""" """
This class is used for calculating the evaluation metric, analyize the models' results and make comparsion This class is used for calculating the evaluation metric, analyize the models' results and make comparsion
args: args:
...@@ -35,6 +35,7 @@ class MetaPostprocess(object): ...@@ -35,6 +35,7 @@ class MetaPostprocess(object):
analysis_dir :str, the path to save the analysis results analysis_dir :str, the path to save the analysis results
metric :str, based on which evalution metric for comparison, "mse","ssim", "texture" and "acc" metric :str, based on which evalution metric for comparison, "mse","ssim", "texture" and "acc"
exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot
enable_skill_scores:bool, the
""" """
self.root_dir = root_dir self.root_dir = root_dir
self.analysis_config = analysis_config self.analysis_config = analysis_config
...@@ -43,21 +44,26 @@ class MetaPostprocess(object): ...@@ -43,21 +44,26 @@ class MetaPostprocess(object):
self.exp_id = exp_id self.exp_id = exp_id
self.enable_skill_scores = enable_skill_scores self.enable_skill_scores = enable_skill_scores
self.models_type = [] self.models_type = []
self.metric_values = [] # return the shape: [num_results, persi_values, model_values]
self.skill_scores = [] # contain the calculated skill scores [num_results, skill_scores_values]
def __call__(self): def __call__(self):
self.sanity_check() self.sanity_check()
self.create_analysis_dir() self.create_analysis_dir()
self.copy_analysis_config() self.copy_analysis_config()
self.load_analysis_config() self.load_analysis_config()
metric_values = self.get_metrics_values() self.get_metrics_values()
self.plot_scores(metric_values) self.calculate_skill_scores()
# self.calculate_skill_scores() if self.enable_skill_scores:
# self.plot_scores() self.plot_skill_scores()
else:
self.plot_abs_scores()
def sanity_check(self): def sanity_check(self):
available_metrics = ["mse", "ssim", "texture", "acc"] self.available_metrics = ["mse", "ssim", "texture", "acc"]
if self.metric not in available_metrics: if self.metric not in self.available_metrics:
raise ("The 'metric' must be one of the following:", available_metrics) raise ("The 'metric' must be one of the following:", available_metrics)
if type(self.exp_id) is not str: if type(self.exp_id) is not str:
raise ("'exp_id' must be 'str' type ") raise ("'exp_id' must be 'str' type ")
...@@ -112,18 +118,23 @@ class MetaPostprocess(object): ...@@ -112,18 +118,23 @@ class MetaPostprocess(object):
return None return None
def get_metrics_values(self): def get_metrics_values(self):
"""
get the evaluation metric values of all the results, return a list [results,persi, model]
"""
self.get_meta_info() self.get_meta_info()
metric_values = []
for i, result_dir in enumerate(self.f["results"].values()): for i, result_dir in enumerate(self.f["results"].values()):
vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i]) vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i])
metric_values.append(vals) # return the shape: [result_id, persi_values,model_values] self.metric_values.append(vals)
print("4. Get metrics values success") print("4. Get metrics values success")
return metric_values return self.metric_values
@staticmethod @staticmethod
def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None): def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None):
""" """
obtain the metric values (persistence and DL model) in the "evaluation_metrics.nc" file obtain the metric values (persistence and DL model) in the "evaluation_metrics.nc" file
return: list contains the evaluatioin metrics of one result. [persi,model]
""" """
filename = 'evaluation_metrics.nc' filename = 'evaluation_metrics.nc'
filepath = os.path.join(result_dir, filename) filepath = os.path.join(result_dir, filename)
...@@ -138,22 +149,41 @@ class MetaPostprocess(object): ...@@ -138,22 +149,41 @@ class MetaPostprocess(object):
print(e) print(e)
def calculate_skill_scores(self): def calculate_skill_scores(self):
if self.enable_skill_scores: """
calculate the skill scores
"""
if self.metric_values is None:
raise ("metric_values should be a list but None is provided")
best_score = 0
if self.metric == "mse":
pass pass
# do sometthing
elif self.metric in ["ssim", "acc", "texture"]:
best_score = 1
else: else:
pass raise ("The metric should be one of the following available metrics :", self.available_metrics)
if self.enable_skill_scores:
for i in range(len(self.metric_values)):
skill_val = skill_score(self.metric_values[i][1], self.metric_values[i][0], best_score)
self.skill_scores.append(skill_val)
return self.skill_scores
else:
return None
def get_lead_time_labels(metric_values: list = None): def get_lead_time_labels(self):
leadtimes = metric_values[0][0].shape[1] leadtimes = self.metric_values[0][0].shape[1]
leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)] leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)]
return leadtimelist return leadtimelist
def config_plots(self, metric_values): def config_plots(self):
self.leadtimelist = MetaPostprocess.get_lead_time_labels(metric_values) self.leadtimelist = self.get_lead_time_labels()
self.labels = self.get_labels() self.labels = self.get_labels()
self.markers = self.f["markers"] self.markers = self.f["markers"]
self.colors = self.f["colors"] self.colors = self.f["colors"]
self.n_leadtime = len(self.leadtimelist)
@staticmethod @staticmethod
def map_ylabels(metric): def map_ylabels(metric):
...@@ -169,35 +199,27 @@ class MetaPostprocess(object): ...@@ -169,35 +199,27 @@ class MetaPostprocess(object):
raise ("The metric is not correct!") raise ("The metric is not correct!")
return ylabel return ylabel
def plot_scores(self, metric_values): def plot_abs_scores(self):
self.config_plots()
self.config_plots(metric_values)
if self.enable_skill_scores:
self.plot_skill_scores(metric_values)
else:
self.plot_abs_scores(metric_values)
def plot_abs_scores(self, metric_values: list = None):
n_leadtime = len(self.leadtimelist)
fig = plt.figure(figsize = (8, 6)) fig = plt.figure(figsize = (8, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
for i in range(len(metric_values)): for i in range(len(self.metric_values)):
score_plot = np.nanquantile(metric_values[i][1], 0.5, axis = 0) score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0)
plt.plot(np.arange(1, 1 + n_leadtime), score_plot, label = self.labels[i], color = self.colors[i], plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, label = self.labels[i], color = self.colors[i],
marker = self.markers[i], markeredgecolor = 'k', linewidth = 1.2) marker = self.markers[i], markeredgecolor = 'k', linewidth = 1.2)
plt.fill_between(np.arange(1, 1 + n_leadtime), plt.fill_between(np.arange(1, 1 + self.n_leadtime),
np.nanquantile(metric_values[i][1], 0.95, axis = 0), np.nanquantile(self.metric_values[i][1], 0.95, axis = 0),
np.nanquantile(metric_values[i][1], 0.05, axis = 0), color = self.colors[i], alpha = 0.2) np.nanquantile(self.metric_values[i][1], 0.05, axis = 0), color = self.colors[i],
alpha = 0.2)
if self.models_type[i] == "convLSTM": if self.models_type[i] == "convLSTM":
score_plot = np.nanquantile(metric_values[i][0], 0.5, axis = 0) score_plot = np.nanquantile(self.metric_values[i][0], 0.5, axis = 0)
plt.plot(np.arange(1, 1 + n_leadtime), score_plot, label = "Persi_cv" + str(i), color = self.colors[i], plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, label = "Persi_cv" + str(i),
marker = "D", markeredgecolor = 'k', linewidth = 1.2) color = self.colors[i], marker = "D", markeredgecolor = 'k', linewidth = 1.2)
plt.fill_between(np.arange(1, 1 + n_leadtime), plt.fill_between(np.arange(1, 1 + self.n_leadtime),
np.nanquantile(metric_values[i][0], 0.95, axis = 0), np.nanquantile(self.metric_values[i][0], 0.95, axis = 0),
np.nanquantile(metric_values[i][0], 0.05, axis = 0), color = "b", alpha = 0.2) np.nanquantile(self.metric_values[i][0], 0.05, axis = 0), color = "b", alpha = 0.2)
plt.yticks(fontsize = 16) plt.yticks(fontsize = 16)
plt.xticks(np.arange(1, 13), np.arange(1, 13, 1), fontsize = 16) plt.xticks(np.arange(1, 13), np.arange(1, 13, 1), fontsize = 16)
...@@ -206,18 +228,59 @@ class MetaPostprocess(object): ...@@ -206,18 +228,59 @@ class MetaPostprocess(object):
ylabel = MetaPostprocess.map_ylabels(self.metric) ylabel = MetaPostprocess.map_ylabels(self.metric)
ax.set_xlabel("Lead time (hours)", fontsize = 21) ax.set_xlabel("Lead time (hours)", fontsize = 21)
ax.set_ylabel(ylabel, fontsize = 21) ax.set_ylabel(ylabel, fontsize = 21)
fig_path = os.path.join(self.analysis_dir, self.metric + "abs_values.png") fig_path = os.path.join(self.analysis_dir, self.metric + "_abs_values.png")
# fig_path = os.path.join(prefix,fig_name) # fig_path = os.path.join(prefix,fig_name)
plt.savefig(fig_path, bbox_inches = "tight") plt.savefig(fig_path, bbox_inches = "tight")
plt.show() plt.show()
plt.close() plt.close()
print("The plot saved to {}".format(fig_path)) print("The plot saved to {}".format(fig_path))
def plot_skill_scores(self):
"""
Plot the skill scores once the enable_skill is True
"""
self.config_plots()
fig = plt.figure(figsize = (8, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
for i in range(len(self.skill_scores)):
if self.models_type[i] == "convLSTM":
c = "r"
elif self.models_type[i] == "savp":
c = "b"
else:
raise ("current only support convLSTM and SAVP for plotinig the skil scores")
plt.boxplot(self.skill_scores[i], positions = np.arange(1, self.n_leadtime + 1), medianprops = {'color': c},
capprops = {'color': c}, boxprops = {'color': c}, showfliers = False)
score_plot = np.nanquantile(self.skill_scores[i], 0.5, axis = 0)
plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, color = c, linewidth = 1.2, label = self.labels[i])
legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95), fontsize = 14)
plt.yticks(fontsize = 16)
plt.xticks(np.arange(1, 13), np.arange(1, 13, 1), fontsize = 16)
ax.set_xlabel("Lead time (hours)", fontsize = 21)
ax.set_ylabel("Skill scores of {}".format(self.metric), fontsize = 21)
fig_path = os.path.join(self.analysis_dir, self.metric + "_skill_scores.png")
plt.savefig(fig_path, bbox_inches = "tight")
plt.show()
plt.close()
print("The plot saved to {}".format(fig_path))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--analysis_config", type=str, required=True, help="The path points to the meta_postprocess configuration file.",
default="../meta_postprocess_config/meta_config.json")
parser.add_argument("--metric", help="Based on which the models are compared, the value should be in one of [mse,ssim,acc,texture]",default="mse")
parser.add_argument("--exp_id", help="The experiment id which will be used as postfix of the output directory",default="exp1")
parser.add_argument("--enable_skill_scores", help="compared by skill scores or the absolute evaluation values",default=True)
args = parser.parse_args()
meta = MetaPostprocess(analysis_config=args.analysis_config, metric=args.metric, exp_id=args.metric,
enable_skill_scores=args.enable_skill_scores)
meta()
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment