diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 5c72b47585ccde6017f2f7769f9e09258d3943ca..6014db100c350198ca70b5b89d3a4a0fb6106670 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -634,7 +634,7 @@ class PlotBootstrapSkillScore(AbstractPlotClass): name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. """ - def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = ""): + def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars=None,): """ Sets attributes and create plot :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. @@ -642,11 +642,17 @@ class PlotBootstrapSkillScore(AbstractPlotClass): :param model_setup: architecture type to specify plot name (default "CNN") """ super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}") + if separate_vars is None: + separate_vars = ['o3'] self._labels = None self._x_name = "boot_var" self._data = self._prepare_data(data) self._plot() self._save() + self.plot_name += '_separated' + self._plot(separate_vars=separate_vars) + self._save(bbox_inches='tight') + def _prepare_data(self, data: Dict) -> pd.DataFrame: """ @@ -667,9 +673,107 @@ class PlotBootstrapSkillScore(AbstractPlotClass): """ return "" if score_only else "terms and " - def _plot(self): + def _plot(self, separate_vars: str = None): """ - Main plot function to plot climatological skill score. + Main plot function to plot boots. + """ + if separate_vars is None: + self._plot_all_variables() + else: + self._plot_selected_variables(separate_vars) + + def _plot_selected_variables(self, separate_vars: List[str] = ['o3']): + data = self._data + self.raise_error_if_separate_vars_do_not_exist(data, separate_vars) + all_variables = self._get_unique_values_from_column_of_df(data, 'boot_var') + remaining_vars = helpers.list_pop(all_variables, separate_vars) + data_first = self._select_data(df=data, variables=separate_vars, column_name='boot_var') + data_second = self._select_data(df=data, variables=remaining_vars, column_name='boot_var') + + fig, ax = plt.subplots(nrows=1, ncols=2, + gridspec_kw={'width_ratios': [len(separate_vars), + len(remaining_vars) + ] + } + ) + if len(separate_vars) > 1: + first_box_width = .8 + else: + first_box_width = 2. + + sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_first, ax=ax[0], whis=1., palette="Blues_d", + showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, + flierprops={"marker": "."}, width=first_box_width + ) + ax[0].set(ylabel=f"skill score", xlabel="") + + sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_second, ax=ax[1], whis=1., palette="Blues_d", + showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, + flierprops={"marker": "."}, + ) + ax[1].set(ylabel="", xlabel="") + ax[1].yaxis.tick_right() + handles, _ = ax[1].get_legend_handles_labels() + for sax in ax: + matplotlib.pyplot.sca(sax) + sax.axhline(y=0, color="grey", linewidth=.5) + plt.xticks(rotation=45, ha='right') + sax.legend_.remove() + + fig.legend(handles, self._labels, loc='upper center', ncol=len(handles)+1,) + + def align_yaxis(ax1, ax2): + """ + Align zeros of the two axes, zooming them out by same ratio + + This function is copy pasted from https://stackoverflow.com/a/41259922 + """ + axes = (ax1, ax2) + extrema = [ax.get_ylim() for ax in axes] + tops = [extr[1] / (extr[1] - extr[0]) for extr in extrema] + # Ensure that plots (intervals) are ordered bottom to top: + if tops[0] > tops[1]: + axes, extrema, tops = [list(reversed(l)) for l in (axes, extrema, tops)] + + # How much would the plot overflow if we kept current zoom levels? + tot_span = tops[1] + 1 - tops[0] + + b_new_t = extrema[0][0] + tot_span * (extrema[0][1] - extrema[0][0]) + t_new_b = extrema[1][1] - tot_span * (extrema[1][1] - extrema[1][0]) + axes[0].set_ylim(extrema[0][0], b_new_t) + axes[1].set_ylim(t_new_b, extrema[1][1]) + + align_yaxis(ax[0], ax[1]) + align_yaxis(ax[0], ax[1]) + + # plt.savefig('MYBOOTTESTPLOT.pdf', bbox_inches='tight') + + + @staticmethod + def _select_data(df: pd.DataFrame, variables: List[str], column_name: str) -> pd.DataFrame: + for i, variable in enumerate(variables): + if i == 0: + selected_data = df.loc[df[column_name] == variable] + else: + tmp_var = df.loc[df[column_name] == variable] + selected_data = pd.concat([selected_data, tmp_var], axis=0) + return selected_data + + def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars): + if not self._variables_exist_in_df(df=data, variables=separate_vars): + raise ValueError(f"At least one entry of `separate_vars' does not exist in `self.data' ") + + @staticmethod + def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List: + return list(df[column_name].unique()) + + def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str = 'boot_var'): + vars_in_df = set(self._get_unique_values_from_column_of_df(df, column_name)) + return set(variables).issubset(vars_in_df) + + def _plot_all_variables(self): + """ + """ fig, ax = plt.subplots() sns.boxplot(x=self._x_name, y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d", diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 5731b7a6291b146681e976aa40fcb3d87a464c3b..aa80a4b82c720172c60a5808a04fe82e5e9e543b 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -111,7 +111,10 @@ class PreProcessing(RunEnvironment): df_descr = df_nometa.iloc[:-2].astype('float32').describe( percentiles=[.05, .1, .25, .5, .75, .9, .95]).astype('int32') df_descr = pd.concat([df_nometa.loc[['# Samples']], df_descr]).T - df_descr.rename(columns={"# Samples": "sum"}, inplace=True) + df_descr.rename(columns={"# Samples": "no. samples", "count": "no. stations"}, inplace=True) + df_descr_colnames = list(df_descr.columns) + df_descr_colnames = [df_descr_colnames[1]] + [df_descr_colnames[0]] + df_descr_colnames[2:] + df_descr = df_descr[df_descr_colnames] column_format = self.create_column_format_for_tex(df_descr) df_descr.to_latex(os.path.join(path, "station_describe_short.tex"), na_rep='---', column_format=column_format)