diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 48606d4f8531812672391304b41885555608473b..175f70537f995f7e2eff84d2fb5e8275dc692c4d 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -626,12 +626,24 @@ class PlotTimeSeries: @TimeTrackingWrapper class PlotAvailability(AbstractPlotClass): - def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily"): + def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily", + summary_name="data availability"): + # create standard Gantt plot for all stations (currently in single pdf file with single page) super().__init__(plot_folder, "data_availability") self.sampling = self._get_sampling(sampling) plot_dict = self._prepare_data(generators) self._plot(plot_dict) self._save() + # create summary Gantt plot (is data in at least one station available) + self.plot_name += "_summary" + plot_dict_summary = self._summarise_data(generators, summary_name) + self._plot(plot_dict_summary) + self._save() + # combination of station and summary plot, last element is summary broken bar + self.plot_name = "data_availability_combined" + plot_dict_summary.update(plot_dict) + self._plot(plot_dict_summary) + self._save() @staticmethod def _get_sampling(sampling): @@ -659,16 +671,41 @@ class PlotAvailability(AbstractPlotClass): plt_dict[station].update({subset: t2}) return plt_dict + def _summarise_data(self, generators: Dict[str, DataGenerator], summary_name: str): + plt_dict = {} + for subset, generator in generators.items(): + all_data = None + stations = generator.stations + for station in stations: + station_data = generator.get_data_generator(station) + labels = station_data.get_transposed_label().resample(datetime=self.sampling, skipna=True).mean() + labels_bool = labels.sel(window=1).notnull() + if all_data is None: + all_data = labels_bool + else: + tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords + all_data = np.logical_or(tmp, labels_bool).combine_first(all_data) # apply logical on merge and fill missing with all_data + + group = (all_data != all_data.shift(datetime=1)).cumsum() + plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, index=all_data.datetime.values) + t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) + t2 = [i[1:] for i in t if i[0]] + if plt_dict.get(summary_name) is None: + plt_dict[summary_name] = {subset: t2} + else: + plt_dict[summary_name].update({subset: t2}) + return plt_dict + + def _plot(self, plt_dict): - # colors = {"train": "orange", "val": "skyblue", "test": "blueishgreen"} - colors = {"train": "#e69f00", "val": "#56b4e9", "test": "#009e73"} - # colors = {"train": (230, 159, 0), "val": (86, 180, 233), "test": (0, 158, 115)} + # colors = {"train": "orange", "val": "skyblue", "test": "blueishgreen"} # color names + colors = {"train": "#e69f00", "val": "#56b4e9", "test": "#009e73"} # hex code + # colors = {"train": (230, 159, 0), "val": (86, 180, 233), "test": (0, 158, 115)} # in rgb but as abs values pos = 0 - count = 0 height = 0.8 # should be <= 1 yticklabels = [] number_of_stations = len(plt_dict.keys()) - fig, ax = plt.subplots(figsize=(10, number_of_stations/3)) + fig, ax = plt.subplots(figsize=(10, max([number_of_stations/3, 1]))) for station, d in sorted(plt_dict.items(), reverse=True): pos += 1 for subset, color in colors.items():