diff --git a/video_prediction_tools/postprocess/postprocess_plotting.py b/video_prediction_tools/postprocess/postprocess_plotting.py index bd21c535f7eadccf8f0998a46dbf7955e5145bb5..84b880f7f665acd549d8459376af4a1e3a9ac7ec 100644 --- a/video_prediction_tools/postprocess/postprocess_plotting.py +++ b/video_prediction_tools/postprocess/postprocess_plotting.py @@ -18,7 +18,7 @@ from mpl_toolkits.basemap import Basemap from general_utils import provide_default -def plot_cond_quantile(quantile_panel: xr.DataArray, data_marginal: xr.DataArray, plt_fname: str, opt: dict): +def plot_cond_quantile(quantile_panel: xr.DataArray, data_marginal: xr.DataArray, plt_fname: str, opt: dict = None): """ Creates conditional quantile plot :param quantile_panel: quantile panel created by calculate_cond_quantiles @@ -39,6 +39,9 @@ def plot_cond_quantile(quantile_panel: xr.DataArray, data_marginal: xr.DataArray if list(quantile_panel.coords) != ["bin_center", "quantile"]: raise ValueError("%{0}: The coordinates of quantile_panel must be ['bin_center', 'quantile']".format(method)) + if opt is None: + opt = {} + bins_c = quantile_panel["bin_center"] bin_width = bins_c[1] - bins_c[0] bins = np.arange(bins_c[0]-bin_width/2., bins_c+1.5*bin_width/2, bin_width) @@ -53,6 +56,8 @@ def plot_cond_quantile(quantile_panel: xr.DataArray, data_marginal: xr.DataArray # start plotting figsize = provide_default(opt, "figsize", (12, 6)) + fs_title = provide_default(opt, "fs_axis_title", 16) + fs_label = provide_default(opt, "fs_axis_label", fs_title-2) fig, ax = plt.subplots(figsize=figsize) # plot reference line @@ -71,12 +76,12 @@ def plot_cond_quantile(quantile_panel: xr.DataArray, data_marginal: xr.DataArray xlabel = "{0} [{1}]".format(provide_default(quantile_panel.attr, "data_cond_longname", "target variable"), provide_default(quantile_panel.attr, "data_cond_unit", "unknown")) - ax.set_ylabel(ylabel, fontsize=16) - ax2.set_ylabel("counts", fontsize=16) - ax.set_xlabel(xlabel, fontsize=16) + ax.set_ylabel(ylabel, fontsize=fs_title) + ax2.set_ylabel("counts", fontsize=fs_title) + ax.set_xlabel(xlabel, fontsize=fs_title) - ax.tick_params(axis="both", labelsize=14) - ax2.tick_params(axis="both", labelsize=14) + ax.tick_params(axis="both", labelsize=fs_label) + ax2.tick_params(axis="both", labelsize=fs_label) fig.savefig(plt_fname) plt.close("all")