Skip to content
Snippets Groups Projects

Resolve "seasonal mse stack plot"

Merged Ghost User requested to merge lukas_issue422_feat_seasonal-mse-stack-plot into develop
2 files
+ 98
1
Compare changes
  • Side-by-side
  • Inline

Files

@@ -1319,6 +1319,93 @@ class PlotTimeEvolutionMetric(AbstractPlotClass):
@@ -1319,6 +1319,93 @@ class PlotTimeEvolutionMetric(AbstractPlotClass):
self._save()
self._save()
 
@TimeTrackingWrapper
 
class PlotSeasonalMSEStack(AbstractPlotClass):
 
 
def __init__(self, data_path: str, plot_folder: str = ".", boot_dim="boots", ahead_dim="ahead",
 
sampling: str = "daily", error_measure: str = "MSE", error_unit: str = "ppb$^2$"):
 
"""Set attributes and create plot."""
 
super().__init__(plot_folder, "seasonal_mse_stack_plot")
 
self.plot_name_orig = "seasonal_mse_stack_plot"
 
self._data_path = data_path
 
self.season_dim = "season"
 
self.error_unit = error_unit
 
self.error_measure = error_measure
 
self._data = self._prepare_data(boot_dim, data_path)
 
for orientation in ["horizontal", "vertical"]:
 
for split_ahead in [True, False]:
 
self._plot(ahead_dim, split_ahead, sampling, orientation)
 
self._save(bbox_inches="tight")
 
 
def _prepare_data(self, boot_dim, data_path):
 
season_dim = self.season_dim
 
data = {}
 
for season in ["total", "DJF", "MAM", "JJA", "SON"]:
 
if season == "total":
 
file_name = "uncertainty_estimate_raw_results.nc"
 
else:
 
file_name = f"uncertainty_estimate_raw_results_{season}.nc"
 
with xr.open_dataarray(os.path.join(data_path, file_name)) as d:
 
data[season] = d
 
mean = {}
 
for season in data.keys():
 
mean[season] = data[season].mean(boot_dim)
 
xr_data = xr.Dataset(mean).to_array(season_dim)
 
factor = xr_data.sel({season_dim: "total"}) / xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}).sum(
 
season_dim)
 
return xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}) * factor
 
 
@staticmethod
 
def _get_target_sampling(sampling, pos):
 
sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
 
sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
 
return sampling, sampling_letter
 
 
def _plot(self, dim, split_ahead=True, sampling="daily", orientation="vertical"):
 
_, sampling_letter = self._get_target_sampling(sampling, 1)
 
if split_ahead is True:
 
self.plot_name = self.plot_name_orig + "_total_" + orientation
 
data = self._data.mean(dim)
 
if orientation == "vertical":
 
fig, ax = plt.subplots(1, 1)
 
data.to_pandas().T.plot.bar(ax=ax, stacked=True, cmap="Dark2", legend=False)
 
ax.xaxis.label.set_visible(False)
 
ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})")
 
else:
 
m = data.to_pandas().T.shape[0]
 
fig, ax = plt.subplots(1, 1, figsize=(6, m))
 
data.to_pandas().T.plot.barh(ax=ax, stacked=True, cmap="Dark2", legend=False)
 
ax.yaxis.label.set_visible(False)
 
ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})")
 
fig.legend(*ax.get_legend_handles_labels(), loc="upper center", ncol=4)
 
fig.tight_layout(rect=[0, 0, 1, 0.9])
 
else:
 
self.plot_name = self.plot_name_orig + "_" + orientation
 
data = self._data
 
n = len(data.coords[dim])
 
if orientation == "vertical":
 
fig, ax = plt.subplots(1, n, sharey=True)
 
for i, sel in enumerate(data.coords[dim].values):
 
data.sel({dim: sel}).to_pandas().T.plot.bar(ax=ax[i], stacked=True, cmap="Dark2", legend=False)
 
label = str(sel) + sampling_letter
 
ax[i].set_title(label)
 
ax[i].xaxis.label.set_visible(False)
 
ax[0].set_ylabel(f"{self.error_measure} (in {self.error_unit})")
 
fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4)
 
fig.tight_layout(rect=[0, 0, 1, 0.9])
 
else:
 
m = data.max(self.season_dim).shape
 
fig, ax = plt.subplots(n, 1, sharex=True, figsize=(6, np.prod(m) * 0.6))
 
for i, sel in enumerate(data.coords[dim].values):
 
data.sel({dim: sel}).to_pandas().T.plot.barh(ax=ax[i], stacked=True, cmap="Dark2", legend=False)
 
label = str(sel) + sampling_letter
 
ax[i].set_title(label)
 
ax[i].yaxis.label.set_visible(False)
 
ax[-1].set_xlabel(f"{self.error_measure} (in {self.error_unit})")
 
fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4)
 
fig.tight_layout(rect=[0, 0, 1, 0.95])
 
 
if __name__ == "__main__":
if __name__ == "__main__":
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
path = "../../testrun_network/forecasts"
path = "../../testrun_network/forecasts"
Loading