Skip to content
Snippets Groups Projects
Commit 72429383 authored by lukas leufen's avatar lukas leufen
Browse files

refac time series plot and add sampling rate

parent d07cf2ca
No related branches found
No related tags found
2 merge requests!37include new development,!36include using of hourly data
Pipeline #29312 passed
......@@ -481,14 +481,23 @@ class PlotCompetitiveSkillScore(RunEnvironment):
class PlotTimeSeries(RunEnvironment):
def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = "."):
def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = ".",
sampling="daily"):
super().__init__()
self._data_path = data_path
self._data_name = name
self._stations = stations
self._window_lead_time = self._get_window_lead_time(window_lead_time)
self._sampling = self._get_sampling(sampling)
self._plot(plot_folder)
@staticmethod
def _get_sampling(sampling):
if sampling == "daily":
return "D"
elif sampling == "hourly":
return "h"
def _get_window_lead_time(self, window_lead_time: int):
"""
Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from
......@@ -509,31 +518,66 @@ class PlotTimeSeries(RunEnvironment):
return data.sel(type=["CNN", "orig"])
def _plot(self, plot_folder):
pdf_pages = self._save_pdf_pages(plot_folder)
pdf_pages = self._create_pdf_pages(plot_folder)
start, end = self._get_time_range(self._load_data(self._stations[0]))
color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex()
for pos, station in enumerate(self._stations):
data = self._load_data(station)
f, axes = plt.subplots(end - start + 1, sharey=True, figsize=(40, 20))
fig, axes, factor = self._create_subplots(start, end)
nan_list = []
for i in range(end - start + 1):
data_year = data.sel(index=f"{start + i}")
orig_data = data_year.sel(type="orig", ahead=1).values
axes[i].plot(data_year.index + np.timedelta64(1, "D"), orig_data, color=color_palette[0], label="orig")
for ahead in data.coords["ahead"].values:
plot_data = data_year.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze()
axes[i].plot(plot_data.index + np.timedelta64(int(ahead), "D"), plot_data.values, color=color_palette[ahead], label=f"{ahead}d")
if np.isnan(orig_data).all():
nan_list.append(i)
for i_year in range(end - start + 1):
data_year = data.sel(index=f"{start + i_year}")
for i_half_of_year in range(factor):
pos = 2 * i_year + i_half_of_year
plot_data = self._create_plot_data(data_year, factor, i_half_of_year)
self._plot_orig(axes[pos], plot_data)
self._plot_ahead(axes[pos], plot_data)
if np.isnan(plot_data.values).all():
nan_list.append(pos)
self._clean_up_axes(nan_list, axes, fig)
self._save_page(station, pdf_pages)
pdf_pages.close()
plt.close('all')
@staticmethod
def _clean_up_axes(nan_list, axes, fig):
for i in reversed(nan_list):
f.delaxes(axes[i])
fig.delaxes(axes[i])
@staticmethod
def _save_page(station, pdf_pages):
plt.suptitle(station)
plt.legend()
plt.tight_layout()
pdf_pages.savefig(dpi=500)
pdf_pages.close()
plt.close('all')
@staticmethod
def _create_plot_data(data, factor, running_index):
if factor > 1:
if running_index == 0:
data = data.where(data["index.month"] < 7)
else:
data = data.where(data["index.month"] >= 7)
return data
def _create_subplots(self, start, end):
factor = 1
if self._sampling == "h":
factor = 2
f, ax = plt.subplots((end - start + 1) * factor, sharey=True, figsize=(50, 30))
return f, ax, factor
def _plot_ahead(self, ax, data):
color = sns.color_palette("Blues_d", self._window_lead_time).as_hex()
for ahead in data.coords["ahead"].values:
plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze()
index = plot_data.index + np.timedelta64(int(ahead), self._sampling)
label = f"{ahead}{self._sampling}"
ax.plot(index, plot_data.values, color=color[ahead-1], label=label)
def _plot_orig(self, ax, data):
orig_data = data.sel(type="orig", ahead=1)
index = data.index + np.timedelta64(1, self._sampling)
ax.plot(index, orig_data.values, color=matplotlib.colors.cnames["green"], label="orig")
@staticmethod
def _get_time_range(data):
......@@ -542,7 +586,7 @@ class PlotTimeSeries(RunEnvironment):
return f(data, min), f(data, max)
@staticmethod
def _save_pdf_pages(plot_folder):
def _create_pdf_pages(plot_folder):
"""
Standard save method to store plot locally. The name of this plot is static.
:param plot_folder: path to save the plot
......
......@@ -35,6 +35,7 @@ class PostProcessing(RunEnvironment):
self.train_val_data: DataGenerator = self.data_store.get("generator", "general.train_val")
self.plot_path: str = self.data_store.get("plot_path", "general")
self.target_var = self.data_store.get("target_var", "general")
self._sampling = self.data_store.get("sampling", "general")
self.skill_scores = None
self._run()
......@@ -76,7 +77,7 @@ class PostProcessing(RunEnvironment):
PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False,
extra_name_tag="all_terms_", model_setup="CNN")
PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN")
PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path)
PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path, sampling=self._sampling)
def calculate_test_score(self):
test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),
......@@ -93,7 +94,7 @@ class PostProcessing(RunEnvironment):
def train_ols_model(self):
self.ols_model = OrdinaryLeastSquaredModel(self.train_data)
def make_prediction(self, freq="1D"):
def make_prediction(self):
logging.debug("start make_prediction")
for i, _ in enumerate(self.test_data):
data = self.test_data.get_data_generator(i)
......@@ -118,7 +119,7 @@ class PostProcessing(RunEnvironment):
orig_pred = self._create_orig_forecast(data, None, mean, std, transformation_method)
# merge all predictions
full_index = self.create_fullindex(data.data.indexes['datetime'], freq)
full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency())
all_predictions = self.create_forecast_arrays(full_index, list(data.label.indexes['window']),
CNN=nn_prediction,
persi=persistence_prediction,
......@@ -130,6 +131,10 @@ class PostProcessing(RunEnvironment):
file = os.path.join(path, f"forecasts_{data.station[0]}_test.nc")
all_predictions.to_netcdf(file)
def _get_frequency(self):
getter = {"daily": "1D", "hourly": "1H"}
return getter.get(self._sampling, None)
@staticmethod
def _create_orig_forecast(data, _, mean, std, transformation_method):
return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment