From af678d0cd0123e35991a9cbd3581f88e985e1015 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Thu, 30 Jan 2020 11:02:44 +0100 Subject: [PATCH] added docs and updated plot file names in readme --- src/plotting/postprocessing_plotting.py | 125 ++++++++++++++++++++---- src/run_modules/README.md | 7 +- src/run_modules/post_processing.py | 6 +- 3 files changed, 113 insertions(+), 25 deletions(-) diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index b1434cd5..cd49ddd5 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -20,7 +20,7 @@ import cartopy.crs as ccrs import cartopy.feature as cfeature from matplotlib.backends.backend_pdf import PdfPages -from typing import Dict, List +from typing import Dict, List, Tuple logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -45,15 +45,16 @@ class PlotMonthlySummary(RunEnvironment): super().__init__() self._data_path = data_path self._data_name = name - self._data = self._get_data(stations) + self._data = self._prepare_data(stations) self._window_lead_time = self._get_window_lead_time(window_lead_time) self._plot(target_var, plot_folder) - def _get_data(self, stations): + def _prepare_data(self, stations: List) -> xr.DataArray: """ - pre-process data - :param stations: - :return: + Pre-process data required to plot. For each station, load locally saved predictions, extract the CNN and orig + prediction and group them into monthly bins (no aggregation, only sorting them). + :param stations: all stations to plot + :return: The entire data set, flagged with the corresponding month. """ forecasts = None for station in stations: @@ -76,13 +77,26 @@ class PlotMonthlySummary(RunEnvironment): forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat return forecasts - def _get_window_lead_time(self, window_lead_time): + def _get_window_lead_time(self, window_lead_time: int): + """ + Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from + data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number + of ahead dimensions in data is lower than the given lead time, data's lead time is used. + :param window_lead_time: lead time from arguments to validate + :return: validated lead time, comes either from given argument or from data itself + """ ahead_steps = len(self._data.ahead) if window_lead_time is None: window_lead_time = ahead_steps return min(ahead_steps, window_lead_time) - def _plot(self, target_var, plot_folder): + def _plot(self, target_var: str, plot_folder: str): + """ + Main plot function that creates a monthly grouped box plot over all stations but with separate boxes for each + lead time step. + :param target_var: display name of the target variable on plot's axis + :param plot_folder: path to save the plot + """ data = self._data.to_dataset(name='values').to_dask_dataframe() logging.debug("... start plotting") color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() @@ -95,6 +109,10 @@ class PlotMonthlySummary(RunEnvironment): @staticmethod def _save(plot_folder): + """ + Standard save method to store plot locally. The name of this plot is static. + :param plot_folder: path to save the plot + """ plot_name = os.path.join(os.path.abspath(plot_folder), 'monthly_summary_box_plot.pdf') logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=500) @@ -103,10 +121,10 @@ class PlotMonthlySummary(RunEnvironment): class PlotStationMap(RunEnvironment): """ - Plot geographical overview of all used stations. Different data sets can be colorised by its key in the input - dictionary generators. The key represents the color to plot on the map. Currently, there is only a white background, - but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is saved under - plot_path with the name station_map.pdf + Plot geographical overview of all used stations as squares. Different data sets can be colorised by its key in the + input dictionary generators. The key represents the color to plot on the map. Currently, there is only a white + background, but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is + saved under plot_path with the name station_map.pdf """ def __init__(self, generators: Dict, plot_folder: str = "."): """ @@ -120,6 +138,9 @@ class PlotStationMap(RunEnvironment): self._plot(generators, plot_folder) def _draw_background(self): + """ + Draw coastline, lakes, ocean, rivers and country borders as background on the map. + """ self._ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black') self._ax.add_feature(cfeature.LAKES.with_scale("50m")) self._ax.add_feature(cfeature.OCEAN.with_scale("50m")) @@ -127,6 +148,12 @@ class PlotStationMap(RunEnvironment): self._ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black') def _plot_stations(self, generators): + """ + The actual plot function. Loops over all keys in generators dict and its containing stations and plots a square + and the stations's position on the map regarding the given color. + :param generators: dictionary with the plot color of each data set as key and the generator containing all + stations as value. + """ if generators is not None: for color, gen in generators.items(): for k, v in enumerate(gen): @@ -136,7 +163,13 @@ class PlotStationMap(RunEnvironment): station_coords.loc['station_lat'].values) self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree()) - def _plot(self, generators, plot_folder): + def _plot(self, generators: Dict, plot_folder: str): + """ + Main plot function to create the station map plot. Sets figure and calls all required sub-methods. + :param generators: dictionary with the plot color of each data set as key and the generator containing all + stations as value. + :param plot_folder: path to save the plot + """ fig = plt.figure(figsize=(10, 5)) self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) self._ax.set_extent([0, 20, 42, 58], crs=ccrs.PlateCarree()) @@ -146,6 +179,10 @@ class PlotStationMap(RunEnvironment): @staticmethod def _save(plot_folder): + """ + Standard save method to store plot locally. The name of this plot is static. + :param plot_folder: path to save the plot + """ plot_name = os.path.join(os.path.abspath(plot_folder), 'station_map.pdf') logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=500) @@ -303,24 +340,43 @@ class PlotClimatologicalSkillScore(RunEnvironment): :param plot_folder: path to save the plot (default: current directory) :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True) :param extra_name_tag: additional tag that can be included in the plot name (default "") - :param model_setup: architecture type (default "CNN") + :param model_setup: architecture type to specify plot name (default "CNN") """ super().__init__() self._labels = None - self._data = self._process_data(data, score_only) + self._data = self._prepare_data(data, score_only) self._plot(plot_folder, score_only, extra_name_tag, model_setup) - def _process_data(self, data, score_only): + def _prepare_data(self, data: Dict, score_only: bool) -> pd.DataFrame: + """ + Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set + plot labels depending on the lead time dimensions. + :param data: dictionary with station names as keys and 2D xarrays as values + :param score_only: if true only scores of CASE I to IV are relevant + :return: pre-processed data set + """ data = helpers.dict_to_xarray(data, "station") self._labels = [str(i) + "d" for i in data.coords["ahead"].values] if score_only: data = data.loc[:, ["CASE I", "CASE II", "CASE III", "CASE IV"], :] return data.to_dataframe("data").reset_index(level=[0, 1, 2]) - def _label_add(self, score_only): + def _label_add(self, score_only: bool): + """ + Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True). + :param score_only: if false all terms are relevant, otherwise only CASE I to IV + :return: additional label + """ return "" if score_only else "terms and " def _plot(self, plot_folder, score_only, extra_name_tag, model_setup): + """ + Main plot function to plot climatological skill score. + :param plot_folder: path to save the plot + :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms + :param extra_name_tag: additional tag that can be included in the plot name + :param model_setup: architecture type to specify plot name + """ fig, ax = plt.subplots() if not score_only: fig.set_size_inches(11.7, 8.27) @@ -335,6 +391,13 @@ class PlotClimatologicalSkillScore(RunEnvironment): @staticmethod def _save(plot_folder, extra_name_tag, model_setup): + """ + Standard save method to store plot locally. The name of this plot is dynamic. It includes the model setup like + 'CNN' and can additionally be adjusted using an extra name tag. + :param plot_folder: path to save the plot + :param extra_name_tag: additional tag that can be included in the plot name + :param model_setup: architecture type to specify plot name + """ plot_name = os.path.join(plot_folder, f"skill_score_clim_{extra_name_tag}{model_setup}.pdf") logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=500) @@ -359,7 +422,13 @@ class PlotCompetitiveSkillScore(RunEnvironment): self._data = self._prepare_data(data) self._plot(plot_folder, model_setup) - def _prepare_data(self, data): + def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame: + """ + Reformat given data and create plot labels. Introduces the dimensions stations and comparison + :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre- + calculated comparisons for cnn, persistence and ols. + :return: processed data + """ data = pd.concat(data, axis=0) data = xr.DataArray(data, dims=["stations", "ahead"]).unstack("stations") data = data.rename({"stations_level_0": "stations", "stations_level_1": "comparison"}) @@ -369,6 +438,12 @@ class PlotCompetitiveSkillScore(RunEnvironment): return data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data") def _plot(self, plot_folder, model_setup): + """ + Main plot function to plot skill scores of the comparisons cnn-persi, ols-persi and cnn-ols. + :param plot_folder: path to save the plot + :param model_setup: + :return: architecture type to specify plot name + """ fig, ax = plt.subplots() sns.boxplot(x="comparison", y="data", hue="ahead", data=self._data, whis=1., ax=ax, palette="Blues_d", showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, @@ -381,14 +456,24 @@ class PlotCompetitiveSkillScore(RunEnvironment): plt.tight_layout() self._save(plot_folder, model_setup) - def _ylim(self): + def _ylim(self) -> Tuple[float, float]: + """ + Calculate y-axis limits from data. Lower is the minimum of either 0 or data's minimum (reduced by small + subtrahend) and upper limit is data's maximum (increased by a small addend). + :return: + """ lower = np.min([0, helpers.float_round(self._data.min()[2], 2) - 0.1]) upper = helpers.float_round(self._data.max()[2], 2) + 0.1 return lower, upper @staticmethod def _save(plot_folder, model_setup): - plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}2.pdf") + """ + Standard save method to store plot locally. The name of this plot is dynamic by including the model setup. + :param plot_folder: path to save the plot + :param model_setup: architecture type to specify plot name + """ + plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}.pdf") logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=500) plt.close() diff --git a/src/run_modules/README.md b/src/run_modules/README.md index 33149220..581811f1 100644 --- a/src/run_modules/README.md +++ b/src/run_modules/README.md @@ -47,8 +47,11 @@ experiment_path └─── plots conditional_quantiles_cali-ref_plot.pdf conditional_quantiles_like-bas_plot.pdf - test_monthly_box.pdf - test_map_plot.pdf + monthly_summary_box_plot.pdf + skill_score_clim_all_terms_<architecture>.pdf + skill_score_clim_<architecture>.pdf + skill_score_competitive_<architecture>.pdf + station_map.pdf <experiment_name>_history_learning_rate.pdf <experiment_name>_history_loss.pdf <experiment_name>_history_main_loss.pdf diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index b935aa83..a9695064 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -227,14 +227,14 @@ class PostProcessing(RunEnvironment): def calculate_skill_scores(self): path = self.data_store.get("forecast_path", "general") window_lead_time = self.data_store.get("window_lead_time", "general") - skill_score_general = {} + skill_score_competitive = {} skill_score_climatological = {} for station in self.test_data.stations: file = os.path.join(path, f"forecasts_{station}_test.nc") data = xr.open_dataarray(file) skill_score = statistics.SkillScores(data) external_data = self._get_external_data(station) - skill_score_general[station] = skill_score.skill_scores(window_lead_time) + skill_score_competitive[station] = skill_score.skill_scores(window_lead_time) skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data, window_lead_time) - return skill_score_general, skill_score_climatological + return skill_score_competitive, skill_score_climatological -- GitLab