Skip to content
Snippets Groups Projects

Resolve "release v1.4.0"

Merged Ghost User requested to merge release_v1.4.0 into master
4 files
+ 543
525
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 101
0
"""Abstract plot class that should be used for preprocessing and postprocessing plots."""
__author__ = "Lukas Leufen"
__date__ = '2021-04-13'
import logging
import os
from matplotlib import pyplot as plt
class AbstractPlotClass:
"""
Abstract class for all plotting routines to unify plot workflow.
Each inheritance requires a _plot method. Create a plot class like:
.. code-block:: python
class MyCustomPlot(AbstractPlotClass):
def __init__(self, plot_folder, *args, **kwargs):
super().__init__(plot_folder, "custom_plot_name")
self._data = self._prepare_data(*args, **kwargs)
self._plot(*args, **kwargs)
self._save()
def _prepare_data(*args, **kwargs):
<your custom data preparation>
return data
def _plot(*args, **kwargs):
<your custom plotting without saving>
The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are
using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be
set in super class initialisation).
Methods like the shown _prepare_data() are optional. The only method required to implement is _plot.
If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot
class. It will log the spent time if you call your plotting without saving the returned object.
.. code-block:: python
@TimeTrackingWrapper
class MyCustomPlot(AbstractPlotClass):
pass
Let's assume it takes a while to create this very special plot.
>>> MyCustomPlot()
INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss)
"""
def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None):
"""Set up plot folder and name, and plot resolution (default 500dpi)."""
plot_folder = os.path.abspath(plot_folder)
if not os.path.exists(plot_folder):
os.makedirs(plot_folder)
self.plot_folder = plot_folder
self.plot_name = plot_name
self.resolution = resolution
if rc_params is None:
rc_params = {'axes.labelsize': 'large',
'xtick.labelsize': 'large',
'ytick.labelsize': 'large',
'legend.fontsize': 'large',
'axes.titlesize': 'large',
}
self.rc_params = rc_params
self._update_rc_params()
def _plot(self, *args):
"""Abstract plot class needs to be implemented in inheritance."""
raise NotImplementedError
def _save(self, **kwargs):
"""Store plot locally. Name of and path to plot need to be set on initialisation."""
plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=self.resolution, **kwargs)
plt.close('all')
def _update_rc_params(self):
plt.rcParams.update(self.rc_params)
@staticmethod
def _get_sampling(sampling):
if sampling == "daily":
return "D"
elif sampling == "hourly":
return "h"
@staticmethod
def get_dataset_colors():
"""
Standard colors used for train-, val-, and test-sets during postprocessing
"""
colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code
return colors
Loading