From af678d0cd0123e35991a9cbd3581f88e985e1015 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Thu, 30 Jan 2020 11:02:44 +0100
Subject: [PATCH] added docs and updated plot file names in readme

---
 src/plotting/postprocessing_plotting.py | 125 ++++++++++++++++++++----
 src/run_modules/README.md               |   7 +-
 src/run_modules/post_processing.py      |   6 +-
 3 files changed, 113 insertions(+), 25 deletions(-)

diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py
index b1434cd5..cd49ddd5 100644
--- a/src/plotting/postprocessing_plotting.py
+++ b/src/plotting/postprocessing_plotting.py
@@ -20,7 +20,7 @@ import cartopy.crs as ccrs
 import cartopy.feature as cfeature
 from matplotlib.backends.backend_pdf import PdfPages
 
-from typing import Dict, List
+from typing import Dict, List, Tuple
 
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
@@ -45,15 +45,16 @@ class PlotMonthlySummary(RunEnvironment):
         super().__init__()
         self._data_path = data_path
         self._data_name = name
-        self._data = self._get_data(stations)
+        self._data = self._prepare_data(stations)
         self._window_lead_time = self._get_window_lead_time(window_lead_time)
         self._plot(target_var, plot_folder)
 
-    def _get_data(self, stations):
+    def _prepare_data(self, stations: List) -> xr.DataArray:
         """
-        pre-process data
-        :param stations:
-        :return:
+        Pre-process data required to plot. For each station, load locally saved predictions, extract the CNN and orig
+        prediction and group them into monthly bins (no aggregation, only sorting them).
+        :param stations: all stations to plot
+        :return: The entire data set, flagged with the corresponding month.
         """
         forecasts = None
         for station in stations:
@@ -76,13 +77,26 @@ class PlotMonthlySummary(RunEnvironment):
             forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat
         return forecasts
 
-    def _get_window_lead_time(self, window_lead_time):
+    def _get_window_lead_time(self, window_lead_time: int):
+        """
+        Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from
+        data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number
+        of ahead dimensions in data is lower than the given lead time, data's lead time is used.
+        :param window_lead_time: lead time from arguments to validate
+        :return: validated lead time, comes either from given argument or from data itself
+        """
         ahead_steps = len(self._data.ahead)
         if window_lead_time is None:
             window_lead_time = ahead_steps
         return min(ahead_steps, window_lead_time)
 
-    def _plot(self, target_var, plot_folder):
+    def _plot(self, target_var: str, plot_folder: str):
+        """
+        Main plot function that creates a monthly grouped box plot over all stations but with separate boxes for each
+        lead time step.
+        :param target_var: display name of the target variable on plot's axis
+        :param plot_folder: path to save the plot
+        """
         data = self._data.to_dataset(name='values').to_dask_dataframe()
         logging.debug("... start plotting")
         color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex()
@@ -95,6 +109,10 @@ class PlotMonthlySummary(RunEnvironment):
 
     @staticmethod
     def _save(plot_folder):
+        """
+        Standard save method to store plot locally. The name of this plot is static.
+        :param plot_folder: path to save the plot
+        """
         plot_name = os.path.join(os.path.abspath(plot_folder), 'monthly_summary_box_plot.pdf')
         logging.debug(f"... save plot to {plot_name}")
         plt.savefig(plot_name, dpi=500)
@@ -103,10 +121,10 @@ class PlotMonthlySummary(RunEnvironment):
 
 class PlotStationMap(RunEnvironment):
     """
-    Plot geographical overview of all used stations. Different data sets can be colorised by its key in the input
-    dictionary generators. The key represents the color to plot on the map. Currently, there is only a white background,
-    but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is saved under
-    plot_path with the name station_map.pdf
+    Plot geographical overview of all used stations as squares. Different data sets can be colorised by its key in the
+    input dictionary generators. The key represents the color to plot on the map. Currently, there is only a white
+    background, but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is
+    saved under plot_path with the name station_map.pdf
     """
     def __init__(self, generators: Dict, plot_folder: str = "."):
         """
@@ -120,6 +138,9 @@ class PlotStationMap(RunEnvironment):
         self._plot(generators, plot_folder)
 
     def _draw_background(self):
+        """
+        Draw coastline, lakes, ocean, rivers and country borders as background on the map.
+        """
         self._ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black')
         self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
         self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
@@ -127,6 +148,12 @@ class PlotStationMap(RunEnvironment):
         self._ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black')
 
     def _plot_stations(self, generators):
+        """
+        The actual plot function. Loops over all keys in generators dict and its containing stations and plots a square
+        and the stations's position on the map regarding the given color.
+        :param generators: dictionary with the plot color of each data set as key and the generator containing all
+            stations as value.
+        """
         if generators is not None:
             for color, gen in generators.items():
                 for k, v in enumerate(gen):
@@ -136,7 +163,13 @@ class PlotStationMap(RunEnvironment):
                         station_coords.loc['station_lat'].values)
                     self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())
 
-    def _plot(self, generators, plot_folder):
+    def _plot(self, generators: Dict, plot_folder: str):
+        """
+        Main plot function to create the station map plot. Sets figure and calls all required sub-methods.
+        :param generators: dictionary with the plot color of each data set as key and the generator containing all
+            stations as value.
+        :param plot_folder: path to save the plot
+        """
         fig = plt.figure(figsize=(10, 5))
         self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
         self._ax.set_extent([0, 20, 42, 58], crs=ccrs.PlateCarree())
@@ -146,6 +179,10 @@ class PlotStationMap(RunEnvironment):
 
     @staticmethod
     def _save(plot_folder):
+        """
+        Standard save method to store plot locally. The name of this plot is static.
+        :param plot_folder: path to save the plot
+        """
         plot_name = os.path.join(os.path.abspath(plot_folder), 'station_map.pdf')
         logging.debug(f"... save plot to {plot_name}")
         plt.savefig(plot_name, dpi=500)
@@ -303,24 +340,43 @@ class PlotClimatologicalSkillScore(RunEnvironment):
         :param plot_folder: path to save the plot (default: current directory)
         :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True)
         :param extra_name_tag: additional tag that can be included in the plot name (default "")
-        :param model_setup: architecture type (default "CNN")
+        :param model_setup: architecture type to specify plot name (default "CNN")
         """
         super().__init__()
         self._labels = None
-        self._data = self._process_data(data, score_only)
+        self._data = self._prepare_data(data, score_only)
         self._plot(plot_folder, score_only, extra_name_tag, model_setup)
 
-    def _process_data(self, data, score_only):
+    def _prepare_data(self, data: Dict, score_only: bool) -> pd.DataFrame:
+        """
+        Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set
+        plot labels depending on the lead time dimensions.
+        :param data: dictionary with station names as keys and 2D xarrays as values
+        :param score_only: if true only scores of CASE I to IV are relevant
+        :return: pre-processed data set
+        """
         data = helpers.dict_to_xarray(data, "station")
         self._labels = [str(i) + "d" for i in data.coords["ahead"].values]
         if score_only:
             data = data.loc[:, ["CASE I", "CASE II", "CASE III", "CASE IV"], :]
         return data.to_dataframe("data").reset_index(level=[0, 1, 2])
 
-    def _label_add(self, score_only):
+    def _label_add(self, score_only: bool):
+        """
+        Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True).
+        :param score_only: if false all terms are relevant, otherwise only CASE I to IV
+        :return: additional label
+        """
         return "" if score_only else "terms and "
 
     def _plot(self, plot_folder, score_only, extra_name_tag, model_setup):
+        """
+        Main plot function to plot climatological skill score.
+        :param plot_folder: path to save the plot
+        :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms
+        :param extra_name_tag: additional tag that can be included in the plot name
+        :param model_setup: architecture type to specify plot name
+        """
         fig, ax = plt.subplots()
         if not score_only:
             fig.set_size_inches(11.7, 8.27)
@@ -335,6 +391,13 @@ class PlotClimatologicalSkillScore(RunEnvironment):
 
     @staticmethod
     def _save(plot_folder, extra_name_tag, model_setup):
+        """
+        Standard save method to store plot locally. The name of this plot is dynamic. It includes the model setup like
+        'CNN' and can additionally be adjusted using an extra name tag.
+        :param plot_folder: path to save the plot
+        :param extra_name_tag: additional tag that can be included in the plot name
+        :param model_setup: architecture type to specify plot name
+        """
         plot_name = os.path.join(plot_folder, f"skill_score_clim_{extra_name_tag}{model_setup}.pdf")
         logging.debug(f"... save plot to {plot_name}")
         plt.savefig(plot_name, dpi=500)
