diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 75450d939a9d56a87cd7ba955299807b7c91bfee..a62b44bb4305ee14e68510bdbf61318111bda40f 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -16,9 +16,10 @@ import seaborn as sns import xarray as xr from matplotlib.backends.backend_pdf import PdfPages import matplotlib.patches as mpatches +import matplotlib.dates as mdates from src import helpers -from src.helpers import TimeTracking, TimeTrackingWrapper +from src.helpers import TimeTrackingWrapper from src.data_handling.data_generator import DataGenerator logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -43,6 +44,21 @@ class AbstractPlotClass: plt.savefig(plot_name, dpi=self.resolution, **kwargs) plt.close('all') + @staticmethod + def _get_sampling(sampling): + if sampling == "daily": + return "D" + elif sampling == "hourly": + return "h" + + @staticmethod + def get_dataset_colors(): + """ + Standard colors used for train-, val-, and test-sets during postprocessing + """ + colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"} # hex code + return colors + @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): @@ -773,12 +789,6 @@ class PlotAvailability(AbstractPlotClass): lgd = self._plot(plot_dict_summary) self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") - @staticmethod - def _get_sampling(sampling): - if sampling == "daily": - return "D" - elif sampling == "hourly": - return "h" def _prepare_data(self, generators: Dict[str, DataGenerator]): plt_dict = {} @@ -825,9 +835,7 @@ class PlotAvailability(AbstractPlotClass): return plt_dict def _plot(self, plt_dict): - # colors = {"train": "orange", "val": "blueishgreen", "test": "skyblue"} # color names - colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"} # hex code - # colors = {"train": (230, 159, 0), "val": (0, 158, 115), "test": (86, 180, 233)} # in rgb but as abs values + colors = self.get_dataset_colors() pos = 0 height = 0.8 # should be <= 1 yticklabels = [] @@ -850,6 +858,87 @@ class PlotAvailability(AbstractPlotClass): return lgd +@TimeTrackingWrapper +class PlotAvailabilityHistogram(AbstractPlotClass): + """ + + + """ + + def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily", + summary_name="data availability"): + + super().__init__(plot_folder, "data_availability_histogram") + self.freq = self._get_sampling(sampling) + self._prepare_data(generators) + self._plot() + self._save() + + def _prepare_data(self, generators: Dict[str, DataGenerator]): + avail_dict = {} + dataset_time_interval={} + for subset, generator in generators.items(): + avail_list = [] + for station in generator.stations: + station_data_X, _ = generator[station] + station_data_X = station_data_X.loc[{'window': 0, # select recent window frame + generator.target_dim: generator.variables[0]}] + avail_list.append(station_data_X.notnull()) + avail_dict[subset] = xr.concat(avail_list, dim='Stations').notnull().sum(dim='Stations') + dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray( + avail_dict[subset], dim_name='datetime', return_type='as_dict' + ) + avail_data_amount = xr.concat(avail_dict.values(), pd.Index(avail_dict.keys(), name='DataSet')) + full_time_index = self._make_full_time_index(avail_data_amount.coords['datetime'].values, freq=self.freq) + self.avail_data_amount = avail_data_amount.reindex({'datetime': full_time_index}, fill_value=0.) + self.dataset_time_interval = dataset_time_interval + + @staticmethod + def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'): + if isinstance(xarray, xr.DataArray): + first = xarray.coords[dim_name].values[0] + last = xarray.coords[dim_name].values[-1] + if return_type == 'as_tuple': + return first, last + elif return_type == 'as_dict': + return {'first': first, 'last': last} + else: + raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'") + else: + raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}") + + @staticmethod + def _make_full_time_index(irregular_time_index, freq): + full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq) + return full_time_index + + def _plot(self, *args): + # for dataset in + colors = self.get_dataset_colors() + fig, axes = plt.subplots(figsize=(10,3)) + for i, subset in enumerate(self.dataset_time_interval.keys()): + plot_dataset = self.avail_data_amount.sel({'DataSet': subset, + 'datetime': slice(self.dataset_time_interval[subset]['first'], + self.dataset_time_interval[subset]['last'] + ) + } + ) + + plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset) + plt.fill_between(plot_dataset.coords['datetime'].values, plot_dataset.values, color=colors[subset]) + + lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval)) + for lgd_line in lgd.get_lines(): + lgd_line.set_linewidth(4.0) + plt.gca().xaxis.set_major_locator(mdates.YearLocator()) + plt.title('') + plt.ylabel('Number of samples') + plt.tight_layout() + + + + + if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] path = "../../testrun_network/forecasts" diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 9e74582d7f9f80881206b2a2a92d682dfcc75dc8..1bb230ada57d818e22ac01b583db9431d3b621cb 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -20,7 +20,8 @@ from src.helpers import TimeTracking from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src.model_modules.model_class import AbstractModelClass from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles + PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles, \ + PlotAvailabilityHistogram # from src.plotting.postprocessing_plotting import plot_conditional_quantiles from src.run_modules.run_environment import RunEnvironment @@ -220,6 +221,10 @@ class PostProcessing(RunEnvironment): if "PlotAvailability" in plot_list: avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} PlotAvailability(avail_data, plot_folder=self.plot_path) + if "PlotAvailabilityHistogram" in plot_list: + avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} + PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path) + def calculate_test_score(self): test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),