diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 6557c454284c47614db96ad8fdf0f12d4f3ba018..1a8f1a704de274e2ecc821dabca1d09155b62303 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -16,12 +16,8 @@ import tensorflow as tf import pickle import datetime as dt import json -import matplotlib from typing import Union, List - -matplotlib.use('Agg') -import matplotlib.pyplot as plt -from mpl_toolkits.basemap import Basemap +# own modules from normalization import Norm_data from general_utils import check_dir from metadata import MetaData as MetaData @@ -29,6 +25,7 @@ from main_scripts.main_train_models import * from data_preprocess.preprocess_data_step2 import * from model_modules.video_prediction import datasets, models, metrics from statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, Scores +from plotting import plot_avg_eval_metrics, create_plot class Postprocess(TrainModel): @@ -726,8 +723,8 @@ class Postprocess(TrainModel): Postprocess.save_ds_to_netcdf(self.eval_metrics_ds, nc_fname) # also save averaged metrics to JSON-file and plot it for diagnosis - _ = Postprocess.plot_avg_eval_metrics(self.eval_metrics_ds, self.eval_metrics, self.fcst_products, - self.vars_in[self.channel], self.results_dir) + _ = plot_avg_eval_metrics(self.eval_metrics_ds, self.eval_metrics, self.fcst_products, + self.vars_in[self.channel], self.results_dir) # auxiliary methods (not necessarily bound to class instance) @staticmethod @@ -1053,7 +1050,7 @@ class Postprocess(TrainModel): .format(varname, date_init.strftime("%Y%m%dT%H00"), metric, int(quantiles[i]*100.))) - Postprocess.create_plot(data_fcst, data_diff, varname, plt_fname_base) + create_plot(data_fcst, data_diff, varname, plt_fname_base) @staticmethod def init_metric_ds(fcst_products, eval_metrics, varname, nsamples, nlead_steps): @@ -1090,151 +1087,6 @@ class Postprocess(TrainModel): return indexes - @staticmethod - def plot_avg_eval_metrics(eval_ds, eval_metrics, fcst_prod_dict, varname, out_dir): - """ - Plots error-metrics averaged over all predictions to file incl. 90%-confidence interval that is estimated by - block bootstrapping. - :param eval_ds: The dataset storing all evaluation metrics for each forecast (produced by init_metric_ds-method) - :param eval_metrics: list of evaluation metrics - :param fcst_prod_dict: dictionary of forecast products, e.g. {"persistence": "pfcst"} - :param varname: the variable name for which the evaluation metrics are available - :param out_dir: output directory to save the lots - :return: a bunch of plots as png-files - """ - method = Postprocess.plot_avg_eval_metrics.__name__ - - # settings for block bootstrapping - # sanity checks - if not isinstance(eval_ds, xr.Dataset): - raise ValueError("%{0}: Argument 'eval_ds' must be a xarray dataset.".format(method)) - - if not isinstance(fcst_prod_dict, dict): - raise ValueError("%{0}: Argument 'fcst_prod_dict' must be dictionary with short names of forecast product" + - "as key and long names as value.".format(method)) - - try: - nhours = np.shape(eval_ds.coords["fcst_hour"])[0] - except Exception as err: - print("%{0}: Input argument 'eval_ds' appears to be unproper.".format(method)) - raise err - - nmodels = len(fcst_prod_dict.values()) - colors = ["blue", "red", "black", "grey"] - for metric in eval_metrics: - # create a new figure object - fig = plt.figure(figsize=(6, 4)) - ax = plt.axes([0.1, 0.15, 0.75, 0.75]) - hours = np.arange(1, nhours+1) - - for ifcst, fcst_prod in enumerate(fcst_prod_dict.keys()): - metric_name = "{0}_{1}_{2}".format(varname, fcst_prod, metric) - try: - metric2plt = eval_ds[metric_name+"_avg"] - metric_boot = eval_ds[metric_name+"_bootstrapped"] - except Exception as err: - print("%{0}: Could not retrieve {1} and/or {2} from evaluation metric dataset." - .format(method, metric_name, metric_name+"_boot")) - raise err - # plot the data - metric2plt_min = metric_boot.quantile(0.05, dim="iboot") - metric2plt_max = metric_boot.quantile(0.95, dim="iboot") - plt.plot(hours, metric2plt, label=fcst_prod, color=colors[ifcst], marker="o") - plt.fill_between(hours, metric2plt_min, metric2plt_max, facecolor=colors[ifcst], alpha=0.3) - # configure plot - plt.xticks(hours) - # automatic y-limits for PSNR wich can be negative and positive - if metric != "psnr": ax.set_ylim(0., None) - legend = ax.legend(loc="upper right", bbox_to_anchor=(1.15, 1)) - ax.set_xlabel("Lead time [hours]") - ax.set_ylabel(metric.upper()) - plt_fname = os.path.join(out_dir, "evaluation_{0}".format(metric)) - print("Saving basic evaluation plot in terms of {1} to '{2}'".format(method, metric, plt_fname)) - plt.savefig(plt_fname) - - plt.close() - - return True - - @staticmethod - def create_plot(data, data_diff, varname, plt_fname): - """ - Creates filled contour plot of forecast data and also draws contours for differences. - ML: So far, only plotting of the 2m temperature is supported (with 12 predicted hours/frames) - :param data: the forecasted data array to be plotted - :param data_diff: the reference data ('ground truth') - :param varname: the name of the variable - :param plt_fname: the filename to the store the plot - :return: - - """ - method = Postprocess.create_plot.__name__ - - try: - coords = data.coords - # handle coordinates and forecast times - lat, lon = coords["lat"], coords["lon"] - date0 = pd.to_datetime(coords["init_time"].data) - fhhs = coords["fcst_hour"].data - except Exception as err: - print("%{0}: Could not retrieve expected coordinates lat, lon and time_forecast from data.".format(method)) - raise err - - lons, lats = np.meshgrid(lon, lat) - - date0_str = date0.strftime("%Y-%m-%d %H:%M UTC") - - # check data to be plotted since programme is not generic so far - if np.shape(fhhs)[0] != 12: - raise ValueError("%{0}: Currently, only 12 hour forecast can be handled properly.".format(method)) - - if varname != "2t": - raise ValueError("%{0}: Currently, only 2m temperature is plotted nicely properly.".format(method)) - - # define levels - clevs = np.arange(-10., 40., 1.) - clevs_diff = np.arange(0.5, 10.5, 2.) - clevs_diff2 = np.arange(-10.5, -0.5, 2.) - - # create fig and subplot axes - fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(12, 6)) - axes = axes.flatten() - - # create all subplots - for t, fhh in enumerate(fhhs): - m = Basemap(projection='cyl', llcrnrlat=np.min(lat), urcrnrlat=np.max(lat), - llcrnrlon=np.min(lon), urcrnrlon=np.max(lon), resolution='l', ax=axes[t]) - m.drawcoastlines() - x, y = m(lons, lats) - if t%6 == 0: - lat_lab = [1, 0, 0, 0] - axes[t].set_ylabel(u'Latitude', labelpad=30) - else: - lat_lab = list(np.zeros(4)) - if t/6 >= 1: - lon_lab = [0, 0, 0, 1] - axes[t].set_xlabel(u'Longitude', labelpad=15) - else: - lon_lab = list(np.zeros(4)) - m.drawmapboundary() - m.drawparallels(np.arange(0, 90, 5),labels=lat_lab, xoffset=1.) - m.drawmeridians(np.arange(5, 355, 10),labels=lon_lab, yoffset=1.) - cs = m.contourf(x, y, data.isel(fcst_hour=t)-273.15, clevs, cmap=plt.get_cmap("jet"), ax=axes[t], - extend="both") - cs_c_pos = m.contour(x, y, data_diff.isel(fcst_hour=t), clevs_diff, linewidths=0.5, ax=axes[t], - colors="black") - cs_c_neg = m.contour(x, y, data_diff.isel(fcst_hour=t), clevs_diff2, linewidths=1, linestyles="dotted", - ax=axes[t], colors="black") - axes[t].set_title("{0} +{1:02d}:00".format(date0_str, int(fhh)), fontsize=7.5, pad=4) - - fig.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=-0.7, - wspace=0.05) - # add colorbar. - cbar_ax = fig.add_axes([0.3, 0.22, 0.4, 0.02]) - cbar = fig.colorbar(cs, cax=cbar_ax, orientation="horizontal") - cbar.set_label('°C') - # save to disk - plt.savefig(plt_fname, bbox_inches="tight") - def main(): parser = argparse.ArgumentParser() diff --git a/video_prediction_tools/postprocess/plotting.py b/video_prediction_tools/postprocess/plotting.py index 67e68171b87ceb181d32314a271f8493b9c0c460..3f19d0c3db7df021b7b5214daa6be1bdba9267f8 100644 --- a/video_prediction_tools/postprocess/plotting.py +++ b/video_prediction_tools/postprocess/plotting.py @@ -83,7 +83,6 @@ def plot_avg_eval_metrics(eval_ds, eval_metrics, fcst_prod_dict, varname, out_di return True - def create_plot(data, data_diff, varname, plt_fname): """ Creates filled contour plot of forecast data and also draws contours for differences.