@@ -359,7 +422,13 @@ class PlotCompetitiveSkillScore(RunEnvironment):
         self._data = self._prepare_data(data)
         self._plot(plot_folder, model_setup)
 
-    def _prepare_data(self, data):
+    def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
+        """
+        Reformat given data and create plot labels. Introduces the dimensions stations and comparison
+        :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre-
+            calculated comparisons for cnn, persistence and ols.
+        :return: processed data
+        """
         data = pd.concat(data, axis=0)
         data = xr.DataArray(data, dims=["stations", "ahead"]).unstack("stations")
         data = data.rename({"stations_level_0": "stations", "stations_level_1": "comparison"})
@@ -369,6 +438,12 @@ class PlotCompetitiveSkillScore(RunEnvironment):
         return data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data")
 
     def _plot(self, plot_folder, model_setup):
+        """
+        Main plot function to plot skill scores of the comparisons cnn-persi, ols-persi and cnn-ols.
+        :param plot_folder: path to save the plot
+        :param model_setup:
+        :return: architecture type to specify plot name
+        """
         fig, ax = plt.subplots()
         sns.boxplot(x="comparison", y="data", hue="ahead", data=self._data, whis=1., ax=ax, palette="Blues_d",
                     showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
@@ -381,14 +456,24 @@ class PlotCompetitiveSkillScore(RunEnvironment):
         plt.tight_layout()
         self._save(plot_folder, model_setup)
 
-    def _ylim(self):
+    def _ylim(self) -> Tuple[float, float]:
+        """
+        Calculate y-axis limits from data. Lower is the minimum of either 0 or data's minimum (reduced by small
+        subtrahend) and upper limit is data's maximum (increased by a small addend).
+        :return:
+        """
         lower = np.min([0, helpers.float_round(self._data.min()[2], 2) - 0.1])
         upper = helpers.float_round(self._data.max()[2], 2) + 0.1
         return lower, upper
 
     @staticmethod
     def _save(plot_folder, model_setup):
-        plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}2.pdf")
+        """
+        Standard save method to store plot locally. The name of this plot is dynamic by including the model setup.
+        :param plot_folder: path to save the plot
+        :param model_setup: architecture type to specify plot name
+        """
+        plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}.pdf")
         logging.debug(f"... save plot to {plot_name}")
         plt.savefig(plot_name, dpi=500)
         plt.close()
diff --git a/src/run_modules/README.md b/src/run_modules/README.md
index 33149220..581811f1 100644
--- a/src/run_modules/README.md
+++ b/src/run_modules/README.md
@@ -47,8 +47,11 @@ experiment_path
 └─── plots
         conditional_quantiles_cali-ref_plot.pdf
         conditional_quantiles_like-bas_plot.pdf
-        test_monthly_box.pdf
-        test_map_plot.pdf
+        monthly_summary_box_plot.pdf
+        skill_score_clim_all_terms_<architecture>.pdf
+        skill_score_clim_<architecture>.pdf
+        skill_score_competitive_<architecture>.pdf
+        station_map.pdf
         <experiment_name>_history_learning_rate.pdf
         <experiment_name>_history_loss.pdf
         <experiment_name>_history_main_loss.pdf
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index b935aa83..a9695064 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -227,14 +227,14 @@ class PostProcessing(RunEnvironment):
     def calculate_skill_scores(self):
         path = self.data_store.get("forecast_path", "general")
         window_lead_time = self.data_store.get("window_lead_time", "general")
-        skill_score_general = {}
+        skill_score_competitive = {}
         skill_score_climatological = {}
         for station in self.test_data.stations:
             file = os.path.join(path, f"forecasts_{station}_test.nc")
             data = xr.open_dataarray(file)
             skill_score = statistics.SkillScores(data)
             external_data = self._get_external_data(station)
-            skill_score_general[station] = skill_score.skill_scores(window_lead_time)
+            skill_score_competitive[station] = skill_score.skill_scores(window_lead_time)
             skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data,
                                                                                           window_lead_time)
-        return skill_score_general, skill_score_climatological
+        return skill_score_competitive, skill_score_climatological
-- 
GitLab