Skip to content
Snippets Groups Projects
Commit b90b36f4 authored by lukas leufen's avatar lukas leufen
Browse files

added availability plot

parent dfdea269
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!87Resolve "Data summary plot for usage of data (time-series)"
Pipeline #34144 passed with warnings
...@@ -19,7 +19,7 @@ from matplotlib.backends.backend_pdf import PdfPages ...@@ -19,7 +19,7 @@ from matplotlib.backends.backend_pdf import PdfPages
from src import helpers from src import helpers
from src.helpers import TimeTracking, TimeTrackingWrapper from src.helpers import TimeTracking, TimeTrackingWrapper
from src.run_modules.run_environment import RunEnvironment from src.data_handling.data_generator import DataGenerator
logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING)
...@@ -621,3 +621,67 @@ class PlotTimeSeries: ...@@ -621,3 +621,67 @@ class PlotTimeSeries:
plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf') plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf')
logging.debug(f"... save plot to {plot_name}") logging.debug(f"... save plot to {plot_name}")
return matplotlib.backends.backend_pdf.PdfPages(plot_name) return matplotlib.backends.backend_pdf.PdfPages(plot_name)
@TimeTrackingWrapper
class PlotAvailability(AbstractPlotClass):
def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily"):
super().__init__(plot_folder, "data_availability")
self.sampling = self._get_sampling(sampling)
plot_dict = self._prepare_data(generators)
self._plot(plot_dict)
self._save()
@staticmethod
def _get_sampling(sampling):
if sampling == "daily":
return "D"
elif sampling == "hourly":
return "h"
def _prepare_data(self, generators: Dict[str, DataGenerator]):
plt_dict = {}
for subset, generator in generators.items():
stations = generator.stations
for station in stations:
station_data = generator.get_data_generator(station)
labels = station_data.get_transposed_label().resample(datetime=self.sampling, skipna=True).mean()
labels_bool = labels.sel(window=1).notnull()
group = (labels_bool != labels_bool.shift(datetime=1)).cumsum()
plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, index=labels.datetime.values)
t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
t2 = [i[1:] for i in t if i[0]]
if plt_dict.get(station) is None:
plt_dict[station] = {subset: t2}
else:
plt_dict[station].update({subset: t2})
return plt_dict
def _plot(self, plt_dict):
# colors = {"train": "orange", "val": "skyblue", "test": "blueishgreen"}
colors = {"train": "#e69f00", "val": "#56b4e9", "test": "#009e73"}
# colors = {"train": (230, 159, 0), "val": (86, 180, 233), "test": (0, 158, 115)}
pos = 0
count = 0
height = 0.8 # should be <= 1
yticklabels = []
number_of_stations = len(plt_dict.keys())
fig, ax = plt.subplots(figsize=(10, number_of_stations/3))
for station, d in sorted(plt_dict.items(), reverse=True):
pos += 1
for subset, color in colors.items():
plt_data = d.get(subset)
if plt_data is None:
continue
ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white")
yticklabels.append(station)
ax.set_ylim([height, number_of_stations + 1])
ax.set_yticks(np.arange(len(plt_dict.keys()))+1+height/2)
ax.set_yticklabels(yticklabels)
plt.savefig(f"test{count}.png")
plt.tight_layout()
...@@ -20,7 +20,7 @@ from src.helpers import TimeTracking ...@@ -20,7 +20,7 @@ from src.helpers import TimeTracking
from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src.model_modules.linear_model import OrdinaryLeastSquaredModel
from src.model_modules.model_class import AbstractModelClass from src.model_modules.model_class import AbstractModelClass
from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability
from src.plotting.postprocessing_plotting import plot_conditional_quantiles from src.plotting.postprocessing_plotting import plot_conditional_quantiles
from src.run_modules.run_environment import RunEnvironment from src.run_modules.run_environment import RunEnvironment
...@@ -37,6 +37,7 @@ class PostProcessing(RunEnvironment): ...@@ -37,6 +37,7 @@ class PostProcessing(RunEnvironment):
self.test_data: DataGenerator = self.data_store.get("generator", "test") self.test_data: DataGenerator = self.data_store.get("generator", "test")
self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size) self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size)
self.train_data: DataGenerator = self.data_store.get("generator", "train") self.train_data: DataGenerator = self.data_store.get("generator", "train")
self.val_data: DataGenerator = self.data_store.get("generator", "val")
self.train_val_data: DataGenerator = self.data_store.get("generator", "train_val") self.train_val_data: DataGenerator = self.data_store.get("generator", "train_val")
self.plot_path: str = self.data_store.get("plot_path") self.plot_path: str = self.data_store.get("plot_path")
self.target_var = self.data_store.get("target_var") self.target_var = self.data_store.get("target_var")
...@@ -213,6 +214,8 @@ class PostProcessing(RunEnvironment): ...@@ -213,6 +214,8 @@ class PostProcessing(RunEnvironment):
if "PlotTimeSeries" in plot_list: if "PlotTimeSeries" in plot_list:
PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path, PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path,
sampling=self._sampling) sampling=self._sampling)
avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
PlotAvailability(avail_data, plot_folder=self.plot_path)
def calculate_test_score(self): def calculate_test_score(self):
test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(), test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment