Skip to content
Snippets Groups Projects
Commit 9bf9966b authored by lukas leufen's avatar lukas leufen
Browse files

legend on top

parent 45ce6195
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!88added summary and combined Gantt plot
Pipeline #34335 passed
...@@ -16,6 +16,7 @@ import pandas as pd ...@@ -16,6 +16,7 @@ import pandas as pd
import seaborn as sns import seaborn as sns
import xarray as xr import xarray as xr
from matplotlib.backends.backend_pdf import PdfPages from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.patches as mpatches
from src import helpers from src import helpers
from src.helpers import TimeTracking, TimeTrackingWrapper from src.helpers import TimeTracking, TimeTrackingWrapper
...@@ -34,13 +35,13 @@ class AbstractPlotClass: ...@@ -34,13 +35,13 @@ class AbstractPlotClass:
def _plot(self, *args): def _plot(self, *args):
raise NotImplementedError raise NotImplementedError
def _save(self): def _save(self, **kwargs):
""" """
Standard save method to store plot locally. Name of and path to plot need to be set on initialisation Standard save method to store plot locally. Name of and path to plot need to be set on initialisation
""" """
plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf") plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf")
logging.debug(f"... save plot to {plot_name}") logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=self.resolution) plt.savefig(plot_name, dpi=self.resolution, **kwargs)
plt.close('all') plt.close('all')
...@@ -632,18 +633,18 @@ class PlotAvailability(AbstractPlotClass): ...@@ -632,18 +633,18 @@ class PlotAvailability(AbstractPlotClass):
super().__init__(plot_folder, "data_availability") super().__init__(plot_folder, "data_availability")
self.sampling = self._get_sampling(sampling) self.sampling = self._get_sampling(sampling)
plot_dict = self._prepare_data(generators) plot_dict = self._prepare_data(generators)
self._plot(plot_dict) lgd = self._plot(plot_dict)
self._save() self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight")
# create summary Gantt plot (is data in at least one station available) # create summary Gantt plot (is data in at least one station available)
self.plot_name += "_summary" self.plot_name += "_summary"
plot_dict_summary = self._summarise_data(generators, summary_name) plot_dict_summary = self._summarise_data(generators, summary_name)
self._plot(plot_dict_summary) lgd = self._plot(plot_dict_summary)
self._save() self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight")
# combination of station and summary plot, last element is summary broken bar # combination of station and summary plot, last element is summary broken bar
self.plot_name = "data_availability_combined" self.plot_name = "data_availability_combined"
plot_dict_summary.update(plot_dict) plot_dict_summary.update(plot_dict)
self._plot(plot_dict_summary) lgd = self._plot(plot_dict_summary)
self._save() self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight")
@staticmethod @staticmethod
def _get_sampling(sampling): def _get_sampling(sampling):
...@@ -698,14 +699,14 @@ class PlotAvailability(AbstractPlotClass): ...@@ -698,14 +699,14 @@ class PlotAvailability(AbstractPlotClass):
def _plot(self, plt_dict): def _plot(self, plt_dict):
# colors = {"train": "orange", "val": "skyblue", "test": "blueishgreen"} # color names # colors = {"train": "orange", "val": "blueishgreen", "test": "skyblue"} # color names
colors = {"train": "#e69f00", "val": "#56b4e9", "test": "#009e73"} # hex code colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"} # hex code
# colors = {"train": (230, 159, 0), "val": (86, 180, 233), "test": (0, 158, 115)} # in rgb but as abs values # colors = {"train": (230, 159, 0), "val": (0, 158, 115), "test": (86, 180, 233)} # in rgb but as abs values
pos = 0 pos = 0
height = 0.8 # should be <= 1 height = 0.8 # should be <= 1
yticklabels = [] yticklabels = []
number_of_stations = len(plt_dict.keys()) number_of_stations = len(plt_dict.keys())
fig, ax = plt.subplots(figsize=(10, max([number_of_stations/3, 1]))) fig, ax = plt.subplots(figsize=(10, number_of_stations/3))
for station, d in sorted(plt_dict.items(), reverse=True): for station, d in sorted(plt_dict.items(), reverse=True):
pos += 1 pos += 1
for subset, color in colors.items(): for subset, color in colors.items():
...@@ -718,4 +719,6 @@ class PlotAvailability(AbstractPlotClass): ...@@ -718,4 +719,6 @@ class PlotAvailability(AbstractPlotClass):
ax.set_ylim([height, number_of_stations + 1]) ax.set_ylim([height, number_of_stations + 1])
ax.set_yticks(np.arange(len(plt_dict.keys()))+1+height/2) ax.set_yticks(np.arange(len(plt_dict.keys()))+1+height/2)
ax.set_yticklabels(yticklabels) ax.set_yticklabels(yticklabels)
plt.tight_layout() handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items()]
lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
return lgd
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment