From 9bf9966b4786389ccd73c10084cd1979475034be Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 14 Apr 2020 18:50:08 +0200
Subject: [PATCH] legend on top

---
 src/plotting/postprocessing_plotting.py | 31 ++++++++++++++-----------
 1 file changed, 17 insertions(+), 14 deletions(-)

diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py
index 175f7053..14e3074a 100644
--- a/src/plotting/postprocessing_plotting.py
+++ b/src/plotting/postprocessing_plotting.py
@@ -16,6 +16,7 @@ import pandas as pd
 import seaborn as sns
 import xarray as xr
 from matplotlib.backends.backend_pdf import PdfPages
+import matplotlib.patches as mpatches
 
 from src import helpers
 from src.helpers import TimeTracking, TimeTrackingWrapper
@@ -33,14 +34,14 @@ class AbstractPlotClass:
 
     def _plot(self, *args):
         raise NotImplementedError
-    
-    def _save(self):
+
+    def _save(self, **kwargs):
         """
         Standard save method to store plot locally. Name of and path to plot need to be set on initialisation
         """
         plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf")
         logging.debug(f"... save plot to {plot_name}")
-        plt.savefig(plot_name, dpi=self.resolution)
+        plt.savefig(plot_name, dpi=self.resolution, **kwargs)
         plt.close('all')
 
 
@@ -632,18 +633,18 @@ class PlotAvailability(AbstractPlotClass):
         super().__init__(plot_folder, "data_availability")
         self.sampling = self._get_sampling(sampling)
         plot_dict = self._prepare_data(generators)
-        self._plot(plot_dict)
-        self._save()
+        lgd = self._plot(plot_dict)
+        self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight")
         # create summary Gantt plot (is data in at least one station available)
         self.plot_name += "_summary"
         plot_dict_summary = self._summarise_data(generators, summary_name)
-        self._plot(plot_dict_summary)
-        self._save()
+        lgd = self._plot(plot_dict_summary)
+        self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight")
         # combination of station and summary plot, last element is summary broken bar
         self.plot_name = "data_availability_combined"
         plot_dict_summary.update(plot_dict)
-        self._plot(plot_dict_summary)
-        self._save()
+        lgd = self._plot(plot_dict_summary)
+        self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight")
 
     @staticmethod
     def _get_sampling(sampling):
@@ -698,14 +699,14 @@ class PlotAvailability(AbstractPlotClass):
 
 
     def _plot(self, plt_dict):
-        # colors = {"train": "orange", "val": "skyblue", "test": "blueishgreen"}  # color names
-        colors = {"train": "#e69f00", "val": "#56b4e9", "test": "#009e73"}  # hex code
-        # colors = {"train": (230, 159, 0), "val": (86, 180, 233), "test": (0, 158, 115)}  # in rgb but as abs values
+        # colors = {"train": "orange", "val": "blueishgreen", "test": "skyblue"}  # color names
+        colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"}  # hex code
+        # colors = {"train": (230, 159, 0), "val": (0, 158, 115), "test": (86, 180, 233)}  # in rgb but as abs values
         pos = 0
         height = 0.8  # should be <= 1
         yticklabels = []
         number_of_stations = len(plt_dict.keys())
-        fig, ax = plt.subplots(figsize=(10, max([number_of_stations/3, 1])))
+        fig, ax = plt.subplots(figsize=(10, number_of_stations/3))
         for station, d in sorted(plt_dict.items(), reverse=True):
             pos += 1
             for subset, color in colors.items():
@@ -718,4 +719,6 @@ class PlotAvailability(AbstractPlotClass):
         ax.set_ylim([height, number_of_stations + 1])
         ax.set_yticks(np.arange(len(plt_dict.keys()))+1+height/2)
         ax.set_yticklabels(yticklabels)
-        plt.tight_layout()
+        handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items()]
+        lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
+        return lgd
-- 
GitLab