Skip to content
Snippets Groups Projects
Commit af678d0c authored by lukas leufen's avatar lukas leufen
Browse files

added docs and updated plot file names in readme

parent e217eeb1
No related branches found
No related tags found
2 merge requests!37include new development,!28Lukas issue32 refac restructure plot routines in modules
Pipeline #28790 passed
......@@ -20,7 +20,7 @@ import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.backends.backend_pdf import PdfPages
from typing import Dict, List
from typing import Dict, List, Tuple
logging.getLogger('matplotlib').setLevel(logging.WARNING)
......@@ -45,15 +45,16 @@ class PlotMonthlySummary(RunEnvironment):
super().__init__()
self._data_path = data_path
self._data_name = name
self._data = self._get_data(stations)
self._data = self._prepare_data(stations)
self._window_lead_time = self._get_window_lead_time(window_lead_time)
self._plot(target_var, plot_folder)
def _get_data(self, stations):
def _prepare_data(self, stations: List) -> xr.DataArray:
"""
pre-process data
:param stations:
:return:
Pre-process data required to plot. For each station, load locally saved predictions, extract the CNN and orig
prediction and group them into monthly bins (no aggregation, only sorting them).
:param stations: all stations to plot
:return: The entire data set, flagged with the corresponding month.
"""
forecasts = None
for station in stations:
......@@ -76,13 +77,26 @@ class PlotMonthlySummary(RunEnvironment):
forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat
return forecasts
def _get_window_lead_time(self, window_lead_time):
def _get_window_lead_time(self, window_lead_time: int):
"""
Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from
data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number
of ahead dimensions in data is lower than the given lead time, data's lead time is used.
:param window_lead_time: lead time from arguments to validate
:return: validated lead time, comes either from given argument or from data itself
"""
ahead_steps = len(self._data.ahead)
if window_lead_time is None:
window_lead_time = ahead_steps
return min(ahead_steps, window_lead_time)
def _plot(self, target_var, plot_folder):
def _plot(self, target_var: str, plot_folder: str):
"""
Main plot function that creates a monthly grouped box plot over all stations but with separate boxes for each
lead time step.
:param target_var: display name of the target variable on plot's axis
:param plot_folder: path to save the plot
"""
data = self._data.to_dataset(name='values').to_dask_dataframe()
logging.debug("... start plotting")
color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex()
......@@ -95,6 +109,10 @@ class PlotMonthlySummary(RunEnvironment):
@staticmethod
def _save(plot_folder):
"""
Standard save method to store plot locally. The name of this plot is static.
:param plot_folder: path to save the plot
"""
plot_name = os.path.join(os.path.abspath(plot_folder), 'monthly_summary_box_plot.pdf')
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
......@@ -103,10 +121,10 @@ class PlotMonthlySummary(RunEnvironment):
class PlotStationMap(RunEnvironment):
"""
Plot geographical overview of all used stations. Different data sets can be colorised by its key in the input
dictionary generators. The key represents the color to plot on the map. Currently, there is only a white background,
but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is saved under
plot_path with the name station_map.pdf
Plot geographical overview of all used stations as squares. Different data sets can be colorised by its key in the
input dictionary generators. The key represents the color to plot on the map. Currently, there is only a white
background, but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is
saved under plot_path with the name station_map.pdf
"""
def __init__(self, generators: Dict, plot_folder: str = "."):
"""
......@@ -120,6 +138,9 @@ class PlotStationMap(RunEnvironment):
self._plot(generators, plot_folder)
def _draw_background(self):
"""
Draw coastline, lakes, ocean, rivers and country borders as background on the map.
"""
self._ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black')
self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
......@@ -127,6 +148,12 @@ class PlotStationMap(RunEnvironment):
self._ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black')
def _plot_stations(self, generators):
"""
The actual plot function. Loops over all keys in generators dict and its containing stations and plots a square
and the stations's position on the map regarding the given color.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
"""
if generators is not None:
for color, gen in generators.items():
for k, v in enumerate(gen):
......@@ -136,7 +163,13 @@ class PlotStationMap(RunEnvironment):
station_coords.loc['station_lat'].values)
self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())
def _plot(self, generators, plot_folder):
def _plot(self, generators: Dict, plot_folder: str):
"""
Main plot function to create the station map plot. Sets figure and calls all required sub-methods.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
:param plot_folder: path to save the plot
"""
fig = plt.figure(figsize=(10, 5))
self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
self._ax.set_extent([0, 20, 42, 58], crs=ccrs.PlateCarree())
......@@ -146,6 +179,10 @@ class PlotStationMap(RunEnvironment):
@staticmethod
def _save(plot_folder):
"""
Standard save method to store plot locally. The name of this plot is static.
:param plot_folder: path to save the plot
"""
plot_name = os.path.join(os.path.abspath(plot_folder), 'station_map.pdf')
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
......@@ -303,24 +340,43 @@ class PlotClimatologicalSkillScore(RunEnvironment):
:param plot_folder: path to save the plot (default: current directory)
:param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True)
:param extra_name_tag: additional tag that can be included in the plot name (default "")
:param model_setup: architecture type (default "CNN")
:param model_setup: architecture type to specify plot name (default "CNN")
"""
super().__init__()
self._labels = None
self._data = self._process_data(data, score_only)
self._data = self._prepare_data(data, score_only)
self._plot(plot_folder, score_only, extra_name_tag, model_setup)
def _process_data(self, data, score_only):
def _prepare_data(self, data: Dict, score_only: bool) -> pd.DataFrame:
"""
Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set
plot labels depending on the lead time dimensions.
:param data: dictionary with station names as keys and 2D xarrays as values
:param score_only: if true only scores of CASE I to IV are relevant
:return: pre-processed data set
"""
data = helpers.dict_to_xarray(data, "station")
self._labels = [str(i) + "d" for i in data.coords["ahead"].values]
if score_only:
data = data.loc[:, ["CASE I", "CASE II", "CASE III", "CASE IV"], :]
return data.to_dataframe("data").reset_index(level=[0, 1, 2])
def _label_add(self, score_only):
def _label_add(self, score_only: bool):
"""
Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True).
:param score_only: if false all terms are relevant, otherwise only CASE I to IV
:return: additional label
"""
return "" if score_only else "terms and "
def _plot(self, plot_folder, score_only, extra_name_tag, model_setup):
"""
Main plot function to plot climatological skill score.
:param plot_folder: path to save the plot
:param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms
:param extra_name_tag: additional tag that can be included in the plot name
:param model_setup: architecture type to specify plot name
"""
fig, ax = plt.subplots()
if not score_only:
fig.set_size_inches(11.7, 8.27)
......@@ -335,6 +391,13 @@ class PlotClimatologicalSkillScore(RunEnvironment):
@staticmethod
def _save(plot_folder, extra_name_tag, model_setup):
"""
Standard save method to store plot locally. The name of this plot is dynamic. It includes the model setup like
'CNN' and can additionally be adjusted using an extra name tag.
:param plot_folder: path to save the plot
:param extra_name_tag: additional tag that can be included in the plot name
:param model_setup: architecture type to specify plot name
"""
plot_name = os.path.join(plot_folder, f"skill_score_clim_{extra_name_tag}{model_setup}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
......@@ -359,7 +422,13 @@ class PlotCompetitiveSkillScore(RunEnvironment):
self._data = self._prepare_data(data)
self._plot(plot_folder, model_setup)
def _prepare_data(self, data):
def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Reformat given data and create plot labels. Introduces the dimensions stations and comparison
:param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre-
calculated comparisons for cnn, persistence and ols.
:return: processed data
"""
data = pd.concat(data, axis=0)
data = xr.DataArray(data, dims=["stations", "ahead"]).unstack("stations")
data = data.rename({"stations_level_0": "stations", "stations_level_1": "comparison"})
......@@ -369,6 +438,12 @@ class PlotCompetitiveSkillScore(RunEnvironment):
return data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data")
def _plot(self, plot_folder, model_setup):
"""
Main plot function to plot skill scores of the comparisons cnn-persi, ols-persi and cnn-ols.
:param plot_folder: path to save the plot
:param model_setup:
:return: architecture type to specify plot name
"""
fig, ax = plt.subplots()
sns.boxplot(x="comparison", y="data", hue="ahead", data=self._data, whis=1., ax=ax, palette="Blues_d",
showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
......@@ -381,14 +456,24 @@ class PlotCompetitiveSkillScore(RunEnvironment):
plt.tight_layout()
self._save(plot_folder, model_setup)
def _ylim(self):
def _ylim(self) -> Tuple[float, float]:
"""
Calculate y-axis limits from data. Lower is the minimum of either 0 or data's minimum (reduced by small
subtrahend) and upper limit is data's maximum (increased by a small addend).
:return:
"""
lower = np.min([0, helpers.float_round(self._data.min()[2], 2) - 0.1])
upper = helpers.float_round(self._data.max()[2], 2) + 0.1
return lower, upper
@staticmethod
def _save(plot_folder, model_setup):
plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}2.pdf")
"""
Standard save method to store plot locally. The name of this plot is dynamic by including the model setup.
:param plot_folder: path to save the plot
:param model_setup: architecture type to specify plot name
"""
plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
plt.close()
......@@ -47,8 +47,11 @@ experiment_path
└─── plots
conditional_quantiles_cali-ref_plot.pdf
conditional_quantiles_like-bas_plot.pdf
test_monthly_box.pdf
test_map_plot.pdf
monthly_summary_box_plot.pdf
skill_score_clim_all_terms_<architecture>.pdf
skill_score_clim_<architecture>.pdf
skill_score_competitive_<architecture>.pdf
station_map.pdf
<experiment_name>_history_learning_rate.pdf
<experiment_name>_history_loss.pdf
<experiment_name>_history_main_loss.pdf
......
......@@ -227,14 +227,14 @@ class PostProcessing(RunEnvironment):
def calculate_skill_scores(self):
path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
skill_score_general = {}
skill_score_competitive = {}
skill_score_climatological = {}
for station in self.test_data.stations:
file = os.path.join(path, f"forecasts_{station}_test.nc")
data = xr.open_dataarray(file)
skill_score = statistics.SkillScores(data)
external_data = self._get_external_data(station)
skill_score_general[station] = skill_score.skill_scores(window_lead_time)
skill_score_competitive[station] = skill_score.skill_scores(window_lead_time)
skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data,
window_lead_time)
return skill_score_general, skill_score_climatological
return skill_score_competitive, skill_score_climatological
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment