Skip to content
Snippets Groups Projects
Commit 7fb1d8aa authored by masak1112's avatar masak1112
Browse files

add metric_filesnames as argument in metapostprocessing step

parent 50e26227
No related branches found
No related tags found
No related merge requests found
Pipeline #104589 passed
......@@ -31,7 +31,8 @@ def skill_score(tar_score,ref_score,best_score):
class MetaPostprocess(object):
def __init__(self, root_dir: str = "/p/project/deepacf/deeprain/video_prediction_shared_folder/",
analysis_config: str = None, metric: str = "mse", exp_id: str=None, enable_skill_scores:bool=False, enable_persit_plot:bool=False):
analysis_config: str = None, metric: str = "mse", exp_id: str=None,
enable_skill_scores:bool=False, enable_persit_plot:bool=False, metrics_filename="evaluation_metrics.nc"):
"""
This class is used for calculating the evaluation metric, analyize the models' results and make comparsion
args:
......@@ -42,6 +43,7 @@ class MetaPostprocess(object):
exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot
enable_skill_scores:bool, enable the skill scores plot
enable_persis_plot: bool, enable the persis prediction in the plot
metrics_filename :str , the .nc file stores the evaluation metrics
"""
self.root_dir = root_dir
self.analysis_config = analysis_config
......@@ -50,6 +52,7 @@ class MetaPostprocess(object):
self.exp_id = exp_id
self.persist = enable_persit_plot
self.enable_skill_scores = enable_skill_scores
self.metrics_filename = metrics_filename
self.models_type = []
self.metric_values = [] # return the shape: [num_results, persi_values, model_values]
self.skill_scores = [] # contain the calculated skill scores [num_results, skill_scores_values]
......@@ -132,27 +135,31 @@ class MetaPostprocess(object):
self.get_meta_info()
for i, result_dir in enumerate(self.f["results"].values()):
vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i],self.enable_skill_scores)
vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i],self.enable_skill_scores,self.metrics_filename)
self.metric_values.append(vals)
print(" Get metrics values success")
return self.metric_values
@staticmethod
def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None, enable_skill_scores:bool = False):
def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None, enable_skill_scores:bool = False, metrics_filename: str = "evaluation_metrics.nc"):
"""
obtain the metric values (persistence and DL model) in the "evaluation_metrics.nc" file
return: list contains the evaluatioin metrics of one result. [persi,model]
"""
filename = 'evaluation_metrics.nc'
filename = metrics_filename
filepath = os.path.join(result_dir, filename)
try:
with xr.open_dataset(filepath) as dfiles:
with xr.open_dataset(filepath,engine="netcdf4") as dfiles:
if enable_skill_scores:
persi = np.array(dfiles['2t_persistence_{}_bootstrapped'.format(metric)][:])
persi = np.array(dfiles['2t_persistence_{}_bootstrapped'.format(metriic)][:])
if persi.shape[0]<30: #20210713T143850_gong1_savp_t2opt_3vars/evaluation_metrics_72x44.nc shape is not correct
persi = np.transpose(persi)
else:
persi = []
model = np.array(dfiles['2t_{}_{}_bootstrapped'.format(model, metric)][:])
if model.shape[0]<30:
model = np.transpose(model)
print("The values for evaluation metric '{}' values are obtained from file {}".format(metric, filepath))
return [persi, model]
except Exception as e:
......@@ -201,7 +208,7 @@ class MetaPostprocess(object):
@staticmethod
def map_ylabels(metric):
if metric == "mse":
ylabel = "MSE"
ylabel = "MSE[K$^2$]"
elif metric == "acc":
ylabel = "ACC"
elif metric == "ssim":
......@@ -220,7 +227,8 @@ class MetaPostprocess(object):
for i in range(len(self.metric_values)): #loop number of test samples
assert len(self.metric_values[0])==2
score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0)
print("score_plot",len(score_plot))
print("self.n_leadtime",self.n_leadtime)
assert len(score_plot) == self.n_leadtime
plt.plot(np.arange(1, 1 + self.n_leadtime), list(score_plot),label = self.labels[i], color = self.colors[i],
marker = self.markers[i], markeredgecolor = 'k', linewidth = 1.2)
......@@ -240,11 +248,12 @@ class MetaPostprocess(object):
plt.yticks(fontsize = 16)
plt.xticks(np.arange(1, self.n_leadtime+1), np.arange(1, self.n_leadtime + 1, 1), fontsize = 16)
legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95),
fontsize = 14) # 'upper right', bbox_to_anchor=(1.38, 0.8),
legend = ax.legend(loc = 'upper right', bbox_to_anchor = (0.92, 0.40),
fontsize = 12) # 'upper right', bbox_to_anchor=(1.38, 0.8),
ylabel = MetaPostprocess.map_ylabels(self.metric)
ax.set_xlabel("Lead time (hours)", fontsize = 21)
ax.set_ylabel(ylabel, fontsize = 21)
plt.title("Sensitivity analysis for domain sizes",fontsize=16)
fig_path = os.path.join(self.analysis_dir, self.metric + "_abs_values.png")
# fig_path = os.path.join(prefix,fig_name)
plt.savefig(fig_path, bbox_inches = "tight")
......@@ -293,10 +302,11 @@ def main():
parser.add_argument("--exp_id", help="The experiment id which will be used as postfix of the output directory",default="exp1")
parser.add_argument("--enable_skill_scores", help="compared by skill scores or the absolute evaluation values",default=False)
parser.add_argument("--enable_persit_plot", help="If plot persistent foreasts",default=False)
parser.add_argument("--metrics_filename", help="The .nc file contain the evaluation metrics",default="evaluation_metrics.nc")
args = parser.parse_args()
meta = MetaPostprocess(root_dir=args.root_dir,analysis_config=args.analysis_config, metric=args.metric, exp_id=args.exp_id,
enable_skill_scores=args.enable_skill_scores,enable_persit_plot=args.enable_persit_plot)
enable_skill_scores=args.enable_skill_scores,enable_persit_plot=args.enable_persit_plot, metrics_filename=args.metrics_filename)
meta()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment