Select Git revision
Data pipeline.ipynb
main_meta_postprocess.py 12.23 KiB
# coding=utf-8
# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
#
# SPDX-License-Identifier: MIT
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong, Yan Ji"
__date__ = "2020-12-04"
__updatedate__ = "2022-02-02"
import argparse
import os
from matplotlib.pylab import plt
import json
import numpy as np
import shutil
import glob
from netCDF4 import Dataset
import xarray as xr
def skill_score(tar_score,ref_score,best_score):
ss = (tar_score-ref_score) / (best_score-ref_score)
return ss
class MetaPostprocess(object):
def __init__(self, root_dir: str = "/p/project/deepacf/deeprain/video_prediction_shared_folder/",
analysis_config: str = None, metric: str = "mse", exp_id: str=None, enable_skill_scores:bool=False):
"""
This class is used for calculating the evaluation metric, analyize the models' results and make comparsion
args:
root_dir :str, the root directory for the shared folder
analysis_config :str, the path pointing to the analysis_configuration json file
analysis_dir :str, the path to save the analysis results
metric :str, based on which evalution metric for comparison, "mse","ssim", "texture" and "acc"
exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot
enable_skill_scores:bool, the
"""
self.root_dir = root_dir
self.analysis_config = analysis_config
self.analysis_dir = os.path.join(root_dir, "meta_postprocess", exp_id)
self.metric = metric
self.exp_id = exp_id
self.enable_skill_scores = enable_skill_scores
self.models_type = []
self.metric_values = [] # return the shape: [num_results, persi_values, model_values]
self.skill_scores = [] # contain the calculated skill scores [num_results, skill_scores_values]
def __call__(self):
self.sanity_check()
self.create_analysis_dir()
self.copy_analysis_config()
self.load_analysis_config()
self.get_metrics_values()
self.calculate_skill_scores()
if self.enable_skill_scores:
self.plot_skill_scores()
else:
self.plot_abs_scores()
def sanity_check(self):
self.available_metrics = ["mse", "ssim", "texture", "acc"]
if self.metric not in self.available_metrics:
raise ("The 'metric' must be one of the following:", available_metrics)
if type(self.exp_id) is not str:
raise ("'exp_id' must be 'str' type ")
def create_analysis_dir(self):
"""
Function to create the analysis directory if it does not exist
"""
if not os.path.exists(self.analysis_dir): os.makedirs(self.analysis_dir)
print("1. Create analysis dir successfully: The result will be stored to the folder:", self.analysis_dir)
def copy_analysis_config(self):
"""
Copy the analysis configuration json to the analysis directory
"""
try:
shutil.copy(self.analysis_config, os.path.join(self.analysis_dir, "meta_config.json"))
self.analysis_config = os.path.join(self.analysis_dir, "meta_config.json")
print("2. Copy analysis config successs ")
except Exception as e:
print("The meta_config.json is not found in the dictory: ", self.analysis_config)
return None
def load_analysis_config(self):
"""
Read the configuration values from the analysis configuration json file
"""
with open(self.analysis_config) as f:
self.f = json.load(f)
print("*****The following results will be compared and ploted*****")
[print(i) for i in self.f["results"].values()]
print("*******************************************************")
print("3. Loading analysis config success")
return None
def get_labels(self):
labels = list(self.f["results"].keys())
return labels
def get_meta_info(self):
"""
get the model types meta information of the results from the options_checkpoints json file from postprocess stage
"""
for i in self.f["results"].values():
option_checkpoints = os.path.join(i, "options_checkpoints.json")
with open(option_checkpoints) as f:
m = json.load(f)
self.models_type.append(m["model"])
return None
def get_metrics_values(self):
"""
get the evaluation metric values of all the results, return a list [results,persi, model]
"""
self.get_meta_info()
for i, result_dir in enumerate(self.f["results"].values()):
vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i])
self.metric_values.append(vals)
print("4. Get metrics values success")
return self.metric_values
@staticmethod
def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None):
"""
obtain the metric values (persistence and DL model) in the "evaluation_metrics.nc" file
return: list contains the evaluatioin metrics of one result. [persi,model]
"""
filename = 'evaluation_metrics.nc'
filepath = os.path.join(result_dir, filename)
try:
with xr.open_dataset(filepath) as dfiles:
persi = np.array(dfiles['2t_persistence_{}_bootstrapped'.format(metric)][:])
model = np.array(dfiles['2t_{}_{}_bootstrapped'.format(model, metric)][:])
print("The values for evaluation metric '{}' values are obtained from file {}".format(metric, filepath))
return [persi, model]
except Exception as e:
print("!! The evallution metrics retrive from the {} fails".format(filepath))
print(e)
def calculate_skill_scores(self):
"""
calculate the skill scores
"""
if self.metric_values is None:
raise ("metric_values should be a list but None is provided")
best_score = 0
if self.metric == "mse":
pass
elif self.metric in ["ssim", "acc", "texture"]:
best_score = 1
else:
raise ("The metric should be one of the following available metrics :", self.available_metrics)
if self.enable_skill_scores:
for i in range(len(self.metric_values)):
skill_val = skill_score(self.metric_values[i][1], self.metric_values[i][0], best_score)
self.skill_scores.append(skill_val)
return self.skill_scores
else:
return None
def get_lead_time_labels(self):
leadtimes = self.metric_values[0][0].shape[1]
leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)]
return leadtimelist
def config_plots(self):
self.leadtimelist = self.get_lead_time_labels()
self.labels = self.get_labels()
self.markers = self.f["markers"]
self.colors = self.f["colors"]
self.n_leadtime = len(self.leadtimelist)
@staticmethod
def map_ylabels(metric):
if metric == "mse":
ylabel = "MSE"
elif metric == "acc":
ylabel = "ACC"
elif metric == "ssim":
ylabel = "SSIM"
elif metric == "texture":
ylabel = "Ratio of gradient ($r_G$)"
else:
raise ("The metric is not correct!")
return ylabel
def plot_abs_scores(self):
self.config_plots()
fig = plt.figure(figsize = (8, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
for i in range(len(self.metric_values)):
score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0)
plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, label = self.labels[i], color = self.colors[i],
marker = self.markers[i], markeredgecolor = 'k', linewidth = 1.2)
plt.fill_between(np.arange(1, 1 + self.n_leadtime),
np.nanquantile(self.metric_values[i][1], 0.95, axis = 0),
np.nanquantile(self.metric_values[i][1], 0.05, axis = 0), color = self.colors[i],
alpha = 0.2)
if self.models_type[i] == "convLSTM":
score_plot = np.nanquantile(self.metric_values[i][0], 0.5, axis = 0)
plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, label = "Persi_cv" + str(i),
color = self.colors[i], marker = "D", markeredgecolor = 'k', linewidth = 1.2)
plt.fill_between(np.arange(1, 1 + self.n_leadtime),
np.nanquantile(self.metric_values[i][0], 0.95, axis = 0),
np.nanquantile(self.metric_values[i][0], 0.05, axis = 0), color = "b", alpha = 0.2)
plt.yticks(fontsize = 16)
plt.xticks(np.arange(1, 13), np.arange(1, 13, 1), fontsize = 16)
legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95),
fontsize = 14) # 'upper right', bbox_to_anchor=(1.38, 0.8),
ylabel = MetaPostprocess.map_ylabels(self.metric)
ax.set_xlabel("Lead time (hours)", fontsize = 21)
ax.set_ylabel(ylabel, fontsize = 21)
fig_path = os.path.join(self.analysis_dir, self.metric + "_abs_values.png")
# fig_path = os.path.join(prefix,fig_name)
plt.savefig(fig_path, bbox_inches = "tight")
plt.show()
plt.close()
print("The plot saved to {}".format(fig_path))
def plot_skill_scores(self):
"""
Plot the skill scores once the enable_skill is True
"""
self.config_plots()
fig = plt.figure(figsize = (8, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
for i in range(len(self.skill_scores)):
if self.models_type[i] == "convLSTM":
c = "r"
elif self.models_type[i] == "savp":
c = "b"
else:
raise ("current only support convLSTM and SAVP for plotinig the skil scores")
plt.boxplot(self.skill_scores[i], positions = np.arange(1, self.n_leadtime + 1), medianprops = {'color': c},
capprops = {'color': c}, boxprops = {'color': c}, showfliers = False)
score_plot = np.nanquantile(self.skill_scores[i], 0.5, axis = 0)
plt.plot(np.arange(1, 1 + self.n_leadtime), score_plot, color = c, linewidth = 1.2, label = self.labels[i])
legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95), fontsize = 14)
plt.yticks(fontsize = 16)
plt.xticks(np.arange(1, 13), np.arange(1, 13, 1), fontsize = 16)
ax.set_xlabel("Lead time (hours)", fontsize = 21)
ax.set_ylabel("Skill scores of {}".format(self.metric), fontsize = 21)
fig_path = os.path.join(self.analysis_dir, self.metric + "_skill_scores.png")
plt.savefig(fig_path, bbox_inches = "tight")
plt.show()
plt.close()
print("The plot saved to {}".format(fig_path))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--root_dir", type=str, required=True, help="The root path for output dir")
parser.add_argument("--analysis_config", type=str, required=True, help="The path points to the meta_postprocess configuration file.",
default="../meta_postprocess_config/meta_config.json")
parser.add_argument("--metric", help="Based on which the models are compared, the value should be in one of [mse,ssim,acc,texture]",default="mse")
parser.add_argument("--exp_id", help="The experiment id which will be used as postfix of the output directory",default="exp1")
parser.add_argument("--enable_skill_scores", help="compared by skill scores or the absolute evaluation values",default=True)
args = parser.parse_args()
meta = MetaPostprocess(root_dir=args.root_dir,analysis_config=args.analysis_config, metric=args.metric, exp_id=args.metric,
enable_skill_scores=args.enable_skill_scores)
meta()
if __name__ == '__main__':
main()