From 9bf9966b4786389ccd73c10084cd1979475034be Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Tue, 14 Apr 2020 18:50:08 +0200 Subject: [PATCH] legend on top --- src/plotting/postprocessing_plotting.py | 31 ++++++++++++++----------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 175f7053..14e3074a 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -16,6 +16,7 @@ import pandas as pd import seaborn as sns import xarray as xr from matplotlib.backends.backend_pdf import PdfPages +import matplotlib.patches as mpatches from src import helpers from src.helpers import TimeTracking, TimeTrackingWrapper @@ -33,14 +34,14 @@ class AbstractPlotClass: def _plot(self, *args): raise NotImplementedError - - def _save(self): + + def _save(self, **kwargs): """ Standard save method to store plot locally. Name of and path to plot need to be set on initialisation """ plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf") logging.debug(f"... save plot to {plot_name}") - plt.savefig(plot_name, dpi=self.resolution) + plt.savefig(plot_name, dpi=self.resolution, **kwargs) plt.close('all') @@ -632,18 +633,18 @@ class PlotAvailability(AbstractPlotClass): super().__init__(plot_folder, "data_availability") self.sampling = self._get_sampling(sampling) plot_dict = self._prepare_data(generators) - self._plot(plot_dict) - self._save() + lgd = self._plot(plot_dict) + self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") # create summary Gantt plot (is data in at least one station available) self.plot_name += "_summary" plot_dict_summary = self._summarise_data(generators, summary_name) - self._plot(plot_dict_summary) - self._save() + lgd = self._plot(plot_dict_summary) + self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") # combination of station and summary plot, last element is summary broken bar self.plot_name = "data_availability_combined" plot_dict_summary.update(plot_dict) - self._plot(plot_dict_summary) - self._save() + lgd = self._plot(plot_dict_summary) + self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") @staticmethod def _get_sampling(sampling): @@ -698,14 +699,14 @@ class PlotAvailability(AbstractPlotClass): def _plot(self, plt_dict): - # colors = {"train": "orange", "val": "skyblue", "test": "blueishgreen"} # color names - colors = {"train": "#e69f00", "val": "#56b4e9", "test": "#009e73"} # hex code - # colors = {"train": (230, 159, 0), "val": (86, 180, 233), "test": (0, 158, 115)} # in rgb but as abs values + # colors = {"train": "orange", "val": "blueishgreen", "test": "skyblue"} # color names + colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"} # hex code + # colors = {"train": (230, 159, 0), "val": (0, 158, 115), "test": (86, 180, 233)} # in rgb but as abs values pos = 0 height = 0.8 # should be <= 1 yticklabels = [] number_of_stations = len(plt_dict.keys()) - fig, ax = plt.subplots(figsize=(10, max([number_of_stations/3, 1]))) + fig, ax = plt.subplots(figsize=(10, number_of_stations/3)) for station, d in sorted(plt_dict.items(), reverse=True): pos += 1 for subset, color in colors.items(): @@ -718,4 +719,6 @@ class PlotAvailability(AbstractPlotClass): ax.set_ylim([height, number_of_stations + 1]) ax.set_yticks(np.arange(len(plt_dict.keys()))+1+height/2) ax.set_yticklabels(yticklabels) - plt.tight_layout() + handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items()] + lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) + return lgd -- GitLab