diff --git a/docs/_source/_plots/conditional_quantiles_cali-ref_plot-1.png b/docs/_source/_plots/conditional_quantiles_cali-ref_plot-1.png new file mode 100644 index 0000000000000000000000000000000000000000..94373ab2b71a2a719fbeac84a5e6b5230f93909c Binary files /dev/null and b/docs/_source/_plots/conditional_quantiles_cali-ref_plot-1.png differ diff --git a/docs/_source/_plots/conditional_quantiles_cali-ref_plot-2.png b/docs/_source/_plots/conditional_quantiles_cali-ref_plot-2.png new file mode 100644 index 0000000000000000000000000000000000000000..bedc075b8cc3bc75e1dabfbbec02cbdb6c69123a Binary files /dev/null and b/docs/_source/_plots/conditional_quantiles_cali-ref_plot-2.png differ diff --git a/docs/_source/_plots/conditional_quantiles_cali-ref_plot-3.png b/docs/_source/_plots/conditional_quantiles_cali-ref_plot-3.png new file mode 100644 index 0000000000000000000000000000000000000000..ccc454211e5dbf16374ebbee522ea584e24a4fbd Binary files /dev/null and b/docs/_source/_plots/conditional_quantiles_cali-ref_plot-3.png differ diff --git a/docs/_source/_plots/conditional_quantiles_like-bas_plot-1.png b/docs/_source/_plots/conditional_quantiles_like-bas_plot-1.png new file mode 100644 index 0000000000000000000000000000000000000000..1641a12f678028c96646b2daabbf06c599cfb86a Binary files /dev/null and b/docs/_source/_plots/conditional_quantiles_like-bas_plot-1.png differ diff --git a/docs/_source/_plots/conditional_quantiles_like-bas_plot-2.png b/docs/_source/_plots/conditional_quantiles_like-bas_plot-2.png new file mode 100644 index 0000000000000000000000000000000000000000..c851f8f58a33cc2b37917e8964faa65243b3e8a6 Binary files /dev/null and b/docs/_source/_plots/conditional_quantiles_like-bas_plot-2.png differ diff --git a/docs/_source/_plots/conditional_quantiles_like-bas_plot-3.png b/docs/_source/_plots/conditional_quantiles_like-bas_plot-3.png new file mode 100644 index 0000000000000000000000000000000000000000..302862bc61d881f879a4bb7c860a2a55d46a76af Binary files /dev/null and b/docs/_source/_plots/conditional_quantiles_like-bas_plot-3.png differ diff --git a/docs/_source/_plots/data_availability.png b/docs/_source/_plots/data_availability.png new file mode 100644 index 0000000000000000000000000000000000000000..a2350c4f57befb65b5d90721b9ae51257b59c4a5 Binary files /dev/null and b/docs/_source/_plots/data_availability.png differ diff --git a/docs/_source/_plots/data_availability_combined.png b/docs/_source/_plots/data_availability_combined.png new file mode 100644 index 0000000000000000000000000000000000000000..ae8fa5c034b3694171ec348cdc20fa3f73795691 Binary files /dev/null and b/docs/_source/_plots/data_availability_combined.png differ diff --git a/docs/_source/_plots/data_availability_summary.png b/docs/_source/_plots/data_availability_summary.png new file mode 100644 index 0000000000000000000000000000000000000000..db88b4d1ea4b5d22b8c04143da0824beef41eff9 Binary files /dev/null and b/docs/_source/_plots/data_availability_summary.png differ diff --git a/docs/_source/_plots/monthly_summary_box_plot.png b/docs/_source/_plots/monthly_summary_box_plot.png new file mode 100644 index 0000000000000000000000000000000000000000..f7447d8283adeb62d43d322769bf08925c0e2d89 Binary files /dev/null and b/docs/_source/_plots/monthly_summary_box_plot.png differ diff --git a/docs/_source/_plots/skill_score_bootstrap.png b/docs/_source/_plots/skill_score_bootstrap.png new file mode 100644 index 0000000000000000000000000000000000000000..844bf7f48cd32d588363b75623c7b7d5691a9988 Binary files /dev/null and b/docs/_source/_plots/skill_score_bootstrap.png differ diff --git a/docs/_source/_plots/skill_score_clim_CNN.png b/docs/_source/_plots/skill_score_clim_CNN.png new file mode 100644 index 0000000000000000000000000000000000000000..28a66b5c43b71c39a57d81123dfca7e3158dd8ce Binary files /dev/null and b/docs/_source/_plots/skill_score_clim_CNN.png differ diff --git a/docs/_source/_plots/skill_score_clim_all_terms_CNN.png b/docs/_source/_plots/skill_score_clim_all_terms_CNN.png new file mode 100644 index 0000000000000000000000000000000000000000..000b942154dbe9dde9f48f64ab1b967a6811907d Binary files /dev/null and b/docs/_source/_plots/skill_score_clim_all_terms_CNN.png differ diff --git a/docs/_source/_plots/skill_score_competitive.png b/docs/_source/_plots/skill_score_competitive.png new file mode 100644 index 0000000000000000000000000000000000000000..6b5342c31579c9c6c59ebacded8a92d02cb7c1f4 Binary files /dev/null and b/docs/_source/_plots/skill_score_competitive.png differ diff --git a/docs/_source/_plots/station_map.png b/docs/_source/_plots/station_map.png new file mode 100644 index 0000000000000000000000000000000000000000..181440f4003a65cdacfae66309fb981f3bb420b8 Binary files /dev/null and b/docs/_source/_plots/station_map.png differ diff --git a/docs/_source/_plots/testrun_network_daily_history_learning_rate-1.png b/docs/_source/_plots/testrun_network_daily_history_learning_rate-1.png new file mode 100644 index 0000000000000000000000000000000000000000..c433a6431fb84322ca0097cb5b567aec1d063661 Binary files /dev/null and b/docs/_source/_plots/testrun_network_daily_history_learning_rate-1.png differ diff --git a/docs/_source/_plots/testrun_network_daily_history_loss-1.png b/docs/_source/_plots/testrun_network_daily_history_loss-1.png new file mode 100644 index 0000000000000000000000000000000000000000..3a2234e4b39036f843396f2538ebbe5d4ec8ed5b Binary files /dev/null and b/docs/_source/_plots/testrun_network_daily_history_loss-1.png differ diff --git a/docs/_source/_plots/testrun_network_daily_history_main_mse-1.png b/docs/_source/_plots/testrun_network_daily_history_main_mse-1.png new file mode 100644 index 0000000000000000000000000000000000000000..71f2f2cea3e55d5c3cd404187d95e3255aea4e63 Binary files /dev/null and b/docs/_source/_plots/testrun_network_daily_history_main_mse-1.png differ diff --git a/docs/_source/conf.py b/docs/_source/conf.py index 24c87532958b1ed6a315b29a7fa890bb077e66b4..8a3181cba2431679f74dccd3f32cb94713e2f230 100644 --- a/docs/_source/conf.py +++ b/docs/_source/conf.py @@ -45,6 +45,7 @@ extensions = [ 'sphinx_rtd_theme', 'sphinx.ext.githubpages', 'recommonmark', + 'sphinx.ext.autosectionlabel', ] # 2020-02-19 Begin diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 4de0bf5a6cd8f13898e9684f29e966ecd349c5be..d516ab77781221d72be0e209133b8b78170259f3 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -8,7 +8,7 @@ In this module, you can find some exemplary model classes that have been build a * `MyLittleModel`: small model implementation with a single 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time). * `MyBranchedModel`: a model with single 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), it has three - output branches from different layers of the model. + output branches from different layers of the model. * `MyTowerModel`: a more complex model with inception blocks (called towers) * `MyPaperModel`: A model used for the publication: <Add Publication Title / Citation> @@ -38,9 +38,9 @@ How to create a customised model? * Make sure to add the `super().__init__()` and at least `set_model()` and `set_loss()` to your custom init method. * If you have custom objects in your model, that are not part of keras, you need to add them to custom objects. To do - this, call `set_custom_objects` with arbitrarily kwargs. In the shown example, the loss has been added, because it - wasn't a standard loss. Apart from this, we always encourage you to add the loss as custom object, to prevent - potential errors when loading an already created model instead of training a new one. + this, call `set_custom_objects` with arbitrarily kwargs. In the shown example, the loss has been added, because it + wasn't a standard loss. Apart from this, we always encourage you to add the loss as custom object, to prevent + potential errors when loading an already created model instead of training a new one. * Build your model inside `set_model()`, e.g. .. code-block:: python @@ -69,7 +69,8 @@ How to create a customised model? def set_loss(self): self.loss = keras.losses.mean_squared_error -* If you have a branched model with multiple outputs, you need to consider the right ordering. E.g. +* If you have a branched model with multiple outputs, you need either set only a single loss for all branch outputs or + to provide the same number of loss functions considering the right order. E.g. .. code-block:: python @@ -102,11 +103,8 @@ You can treat the instance of your model as instance but also as the model itsel the model instead of the model instance, you can directly apply the command on the instance instead of adding the model parameter call. ->>> MyCustomisedModel().model.compile() - -is therefore equal to the command - ->>> MyCustomisedModel().compile() +>>> MyCustomisedModel().model.compile(**kwargs) == MyCustomisedModel().compile(**kwargs) +True """ diff --git a/src/plotting/__init__.py b/src/plotting/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..cc92014bb42fcf43b983d576fe6d88aeb2dd797b 100644 --- a/src/plotting/__init__.py +++ b/src/plotting/__init__.py @@ -0,0 +1 @@ +"""Collection of all plots that can be used during experiment for monitoring and evaluation.""" diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 92ff9f2bb2137440cfbf74041ab5187967f8b934..8efd54bb23035eed4cf51e94235bf8de7ff2a481 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -1,3 +1,4 @@ +"""Collection of plots to evaluate a model, create overviews on data or forecasts.""" __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-12-17' @@ -10,13 +11,13 @@ from typing import Dict, List, Tuple import cartopy.crs as ccrs import cartopy.feature as cfeature import matplotlib +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import xarray as xr from matplotlib.backends.backend_pdf import PdfPages -import matplotlib.patches as mpatches from src import helpers from src.data_handling.data_generator import DataGenerator @@ -26,19 +27,61 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING) class AbstractPlotClass: + """ + Abstract class for all plotting routines to unify plot workflow. + + Each inheritance requires a _plot method. Create a plot class like: + + .. code-block:: python + + class MyCustomPlot(AbstractPlotClass): + + def __init__(self, plot_folder, *args, **kwargs): + super().__init__(plot_folder, "custom_plot_name") + self._data = self._prepare_data(*args, **kwargs) + self._plot(*args, **kwargs) + self._save() + + def _prepare_data(*args, **kwargs): + <your custom data preparation> + return data + + def _plot(*args, **kwargs): + <your custom plotting without saving> + + The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are + using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be + set in super class initialisation). + + Methods like the shown _prepare_data() are optional. The only method required to implement is _plot. + + If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot + class. It will log the spent time if you call your plotting without saving the returned object. + + .. code-block:: python + + @TimeTrackingWrapper + class MyCustomPlot(AbstractPlotClass): + pass + + Let's assume it takes a while to create this very special plot. + >>> MyCustomPlot() + INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss) + + """ def __init__(self, plot_folder, plot_name, resolution=500): + """Set up plot folder and name, and plot resolution (default 500dpi)""" self.plot_folder = plot_folder self.plot_name = plot_name self.resolution = resolution def _plot(self, *args): + """Abstract plot class needs to be implemented in inheritance.""" raise NotImplementedError def _save(self, **kwargs): - """ - Standard save method to store plot locally. Name of and path to plot need to be set on initialisation - """ + """Standard save method to store plot locally. Name of and path to plot need to be set on initialisation.""" plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf") logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=self.resolution, **kwargs) @@ -50,11 +93,17 @@ class PlotMonthlySummary(AbstractPlotClass): """ Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. The plot is saved in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution. + + .. image:: ../../../../../_source/_plots/monthly_summary_box_plot.png + :width: 400 + """ + def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None, plot_folder: str = "."): """ Sets attributes and create plot + :param stations: all stations to plot :param data_path: path, where the data is located :param name: full name of the local files with a % as placeholder for the station name @@ -75,6 +124,7 @@ class PlotMonthlySummary(AbstractPlotClass): """ Pre-process data required to plot. For each station, load locally saved predictions, extract the CNN prediction and the observation and group them into monthly bins (no aggregation, only sorting them). + :param stations: all stations to plot :return: The entire data set, flagged with the corresponding month. """ @@ -105,6 +155,7 @@ class PlotMonthlySummary(AbstractPlotClass): Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number of ahead dimensions in data is lower than the given lead time, data's lead time is used. + :param window_lead_time: lead time from arguments to validate :return: validated lead time, comes either from given argument or from data itself """ @@ -117,11 +168,13 @@ class PlotMonthlySummary(AbstractPlotClass): """ Main plot function that creates a monthly grouped box plot over all stations but with separate boxes for each lead time step. + :param target_var: display name of the target variable on plot's axis """ data = self._data.to_dataset(name='values').to_dask_dataframe() logging.debug("... start plotting") - color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() + color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", + self._window_lead_time).as_hex() ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1., palette=color_palette, flierprops={'marker': '.', 'markersize': 1}, showmeans=True, meanprops={'markersize': 1, 'markeredgecolor': 'k'}) @@ -136,10 +189,15 @@ class PlotStationMap(AbstractPlotClass): input dictionary generators. The key represents the color to plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf + + .. image:: ../../../../../_source/_plots/station_map.png + :width: 400 """ + def __init__(self, generators: Dict, plot_folder: str = "."): """ Sets attributes and create plot + :param generators: dictionary with the plot color of each data set as key and the generator containing all stations as value. :param plot_folder: path to save the plot (default: current directory) @@ -163,6 +221,7 @@ class PlotStationMap(AbstractPlotClass): """ The actual plot function. Loops over all keys in generators dict and its containing stations and plots a square and the stations's position on the map regarding the given color. + :param generators: dictionary with the plot color of each data set as key and the generator containing all stations as value. """ @@ -178,6 +237,7 @@ class PlotStationMap(AbstractPlotClass): def _plot(self, generators: Dict): """ Main plot function to create the station map plot. Sets figure and calls all required sub-methods. + :param generators: dictionary with the plot color of each data set as key and the generator containing all stations as value. """ @@ -207,8 +267,6 @@ def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_w :param plot_name_affix: name to specify this plot (e.g. 'cali-ref', default: '') :param units: units of the forecasted values (default: ppb) """ - # time = TimeTracking() - logging.debug(f"started plot_conditional_quantiles()") # ignore warnings if nans appear in quantile grouping warnings.filterwarnings("ignore", message="All-NaN slice encountered") # ignore warnings if mean is calculated on nans @@ -265,7 +323,10 @@ def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_w # set name and path of the plot base_name = "conditional_quantiles" - def add_affix(x): return f"_{x}" if len(x) > 0 else "" + + def add_affix(x): + return f"_{x}" if len(x) > 0 else "" + plot_name = f"{base_name}{add_affix(season)}{add_affix(plot_name_affix)}_plot.pdf" plot_path = os.path.join(os.path.abspath(plot_folder), plot_name) @@ -321,7 +382,6 @@ def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_w # close all open figures / plots pdf_pages.close() plt.close('all') - #logging.info(f"plot_conditional_quantiles() finished after {time}") @TimeTrackingWrapper @@ -332,11 +392,20 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed (score_only=True, default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. + + .. image:: ../../../../../_source/_plots/skill_score_clim_all_terms_CNN.png + :width: 400 + + .. image:: ../../../../../_source/_plots/skill_score_clim_CNN.png + :width: 400 + """ + def __init__(self, data: Dict, plot_folder: str = ".", score_only: bool = True, extra_name_tag: str = "", model_setup: str = ""): """ Sets attributes and create plot + :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. :param plot_folder: path to save the plot (default: current directory) :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True) @@ -353,6 +422,7 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): """ Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set plot labels depending on the lead time dimensions. + :param data: dictionary with station names as keys and 2D xarrays as values :param score_only: if true only scores of CASE I to IV are relevant :return: pre-processed data set @@ -366,6 +436,7 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): def _label_add(self, score_only: bool): """ Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True). + :param score_only: if false all terms are relevant, otherwise only CASE I to IV :return: additional label """ @@ -374,6 +445,7 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): def _plot(self, score_only): """ Main plot function to plot climatological skill score. + :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms """ fig, ax = plt.subplots() @@ -394,7 +466,12 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): Create competitive skill score for the given model setup and the reference models ordinary least squared ("ols") and the persistence forecast ("persi") for all lead times ("ahead"). The plot is saved under plot_folder with the name skill_score_competitive_{model_setup}.pdf and resolution of 500dpi. + + .. image:: ../../../../../_source/_plots/skill_score_competitive.png + :width: 400 + """ + def __init__(self, data: pd.DataFrame, plot_folder=".", model_setup="CNN"): """ :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre- @@ -411,6 +488,7 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame: """ Reformat given data and create plot labels. Introduces the dimensions stations and comparison + :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre- calculated comparisons for cnn, persistence and ols. :return: processed data @@ -442,6 +520,7 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): """ Calculate y-axis limits from data. Lower is the minimum of either 0 or data's minimum (reduced by small subtrahend) and upper limit is data's maximum (increased by a small addend). + :return: """ lower = np.min([0, helpers.float_round(self._data.min()[2], 2) - 0.1]) @@ -457,11 +536,16 @@ class PlotBootstrapSkillScore(AbstractPlotClass): term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed (score_only=True, default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. + + .. image:: ../../../../../_source/_plots/skill_score_bootstrap.png + :width: 400 + """ def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = ""): """ Sets attributes and create plot + :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. :param plot_folder: path to save the plot (default: current directory) :param model_setup: architecture type to specify plot name (default "CNN") @@ -477,6 +561,7 @@ class PlotBootstrapSkillScore(AbstractPlotClass): """ Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set plot labels depending on the lead time dimensions. + :param data: dictionary with station names as keys and 2D xarrays as values :return: pre-processed data set """ @@ -487,6 +572,7 @@ class PlotBootstrapSkillScore(AbstractPlotClass): def _label_add(self, score_only: bool): """ Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True). + :param score_only: if false all terms are relevant, otherwise only CASE I to IV :return: additional label """ @@ -530,6 +616,7 @@ class PlotTimeSeries: Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number of ahead dimensions in data is lower than the given lead time, data's lead time is used. + :param window_lead_time: lead time from arguments to validate :return: validated lead time, comes either from given argument or from data itself """ @@ -598,7 +685,7 @@ class PlotTimeSeries: for ahead in data.coords["ahead"].values: plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze().shift(index=ahead) label = f"{ahead}{self._sampling}" - ax.plot(plot_data, color=color[ahead-1], label=label) + ax.plot(plot_data, color=color[ahead - 1], label=label) def _plot_obs(self, ax, data): ahead = 1 @@ -611,12 +698,14 @@ class PlotTimeSeries: def _get_time_range(data): def f(x, f_x): return pd.to_datetime(f_x(x.index.values)).year + return f(data, min), f(data, max) @staticmethod def _create_pdf_pages(plot_folder): """ Standard save method to store plot locally. The name of this plot is static. + :param plot_folder: path to save the plot """ plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf') @@ -626,6 +715,17 @@ class PlotTimeSeries: @TimeTrackingWrapper class PlotAvailability(AbstractPlotClass): + """ + .. image:: ../../../../../_source/_plots/data_availability.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_summary.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_combined.png + :width: 400 + + """ def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily", summary_name="data availability"): @@ -634,17 +734,17 @@ class PlotAvailability(AbstractPlotClass): self.sampling = self._get_sampling(sampling) plot_dict = self._prepare_data(generators) lgd = self._plot(plot_dict) - self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") # create summary Gantt plot (is data in at least one station available) self.plot_name += "_summary" plot_dict_summary = self._summarise_data(generators, summary_name) lgd = self._plot(plot_dict_summary) - self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") # combination of station and summary plot, last element is summary broken bar self.plot_name = "data_availability_combined" plot_dict_summary.update(plot_dict) lgd = self._plot(plot_dict_summary) - self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") @staticmethod def _get_sampling(sampling): @@ -662,7 +762,8 @@ class PlotAvailability(AbstractPlotClass): labels = station_data.get_transposed_label().resample(datetime=self.sampling, skipna=True).mean() labels_bool = labels.sel(window=1).notnull() group = (labels_bool != labels_bool.shift(datetime=1)).cumsum() - plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, index=labels.datetime.values) + plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, + index=labels.datetime.values) t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) t2 = [i[1:] for i in t if i[0]] @@ -685,7 +786,8 @@ class PlotAvailability(AbstractPlotClass): all_data = labels_bool else: tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords - all_data = np.logical_or(tmp, labels_bool).combine_first(all_data) # apply logical on merge and fill missing with all_data + all_data = np.logical_or(tmp, labels_bool).combine_first( + all_data) # apply logical on merge and fill missing with all_data group = (all_data != all_data.shift(datetime=1)).cumsum() plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, index=all_data.datetime.values) @@ -697,7 +799,6 @@ class PlotAvailability(AbstractPlotClass): plt_dict[summary_name].update({subset: t2}) return plt_dict - def _plot(self, plt_dict): # colors = {"train": "orange", "val": "blueishgreen", "test": "skyblue"} # color names colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"} # hex code @@ -706,7 +807,7 @@ class PlotAvailability(AbstractPlotClass): height = 0.8 # should be <= 1 yticklabels = [] number_of_stations = len(plt_dict.keys()) - fig, ax = plt.subplots(figsize=(10, number_of_stations/3)) + fig, ax = plt.subplots(figsize=(10, number_of_stations / 3)) for station, d in sorted(plt_dict.items(), reverse=True): pos += 1 for subset, color in colors.items(): @@ -717,7 +818,7 @@ class PlotAvailability(AbstractPlotClass): yticklabels.append(station) ax.set_ylim([height, number_of_stations + 1]) - ax.set_yticks(np.arange(len(plt_dict.keys()))+1+height/2) + ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2) ax.set_yticklabels(yticklabels) handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items()] lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py index 7e656895c5eecdabe1ef26869b68fb9494ed4c8c..473b966ce52ee7e2885bc14beef2e68b8835b15e 100644 --- a/src/plotting/training_monitoring.py +++ b/src/plotting/training_monitoring.py @@ -1,7 +1,8 @@ +"""Plots to monitor training.""" + __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-12-11' - from typing import Union, Dict, List import keras @@ -18,15 +19,18 @@ lr_object = Union[Dict, LearningRateDecay] class PlotModelHistory: """ - Plots history of all plot_metrics (default: loss) for a training event. For default plot_metric and val_plot_metric - are plotted. If further metrics are provided (name must somehow include the word `<plot_metric>`), this additional - information is added to the plot with an separate y-axis scale on the right side (shared for all additional - metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute - path for the plot. + Plot history of all plot_metrics (default: loss) for a training event. + + For default plot_metric and val_plot_metric are plotted. If further metrics are provided (name must somehow include + the word `<plot_metric>`), this additional information is added to the plot with an separate y-axis scale on the + right side (shared for all additional metrics). The plot is saved locally. For a proper saving behaviour, the + parameter filename must include the absolute path for the plot. """ + def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False): """ - Sets attributes and create plot + Set attributes and create plot. + :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a format ending like .pdf or .png to work. :param history: the history object (or a dict with at least 'loss' and 'val_loss' as keys) to plot loss from @@ -47,16 +51,20 @@ class PlotModelHistory: plot_metric = "mean_squared_error" elif plot_metric.lower() == "mae": plot_metric = "mean_absolute_error" - available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k.lower() if main_branch else True)] + available_keys = [k for k in history.keys() if + plot_metric in k and ("main" in k.lower() if main_branch else True)] available_keys.sort(key=len) return available_keys[0] def _filter_columns(self, history: Dict) -> List[str]: """ - Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are - also removed. + Select only columns named like %<plot_metric>%. + + The default metrics '<plot_metric>' and 'val_<plot_metric>' are removed too. + :param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras History.history) + :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>. """ cols = list(filter(lambda x: self._plot_metric in x, history.keys())) @@ -69,8 +77,11 @@ class PlotModelHistory: def _plot(self, filename: str) -> None: """ - Actual plot routine. Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided, - they will be added with an additional yaxis on the right side. The plot is saved in filename. + Create plot. + + Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided, they will be added with + an additional yaxis on the right side. The plot is saved in filename. + :param filename: name (including total path) of the plot to save. """ ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7) @@ -86,12 +97,16 @@ class PlotModelHistory: class PlotModelLearningRate: """ - Plots the behaviour of the learning rate in dependence of the number of epochs. The plot is saved locally as pdf. - For a proper saving behaviour, the parameter filename must include the absolute path for the plot. + Plot the behaviour of the learning rate in dependence of the number of epochs. + + The plot is saved locally as pdf. For a proper saving behaviour, the parameter filename must include the absolute + path for the plot. """ + def __init__(self, filename: str, lr_sc: lr_object): """ - Sets attributes and create plot + Set attributes and create plot. + :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a format ending like .pdf or .png to work. :param lr_sc: the learning rate object (or a dict with `lr` as key) to plot from @@ -103,7 +118,10 @@ class PlotModelLearningRate: def _plot(self, filename: str) -> None: """ - Actual plot routine. Plots the learning rate in dependence of epoch. + Create plot. + + Plot the learning rate in dependence of epoch. + :param filename: name (including total path) of the plot to save. """ ax = self._data.plot(linewidth=0.7)