diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 370a6009586dfc62f5201c6e422bb88764ff0d12..1a8e75e194bafa6aad3adec42cc35edc466072c7 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -18,14 +18,15 @@ import datetime as dt import json from typing import Union, List # own modules +from general_utils import get_era5_varatts from normalization import Norm_data from general_utils import check_dir from metadata import MetaData as MetaData from main_scripts.main_train_models import * from data_preprocess.preprocess_data_step2 import * from model_modules.video_prediction import datasets, models, metrics -from statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, Scores -from postprocess_plotting import plot_avg_eval_metrics, create_geo_contour_plot +from statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, calculate_cond_quantiles, Scores +from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, create_geo_contour_plot class Postprocess(TrainModel): @@ -726,6 +727,105 @@ class Postprocess(TrainModel): _ = plot_avg_eval_metrics(self.eval_metrics_ds, self.eval_metrics, self.fcst_products, self.vars_in[self.channel], self.results_dir) + def plot_example_forecasts(self, metric="mse", channel=0): + """ + Plots example forecasts. The forecasts are chosen from the complete pool of the test dataset and are chosen + according to the accuracy in terms of the chosen metric. In add ition, to the best and worst forecast, + every decil of the chosen metric is retrieved to cover the whole bandwith of forecasts. + :param metric: The metric which is used for measuring accuracy + :param channel: The channel index of the forecasted variable to plot (correspondong to self.vars_in) + :return: 11 exemplary forecast plots are created + """ + method = Postprocess.plot_example_forecasts.__name__ + + metric_name = "{0}_{1}_{2}".format(self.vars_in[channel], self.model, metric) + if not metric_name in self.eval_metrics_ds: + raise ValueError("%{0}: Cannot find requested evaluation metric '{1}'".format(method, metric_name) + + " onto which selection of plotted forecast is done.") + # average metric of interest and obtain quantiles incl. indices + metric_mean = self.eval_metrics_ds[metric_name].mean(dim="fcst_hour") + quantiles = np.arange(0., 1.01, .1) + quantiles_val = metric_mean.quantile(quantiles, interpolation="nearest") + quantiles_inds = self.get_matching_indices(metric_mean.values, quantiles_val) + + for i, ifcst in enumerate(quantiles_inds): + date_init = pd.to_datetime(metric_mean.coords["init_time"][ifcst].data) + nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc" + .format(date_init.strftime("%Y%m%d%H"), ifcst)) + if not os.path.isfile(nc_fname): + raise FileNotFoundError("%{0}: Could not find requested file '{1}'".format(method, nc_fname)) + else: + # get the data + varname = self.vars_in[channel] + with xr.open_dataset(nc_fname) as dfile: + data_fcst = dfile["{0}_{1}_fcst".format(varname, self.model)] + data_ref = dfile["{0}_ref".format(varname)] + + data_diff = data_fcst - data_ref + # name of plot + plt_fname_base = os.path.join(self.output_dir, "forecast_{0}_{1}_{2}_{3:d}percentile.png" + .format(varname, date_init.strftime("%Y%m%dT%H00"), metric, + int(quantiles[i] * 100.))) + + create_geo_contour_plot(data_fcst, data_diff, varname, plt_fname_base) + + def plot_conditional_quantiles(self): + + # release some memory + Postprocess.clean_obj_attribute(self, "eval_metrics_ds") + + + cond_quantile_vars = ["{0}_{1}_fcst".format(self.vars_in[self.channel], self.model), + "{0}_ref".format(self.vars_in[self.channel])] + var_fcst = "{0}_{1}_fcst".format(self.vars_in[self.channel], self.model) + var_ref = "{0}_ref".format(self.vars_in[self.channel]) + + get_era5_varatts(self.cond_quantiple_ds[var_fcst], self.cond_quantiple_ds[var_fcst].name) + get_era5_varatts(self.cond_quantiple_ds[var_fcst], self.cond_quantiple_ds[var_fcst].name) + + # create plots + # calibration refinement factorization + plt_fname_cf = os.path.join(self.results_dir, "cond_quantile_{0}_{1}_calibration_refinement.png" + .format(self.vars_in[self.channel], self.model)) + + quantile_panel_cf, cond_variable_cf = calculate_cond_quantiles(self.cond_quantiple_ds[var_fcst], + self.cond_quantiple_ds[var_ref], + factorization="calibration_refinement", + quantiles=(0.05, 0.5, 0.95)) + + plot_cond_quantile(quantile_panel_cf, cond_variable_cf, plt_fname_cf) + + # likelihood-base rate factorization + plt_fname_lbr = plt_fname_cf.replace("calibration_refinement", "likelihood-base_rate") + quantile_panel_lbr, cond_variable_lbr = calculate_cond_quantiles(self.cond_quantiple_ds[var_fcst], + self.cond_quantiple_ds[var_ref], + factorization="likelihood-base_rate", + quantiles=(0.05, 0.5, 0.95)) + + plot_cond_quantile(quantile_panel_lbr, cond_variable_lbr, plt_fname_lbr) + + + @staticmethod + def clean_obj_attribute(obj, attr_name, lremove=False): + """ + Cleans attribute of object by setting it to None (can be used to releave memory) + :param obj: the object/ class instance + :param attr_name: the attribute from the object to be cleaned + :param lremove: flag if attribute is removed or set to None + :return: the object/class instance with the attribute's value changed to None + """ + method = Postprocess.clean_obj_attribute.__name__ + + if not hasattr(obj, attr_name): + print("%{0}: Class attribute '{1}' does not exist. Nothing to do...".format(method, attr_name)) + else: + if lremove: + delattr(obj, attr_name) + else: + setattr(obj, attr_name, None) + + return obj + # auxiliary methods (not necessarily bound to class instance) @staticmethod def get_norm(varnames, stat_fl, norm_method): @@ -1010,48 +1110,6 @@ class Postprocess(TrainModel): return ds_preexist - def plot_example_forecasts(self, metric="mse", channel=0): - """ - Plots example forecasts. The forecasts are chosen from the complete pool of the test dataset and are chosen - according to the accuracy in terms of the chosen metric. In add ition, to the best and worst forecast, - every decil of the chosen metric is retrieved to cover the whole bandwith of forecasts. - :param metric: The metric which is used for measuring accuracy - :param channel: The channel index of the forecasted variable to plot (correspondong to self.vars_in) - :return: 11 exemplary forecast plots are created - """ - method = Postprocess.plot_example_forecasts.__name__ - - metric_name = "{0}_{1}_{2}".format(self.vars_in[channel], self.model, metric) - if not metric_name in self.eval_metrics_ds: - raise ValueError("%{0}: Cannot find requested evaluation metric '{1}'".format(method, metric_name) + - " onto which selection of plotted forecast is done.") - # average metric of interest and obtain quantiles incl. indices - metric_mean = self.eval_metrics_ds[metric_name].mean(dim="fcst_hour") - quantiles = np.arange(0., 1.01, .1) - quantiles_val = metric_mean.quantile(quantiles, interpolation="nearest") - quantiles_inds = self.get_matching_indices(metric_mean.values, quantiles_val) - - for i, ifcst in enumerate(quantiles_inds): - date_init = pd.to_datetime(metric_mean.coords["init_time"][ifcst].data) - nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc" - .format(date_init.strftime("%Y%m%d%H"), ifcst)) - if not os.path.isfile(nc_fname): - raise FileNotFoundError("%{0}: Could not find requested file '{1}'".format(method, nc_fname)) - else: - # get the data - varname = self.vars_in[channel] - with xr.open_dataset(nc_fname) as dfile: - data_fcst = dfile["{0}_{1}_fcst".format(varname, self.model)] - data_ref = dfile["{0}_ref".format(varname)] - - data_diff = data_fcst - data_ref - # name of plot - plt_fname_base = os.path.join(self.output_dir, "forecast_{0}_{1}_{2}_{3:d}percentile.png" - .format(varname, date_init.strftime("%Y%m%dT%H00"), metric, - int(quantiles[i]*100.))) - - create_geo_contour_plot(data_fcst, data_diff, varname, plt_fname_base) - @staticmethod def init_metric_ds(fcst_products, eval_metrics, varname, nsamples, nlead_steps): """ @@ -1102,6 +1160,8 @@ def main(): parser.add_argument("--seed", type=int, default=7) parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", args="+", default=["mse, psnr"], help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.") + parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0, + help="Channel which is used for evaluation.") args = parser.parse_args() print('----------------------------------- Options ------------------------------------') @@ -1113,11 +1173,11 @@ def main(): postproc_instance = Postprocess(results_dir=args.results_dir, checkpoint=args.checkpoint, mode="test", batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples, gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, - eval_metrics=args.eval_metrics) + eval_metrics=args.eval_metrics, channel=args.channel) # run the postprocessing postproc_instance.run() postproc_instance.handle_eval_metrics() - postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0]) + postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel) if __name__ == '__main__':