Skip to content
Snippets Groups Projects
Commit 9bf9966b authored by lukas leufen's avatar lukas leufen
Browse files

legend on top

parent 45ce6195
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!88added summary and combined Gantt plot
Pipeline #34335 passed
......@@ -16,6 +16,7 @@ import pandas as pd
import seaborn as sns
import xarray as xr
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.patches as mpatches
from src import helpers
from src.helpers import TimeTracking, TimeTrackingWrapper
......@@ -33,14 +34,14 @@ class AbstractPlotClass:
def _plot(self, *args):
raise NotImplementedError
def _save(self):
def _save(self, **kwargs):
"""
Standard save method to store plot locally. Name of and path to plot need to be set on initialisation
"""
plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=self.resolution)
plt.savefig(plot_name, dpi=self.resolution, **kwargs)
plt.close('all')
......@@ -632,18 +633,18 @@ class PlotAvailability(AbstractPlotClass):
super().__init__(plot_folder, "data_availability")
self.sampling = self._get_sampling(sampling)
plot_dict = self._prepare_data(generators)
self._plot(plot_dict)
self._save()
lgd = self._plot(plot_dict)
self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight")
# create summary Gantt plot (is data in at least one station available)
self.plot_name += "_summary"
plot_dict_summary = self._summarise_data(generators, summary_name)
self._plot(plot_dict_summary)
self._save()
lgd = self._plot(plot_dict_summary)
self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight")
# combination of station and summary plot, last element is summary broken bar
self.plot_name = "data_availability_combined"
plot_dict_summary.update(plot_dict)
self._plot(plot_dict_summary)
self._save()
lgd = self._plot(plot_dict_summary)
self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight")
@staticmethod
def _get_sampling(sampling):
......@@ -698,14 +699,14 @@ class PlotAvailability(AbstractPlotClass):
def _plot(self, plt_dict):
# colors = {"train": "orange", "val": "skyblue", "test": "blueishgreen"} # color names
colors = {"train": "#e69f00", "val": "#56b4e9", "test": "#009e73"} # hex code
# colors = {"train": (230, 159, 0), "val": (86, 180, 233), "test": (0, 158, 115)} # in rgb but as abs values
# colors = {"train": "orange", "val": "blueishgreen", "test": "skyblue"} # color names
colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"} # hex code
# colors = {"train": (230, 159, 0), "val": (0, 158, 115), "test": (86, 180, 233)} # in rgb but as abs values
pos = 0
height = 0.8 # should be <= 1
yticklabels = []
number_of_stations = len(plt_dict.keys())
fig, ax = plt.subplots(figsize=(10, max([number_of_stations/3, 1])))
fig, ax = plt.subplots(figsize=(10, number_of_stations/3))
for station, d in sorted(plt_dict.items(), reverse=True):
pos += 1
for subset, color in colors.items():
......@@ -718,4 +719,6 @@ class PlotAvailability(AbstractPlotClass):
ax.set_ylim([height, number_of_stations + 1])
ax.set_yticks(np.arange(len(plt_dict.keys()))+1+height/2)
ax.set_yticklabels(yticklabels)
plt.tight_layout()
handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items()]
lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
return lgd
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment