Skip to content
Snippets Groups Projects
Commit 1027e284 authored by lukas leufen's avatar lukas leufen
Browse files

first time series plot implementation, but zoom-in needs to be implemented

parent e3b98733
Branches
Tags
2 merge requests!37include new development,!35Lukas issue046 feat time series plot
Pipeline #29254 passed
......@@ -477,3 +477,57 @@ class PlotCompetitiveSkillScore(RunEnvironment):
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
plt.close()
class PlotTimeSeries(RunEnvironment):
def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = "."):
super().__init__()
self._data_path = data_path
self._data_name = name
self._stations = stations
self._window_lead_time = self._get_window_lead_time(window_lead_time)
self._plot(plot_folder)
def _get_window_lead_time(self, window_lead_time: int):
"""
Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from
data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number
of ahead dimensions in data is lower than the given lead time, data's lead time is used.
:param window_lead_time: lead time from arguments to validate
:return: validated lead time, comes either from given argument or from data itself
"""
ahead_steps = len(self._load_data(self._stations[0]).ahead)
if window_lead_time is None:
window_lead_time = ahead_steps
return min(ahead_steps, window_lead_time)
def _load_data(self, station):
logging.debug(f"... preprocess station {station}")
file_name = os.path.join(self._data_path, self._data_name % station)
data = xr.open_dataarray(file_name)
return data.sel(type=["CNN", "orig"])
def _plot(self, plot_folder):
f, axes = plt.subplots(len(self._stations), sharex="all")
color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex()
for pos, station in enumerate(self._stations):
data = self._load_data(station)
axes[pos].plot(data.index+ np.timedelta64(1, "D"), data.sel(type="CNN", ahead=1).values, color=color_palette[0])
for ahead in data.coords["ahead"].values:
plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze()
axes[pos].plot(plot_data.index + np.timedelta64(int(ahead), "D"), plot_data.values, color=color_palette[ahead])
self._save(plot_folder)
@staticmethod
def _save(plot_folder):
"""
Standard save method to store plot locally. The name of this plot is static.
:param plot_folder: path to save the plot
"""
plot_name = os.path.join(os.path.abspath(plot_folder), 'test_timeseries_plot.pdf')
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
plt.close('all')
......@@ -16,7 +16,8 @@ from src.data_handling.data_generator import DataGenerator
from src.model_modules.linear_model import OrdinaryLeastSquaredModel
from src import statistics
from src.plotting.postprocessing_plotting import plot_conditional_quantiles
from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, PlotCompetitiveSkillScore
from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
PlotCompetitiveSkillScore, PlotTimeSeries
from src.datastore import NameNotFoundInDataStore
from src.helpers import TimeTracking
......@@ -42,10 +43,10 @@ class PostProcessing(RunEnvironment):
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip make_prediction() whenever it is possible to save time.")
with TimeTracking():
preds_for_all_stations = self.make_prediction()
self.make_prediction()
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip make_prediction() whenever it is possible to save time.")
self.skill_scores = self.calculate_skill_scores()
# self.skill_scores = self.calculate_skill_scores()
self.plot()
def _load_model(self):
......@@ -64,17 +65,18 @@ class PostProcessing(RunEnvironment):
path = self.data_store.get("forecast_path", "general")
target_var = self.data_store.get("target_var", "general")
plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig",
forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path)
plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN",
forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path)
PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var,
plot_folder=self.plot_path)
PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN")
PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False,
extra_name_tag="all_terms_", model_setup="CNN")
PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN")
# plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig",
# forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path)
# plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN",
# forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path)
# PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
# PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var,
# plot_folder=self.plot_path)
# PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN")
# PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False,
# extra_name_tag="all_terms_", model_setup="CNN")
# PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN")
PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path)
def calculate_test_score(self):
test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),
......@@ -93,12 +95,11 @@ class PostProcessing(RunEnvironment):
def make_prediction(self, freq="1D"):
logging.debug("start make_prediction")
nn_prediction_all_stations = []
for i, v in enumerate(self.test_data):
for i, _ in enumerate(self.test_data):
data = self.test_data.get_data_generator(i)
nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3)
input_data = self.test_data[i][0]
input_data = data.get_transposed_history()
# get scaling parameters
mean, std, transformation_method = data.get_transformation_information(variable='o3')
......@@ -129,10 +130,6 @@ class PostProcessing(RunEnvironment):
file = os.path.join(path, f"forecasts_{data.station[0]}_test.nc")
all_predictions.to_netcdf(file)
# save nn forecast to return variable
nn_prediction_all_stations.append(nn_prediction)
return nn_prediction_all_stations
@staticmethod
def _create_orig_forecast(data, _, mean, std, transformation_method):
return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment