diff --git a/video_prediction_tools/main_scripts/main_meta_postprocess.py b/video_prediction_tools/main_scripts/main_meta_postprocess.py index aa07ed65849efd7aea6b31d968124e1c1fbc5b46..413c8187ced5b0c85e430c2721e98de7b7812798 100644 --- a/video_prediction_tools/main_scripts/main_meta_postprocess.py +++ b/video_prediction_tools/main_scripts/main_meta_postprocess.py @@ -62,6 +62,7 @@ class MetaPostprocess(object): self.load_analysis_config() self.get_metrics_values() if self.enable_skill_scores: + print("Enable the skill scores") self.calculate_skill_scores() self.plot_skill_scores() else: @@ -80,7 +81,7 @@ class MetaPostprocess(object): Function to create the analysis directory if it does not exist """ if not os.path.exists(self.analysis_dir): os.makedirs(self.analysis_dir) - print("1. Create analysis dir successfully: The result will be stored to the folder:", self.analysis_dir) + print("Create analysis dir successfully: The result will be stored to the folder:", self.analysis_dir) def copy_analysis_config(self): """ @@ -89,7 +90,7 @@ class MetaPostprocess(object): try: shutil.copy(self.analysis_config, os.path.join(self.analysis_dir, "meta_config.json")) self.analysis_config = os.path.join(self.analysis_dir, "meta_config.json") - print("2. Copy analysis config successs ") + print("Copy analysis config successs ") except Exception as e: print("The meta_config.json is not found in the dictory: ", self.analysis_config) return None @@ -104,7 +105,7 @@ class MetaPostprocess(object): print("*****The following results will be compared and ploted*****") [print(i) for i in self.f["results"].values()] print("*******************************************************") - print("3. Loading analysis config success") + print("Loading analysis config success") return None @@ -133,7 +134,7 @@ class MetaPostprocess(object): for i, result_dir in enumerate(self.f["results"].values()): vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i],self.enable_skill_scores) self.metric_values.append(vals) - print("4. Get metrics values success") + print(" Get metrics values success") return self.metric_values @staticmethod @@ -184,7 +185,8 @@ class MetaPostprocess(object): return None def get_lead_time_labels(self): - assert len(self.metric_values) == 2 + assert len(self.metric_values[0]) == 2 + leadtimes = np.array(self.metric_values[0][1]).shape[1] leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)] return leadtimelist @@ -216,7 +218,7 @@ class MetaPostprocess(object): fig = plt.figure(figsize = (8, 6)) ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) for i in range(len(self.metric_values)): #loop number of test samples - assert len(self.metric_values)==2 + assert len(self.metric_values[0])==2 score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0) assert len(score_plot) == self.n_leadtime