From 33433be319f08a79f5a30d3f34a21808269e9c56 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 14 Apr 2020 13:22:20 +0200
Subject: [PATCH] added summary and combined Gantt plot, also name of summary
 can be set from outside

---
 src/plotting/postprocessing_plotting.py | 49 ++++++++++++++++++++++---
 1 file changed, 43 insertions(+), 6 deletions(-)

diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py
index 48606d4f..175f7053 100644
--- a/src/plotting/postprocessing_plotting.py
+++ b/src/plotting/postprocessing_plotting.py
@@ -626,12 +626,24 @@ class PlotTimeSeries:
 @TimeTrackingWrapper
 class PlotAvailability(AbstractPlotClass):
 
-    def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily"):
+    def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily",
+                 summary_name="data availability"):
+        # create standard Gantt plot for all stations (currently in single pdf file with single page)
         super().__init__(plot_folder, "data_availability")
         self.sampling = self._get_sampling(sampling)
         plot_dict = self._prepare_data(generators)
         self._plot(plot_dict)
         self._save()
+        # create summary Gantt plot (is data in at least one station available)
+        self.plot_name += "_summary"
+        plot_dict_summary = self._summarise_data(generators, summary_name)
+        self._plot(plot_dict_summary)
+        self._save()
+        # combination of station and summary plot, last element is summary broken bar
+        self.plot_name = "data_availability_combined"
+        plot_dict_summary.update(plot_dict)
+        self._plot(plot_dict_summary)
+        self._save()
 
     @staticmethod
     def _get_sampling(sampling):
@@ -659,16 +671,41 @@ class PlotAvailability(AbstractPlotClass):
                     plt_dict[station].update({subset: t2})
         return plt_dict
 
+    def _summarise_data(self, generators: Dict[str, DataGenerator], summary_name: str):
+        plt_dict = {}
+        for subset, generator in generators.items():
+            all_data = None
+            stations = generator.stations
+            for station in stations:
+                station_data = generator.get_data_generator(station)
+                labels = station_data.get_transposed_label().resample(datetime=self.sampling, skipna=True).mean()
+                labels_bool = labels.sel(window=1).notnull()
+                if all_data is None:
+                    all_data = labels_bool
+                else:
+                    tmp = all_data.combine_first(labels_bool)  # expand dims to merged datetime coords
+                    all_data = np.logical_or(tmp, labels_bool).combine_first(all_data)  # apply logical on merge and fill missing with all_data
+
+            group = (all_data != all_data.shift(datetime=1)).cumsum()
+            plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, index=all_data.datetime.values)
+            t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
+            t2 = [i[1:] for i in t if i[0]]
+            if plt_dict.get(summary_name) is None:
+                plt_dict[summary_name] = {subset: t2}
+            else:
+                plt_dict[summary_name].update({subset: t2})
+        return plt_dict
+
+
     def _plot(self, plt_dict):
-        # colors = {"train": "orange", "val": "skyblue", "test": "blueishgreen"}
-        colors = {"train": "#e69f00", "val": "#56b4e9", "test": "#009e73"}
-        # colors = {"train": (230, 159, 0), "val": (86, 180, 233), "test": (0, 158, 115)}
+        # colors = {"train": "orange", "val": "skyblue", "test": "blueishgreen"}  # color names
+        colors = {"train": "#e69f00", "val": "#56b4e9", "test": "#009e73"}  # hex code
+        # colors = {"train": (230, 159, 0), "val": (86, 180, 233), "test": (0, 158, 115)}  # in rgb but as abs values
         pos = 0
-        count = 0
         height = 0.8  # should be <= 1
         yticklabels = []
         number_of_stations = len(plt_dict.keys())
-        fig, ax = plt.subplots(figsize=(10, number_of_stations/3))
+        fig, ax = plt.subplots(figsize=(10, max([number_of_stations/3, 1])))
         for station, d in sorted(plt_dict.items(), reverse=True):
             pos += 1
             for subset, color in colors.items():
-- 
GitLab