Skip to content
Snippets Groups Projects
Commit 88041cad authored by Bing Gong's avatar Bing Gong
Browse files

update main_meta_postprocess to unenable the persist prediction in plot

parent daef8d16
Branches
No related tags found
No related merge requests found
Pipeline #92842 canceled
...@@ -31,7 +31,7 @@ def skill_score(tar_score,ref_score,best_score): ...@@ -31,7 +31,7 @@ def skill_score(tar_score,ref_score,best_score):
class MetaPostprocess(object): class MetaPostprocess(object):
def __init__(self, root_dir: str = "/p/project/deepacf/deeprain/video_prediction_shared_folder/", def __init__(self, root_dir: str = "/p/project/deepacf/deeprain/video_prediction_shared_folder/",
analysis_config: str = None, metric: str = "mse", exp_id: str=None, enable_skill_scores:bool=False): analysis_config: str = None, metric: str = "mse", exp_id: str=None, enable_skill_scores:bool=False, enable_persit_plot:bool=False):
""" """
This class is used for calculating the evaluation metric, analyize the models' results and make comparsion This class is used for calculating the evaluation metric, analyize the models' results and make comparsion
args: args:
...@@ -40,13 +40,15 @@ class MetaPostprocess(object): ...@@ -40,13 +40,15 @@ class MetaPostprocess(object):
analysis_dir :str, the path to save the analysis results analysis_dir :str, the path to save the analysis results
metric :str, based on which evalution metric for comparison, "mse","ssim", "texture" and "acc" metric :str, based on which evalution metric for comparison, "mse","ssim", "texture" and "acc"
exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot
enable_skill_scores:bool, the enable_skill_scores:bool, enable the skill scores plot
enable_persis_plot: bool, enable the persis prediction in the plot
""" """
self.root_dir = root_dir self.root_dir = root_dir
self.analysis_config = analysis_config self.analysis_config = analysis_config
self.analysis_dir = os.path.join(root_dir, "meta_postprocess", exp_id) self.analysis_dir = os.path.join(root_dir, "meta_postprocess", exp_id)
self.metric = metric self.metric = metric
self.exp_id = exp_id self.exp_id = exp_id
self.persist = enable_persit_plot
self.enable_skill_scores = enable_skill_scores self.enable_skill_scores = enable_skill_scores
self.models_type = [] self.models_type = []
self.metric_values = [] # return the shape: [num_results, persi_values, model_values] self.metric_values = [] # return the shape: [num_results, persi_values, model_values]
...@@ -59,8 +61,8 @@ class MetaPostprocess(object): ...@@ -59,8 +61,8 @@ class MetaPostprocess(object):
self.copy_analysis_config() self.copy_analysis_config()
self.load_analysis_config() self.load_analysis_config()
self.get_metrics_values() self.get_metrics_values()
self.calculate_skill_scores()
if self.enable_skill_scores: if self.enable_skill_scores:
self.calculate_skill_scores()
self.plot_skill_scores() self.plot_skill_scores()
else: else:
self.plot_abs_scores() self.plot_abs_scores()
...@@ -129,13 +131,13 @@ class MetaPostprocess(object): ...@@ -129,13 +131,13 @@ class MetaPostprocess(object):
self.get_meta_info() self.get_meta_info()
for i, result_dir in enumerate(self.f["results"].values()): for i, result_dir in enumerate(self.f["results"].values()):
vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i]) vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i],self.enable_skill_scores)
self.metric_values.append(vals) self.metric_values.append(vals)
print("4. Get metrics values success") print("4. Get metrics values success")
return self.metric_values return self.metric_values
@staticmethod @staticmethod
def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None): def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None, enable_skill_scores:bool = False):
""" """
obtain the metric values (persistence and DL model) in the "evaluation_metrics.nc" file obtain the metric values (persistence and DL model) in the "evaluation_metrics.nc" file
...@@ -145,7 +147,10 @@ class MetaPostprocess(object): ...@@ -145,7 +147,10 @@ class MetaPostprocess(object):
filepath = os.path.join(result_dir, filename) filepath = os.path.join(result_dir, filename)
try: try:
with xr.open_dataset(filepath) as dfiles: with xr.open_dataset(filepath) as dfiles:
if enable_skill_scores:
persi = np.array(dfiles['2t_persistence_{}_bootstrapped'.format(metric)][:]) persi = np.array(dfiles['2t_persistence_{}_bootstrapped'.format(metric)][:])
else:
persi = []
model = np.array(dfiles['2t_{}_{}_bootstrapped'.format(model, metric)][:]) model = np.array(dfiles['2t_{}_{}_bootstrapped'.format(model, metric)][:])
print("The values for evaluation metric '{}' values are obtained from file {}".format(metric, filepath)) print("The values for evaluation metric '{}' values are obtained from file {}".format(metric, filepath))
return [persi, model] return [persi, model]
...@@ -179,7 +184,8 @@ class MetaPostprocess(object): ...@@ -179,7 +184,8 @@ class MetaPostprocess(object):
return None return None
def get_lead_time_labels(self): def get_lead_time_labels(self):
leadtimes = self.metric_values[0][0].shape[1] assert len(self.metric_values) == 2
leadtimes = np.array(self.metric_values[0][1]).shape[1]
leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)] leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)]
return leadtimelist return leadtimelist
...@@ -209,16 +215,20 @@ class MetaPostprocess(object): ...@@ -209,16 +215,20 @@ class MetaPostprocess(object):
fig = plt.figure(figsize = (8, 6)) fig = plt.figure(figsize = (8, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
for i in range(len(self.metric_values)): for i in range(len(self.metric_values)): #loop number of test samples
assert len(self.metric_values)==2
score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0) score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0)
plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, label = self.labels[i], color = self.colors[i],
assert len(score_plot) == self.n_leadtime
plt.plot(np.arange(1, 1 + self.n_leadtime), list(score_plot),label = self.labels[i], color = self.colors[i],
marker = self.markers[i], markeredgecolor = 'k', linewidth = 1.2) marker = self.markers[i], markeredgecolor = 'k', linewidth = 1.2)
plt.fill_between(np.arange(1, 1 + self.n_leadtime), plt.fill_between(np.arange(1, 1 + self.n_leadtime),
np.nanquantile(self.metric_values[i][1], 0.95, axis = 0), np.nanquantile(self.metric_values[i][1], 0.95, axis = 0),
np.nanquantile(self.metric_values[i][1], 0.05, axis = 0), color = self.colors[i], np.nanquantile(self.metric_values[i][1], 0.05, axis = 0), color = self.colors[i],
alpha = 0.2) alpha = 0.2)
#only plot the persist prediction when the enabled
if self.persist:
if self.models_type[i] == "convLSTM":
score_plot = np.nanquantile(self.metric_values[i][0], 0.5, axis = 0) score_plot = np.nanquantile(self.metric_values[i][0], 0.5, axis = 0)
plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, label = "Persi_cv" + str(i), plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, label = "Persi_cv" + str(i),
color = self.colors[i], marker = "D", markeredgecolor = 'k', linewidth = 1.2) color = self.colors[i], marker = "D", markeredgecolor = 'k', linewidth = 1.2)
...@@ -227,7 +237,7 @@ class MetaPostprocess(object): ...@@ -227,7 +237,7 @@ class MetaPostprocess(object):
np.nanquantile(self.metric_values[i][0], 0.05, axis = 0), color = "b", alpha = 0.2) np.nanquantile(self.metric_values[i][0], 0.05, axis = 0), color = "b", alpha = 0.2)
plt.yticks(fontsize = 16) plt.yticks(fontsize = 16)
plt.xticks(np.arange(1, 13), np.arange(1, 13, 1), fontsize = 16) plt.xticks(np.arange(1, self.n_leadtime+1), np.arange(1, self.n_leadtime + 1, 1), fontsize = 16)
legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95), legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95),
fontsize = 14) # 'upper right', bbox_to_anchor=(1.38, 0.8), fontsize = 14) # 'upper right', bbox_to_anchor=(1.38, 0.8),
ylabel = MetaPostprocess.map_ylabels(self.metric) ylabel = MetaPostprocess.map_ylabels(self.metric)
...@@ -262,7 +272,7 @@ class MetaPostprocess(object): ...@@ -262,7 +272,7 @@ class MetaPostprocess(object):
legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95), fontsize = 14) legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95), fontsize = 14)
plt.yticks(fontsize = 16) plt.yticks(fontsize = 16)
plt.xticks(np.arange(1, 13), np.arange(1, 13, 1), fontsize = 16) plt.xticks(np.arange(1, self.n_leadtime +1), np.arange(1, self.n_leadtime+1, 1), fontsize = 16)
ax.set_xlabel("Lead time (hours)", fontsize = 21) ax.set_xlabel("Lead time (hours)", fontsize = 21)
ax.set_ylabel("Skill scores of {}".format(self.metric), fontsize = 21) ax.set_ylabel("Skill scores of {}".format(self.metric), fontsize = 21)
fig_path = os.path.join(self.analysis_dir, self.metric + "_skill_scores.png") fig_path = os.path.join(self.analysis_dir, self.metric + "_skill_scores.png")
...@@ -279,11 +289,12 @@ def main(): ...@@ -279,11 +289,12 @@ def main():
default="../meta_postprocess_config/meta_config.json") default="../meta_postprocess_config/meta_config.json")
parser.add_argument("--metric", help="Based on which the models are compared, the value should be in one of [mse,ssim,acc,texture]",default="mse") parser.add_argument("--metric", help="Based on which the models are compared, the value should be in one of [mse,ssim,acc,texture]",default="mse")
parser.add_argument("--exp_id", help="The experiment id which will be used as postfix of the output directory",default="exp1") parser.add_argument("--exp_id", help="The experiment id which will be used as postfix of the output directory",default="exp1")
parser.add_argument("--enable_skill_scores", help="compared by skill scores or the absolute evaluation values",default=True) parser.add_argument("--enable_skill_scores", help="compared by skill scores or the absolute evaluation values",default=False)
parser.add_argument("--enable_persit_plot", help="If plot persistent foreasts",default=False)
args = parser.parse_args() args = parser.parse_args()
meta = MetaPostprocess(root_dir=args.root_dir,analysis_config=args.analysis_config, metric=args.metric, exp_id=args.metric, meta = MetaPostprocess(root_dir=args.root_dir,analysis_config=args.analysis_config, metric=args.metric, exp_id=args.exp_id,
enable_skill_scores=args.enable_skill_scores) enable_skill_scores=args.enable_skill_scores,enable_persit_plot=args.enable_persit_plot)
meta() meta()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment