Skip to content
Snippets Groups Projects
Commit 50e26227 authored by masak1112's avatar masak1112
Browse files

fix the bug for assertation in main_meta_postprocess

parent f32446fb
Branches
Tags
No related merge requests found
Pipeline #102134 passed
...@@ -62,6 +62,7 @@ class MetaPostprocess(object): ...@@ -62,6 +62,7 @@ class MetaPostprocess(object):
self.load_analysis_config() self.load_analysis_config()
self.get_metrics_values() self.get_metrics_values()
if self.enable_skill_scores: if self.enable_skill_scores:
print("Enable the skill scores")
self.calculate_skill_scores() self.calculate_skill_scores()
self.plot_skill_scores() self.plot_skill_scores()
else: else:
...@@ -80,7 +81,7 @@ class MetaPostprocess(object): ...@@ -80,7 +81,7 @@ class MetaPostprocess(object):
Function to create the analysis directory if it does not exist Function to create the analysis directory if it does not exist
""" """
if not os.path.exists(self.analysis_dir): os.makedirs(self.analysis_dir) if not os.path.exists(self.analysis_dir): os.makedirs(self.analysis_dir)
print("1. Create analysis dir successfully: The result will be stored to the folder:", self.analysis_dir) print("Create analysis dir successfully: The result will be stored to the folder:", self.analysis_dir)
def copy_analysis_config(self): def copy_analysis_config(self):
""" """
...@@ -89,7 +90,7 @@ class MetaPostprocess(object): ...@@ -89,7 +90,7 @@ class MetaPostprocess(object):
try: try:
shutil.copy(self.analysis_config, os.path.join(self.analysis_dir, "meta_config.json")) shutil.copy(self.analysis_config, os.path.join(self.analysis_dir, "meta_config.json"))
self.analysis_config = os.path.join(self.analysis_dir, "meta_config.json") self.analysis_config = os.path.join(self.analysis_dir, "meta_config.json")
print("2. Copy analysis config successs ") print("Copy analysis config successs ")
except Exception as e: except Exception as e:
print("The meta_config.json is not found in the dictory: ", self.analysis_config) print("The meta_config.json is not found in the dictory: ", self.analysis_config)
return None return None
...@@ -104,7 +105,7 @@ class MetaPostprocess(object): ...@@ -104,7 +105,7 @@ class MetaPostprocess(object):
print("*****The following results will be compared and ploted*****") print("*****The following results will be compared and ploted*****")
[print(i) for i in self.f["results"].values()] [print(i) for i in self.f["results"].values()]
print("*******************************************************") print("*******************************************************")
print("3. Loading analysis config success") print("Loading analysis config success")
return None return None
...@@ -133,7 +134,7 @@ class MetaPostprocess(object): ...@@ -133,7 +134,7 @@ class MetaPostprocess(object):
for i, result_dir in enumerate(self.f["results"].values()): for i, result_dir in enumerate(self.f["results"].values()):
vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i],self.enable_skill_scores) vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i],self.enable_skill_scores)
self.metric_values.append(vals) self.metric_values.append(vals)
print("4. Get metrics values success") print(" Get metrics values success")
return self.metric_values return self.metric_values
@staticmethod @staticmethod
...@@ -184,7 +185,8 @@ class MetaPostprocess(object): ...@@ -184,7 +185,8 @@ class MetaPostprocess(object):
return None return None
def get_lead_time_labels(self): def get_lead_time_labels(self):
assert len(self.metric_values) == 2 assert len(self.metric_values[0]) == 2
leadtimes = np.array(self.metric_values[0][1]).shape[1] leadtimes = np.array(self.metric_values[0][1]).shape[1]
leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)] leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)]
return leadtimelist return leadtimelist
...@@ -216,7 +218,7 @@ class MetaPostprocess(object): ...@@ -216,7 +218,7 @@ class MetaPostprocess(object):
fig = plt.figure(figsize = (8, 6)) fig = plt.figure(figsize = (8, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
for i in range(len(self.metric_values)): #loop number of test samples for i in range(len(self.metric_values)): #loop number of test samples
assert len(self.metric_values)==2 assert len(self.metric_values[0])==2
score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0) score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0)
assert len(score_plot) == self.n_leadtime assert len(score_plot) == self.n_leadtime
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment