diff --git a/video_prediction_tools/postprocess/postprocess_plotting.py b/video_prediction_tools/postprocess/postprocess_plotting.py index cde8f97e8d6756eae3367c7fa26294d46094ce7e..a6aaddd665eea21b07ff284efdcf8042c2dfbaba 100644 --- a/video_prediction_tools/postprocess/postprocess_plotting.py +++ b/video_prediction_tools/postprocess/postprocess_plotting.py @@ -60,6 +60,7 @@ def plot_cond_quantile(quantile_panel: xr.DataArray, data_marginal: xr.DataArray figsize = provide_default(opt, "figsize", (12, 6)) fs_title = provide_default(opt, "fs_axis_title", 16) fs_label = provide_default(opt, "fs_axis_label", fs_title-2) + plt_title = provide_default(opt, "plt_title", "") fig, ax = plt.subplots(figsize=figsize) # plot reference line @@ -73,14 +74,18 @@ def plot_cond_quantile(quantile_panel: xr.DataArray, data_marginal: xr.DataArray xr.plot.hist(data_marginal, ax=ax2, bins=bins, color="k", alpha=0.3) ax2.set_yscale("log") - ylabel = "{0} [{1}]".format(provide_default(quantile_panel.attrs, "data_cond_longname", "conditiong variable"), - provide_default(quantile_panel.attrs, "data_cond_unit", "unknown")) - xlabel = "{0} [{1}]".format(provide_default(quantile_panel.attrs, "data_cond_longname", "target variable"), - provide_default(quantile_panel.attrs, "data_cond_unit", "unknown")) + ylabel = "{0} [{1}]".format(provide_default(quantile_panel.attrs, "cond_var_name", "conditiong variable"), + provide_default(quantile_panel.attrs, "cond_var_unit", "unknown")) + xlabel = "{0} [{1}]".format(provide_default(quantile_panel.attrs, "tar_var_name", "target variable"), + provide_default(quantile_panel.attrs, "tar_var_unit", "unknown")) ax.set_ylabel(ylabel, fontsize=fs_title) ax2.set_ylabel("counts", fontsize=fs_title) ax.set_xlabel(xlabel, fontsize=fs_title) + # ensure that histogram extends to the lower half of the plot + y2_max_power = int(np.log10(ax2.get_ylim()[1])) + ax2.set(ylim=(1.e00, np.power(10, y2_max_power*4)), yticks=np.logspace(0, y2_max_power+1, y2_max_power+2)) + ax2.set_title(plt_title) ax.tick_params(axis="both", labelsize=fs_label) ax2.tick_params(axis="both", labelsize=fs_label)