Skip to content
Snippets Groups Projects
Commit 4cb45809 authored by lukas leufen's avatar lukas leufen
Browse files

first refac on monthly summary plot, docs partly missing

parent 179b0b16
No related branches found
No related tags found
2 merge requests!37include new development,!28Lukas issue32 refac restructure plot routines in modules
Pipeline #28726 passed
......@@ -7,6 +7,7 @@ import math
import warnings
from src import helpers
from src.helpers import TimeTracking
from src.run_modules.run_environment import RunEnvironment
import numpy as np
import xarray as xr
......@@ -24,11 +25,15 @@ from typing import Dict, List
logging.getLogger('matplotlib').setLevel(logging.WARNING)
def plot_monthly_summary(stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None,
plot_folder: str = "."):
class PlotMonthlySummary(RunEnvironment):
"""
Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. The plot is saved
in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution.
"""
def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None,
plot_folder: str = "."):
"""
Sets attributes and create plot
:param stations: all stations to plot
:param data_path: path, where the data is located
:param name: full name of the local files with a % as placeholder for the station name
......@@ -37,12 +42,23 @@ def plot_monthly_summary(stations: List, data_path: str, name: str, target_var:
the maximum lead time from data is used. (default None -> use maximum lead time from data).
:param plot_folder: path to save the plot (default: current directory)
"""
logging.debug("run plot_monthly_summary()")
super().__init__()
self._data_path = data_path
self._data_name = name
self._data = self._get_data(stations)
self._window_lead_time = self._get_window_lead_time(window_lead_time)
self._plot(target_var, plot_folder)
def _get_data(self, stations):
"""
pre-process data
:param stations:
:return:
"""
forecasts = None
for station in stations:
logging.debug(f"... preprocess station {station}")
file_name = os.path.join(data_path, name % station)
file_name = os.path.join(self._data_path, self._data_name % station)
data = xr.open_dataarray(file_name)
data_cnn = data.sel(type="CNN").squeeze()
......@@ -58,17 +74,19 @@ def plot_monthly_summary(stations: List, data_path: str, name: str, target_var:
data_concat = data_concat.clip(min=0)
forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat
return forecasts
ahead_steps = len(forecasts.ahead)
def _get_window_lead_time(self, window_lead_time):
ahead_steps = len(self._data.ahead)
if window_lead_time is None:
window_lead_time = ahead_steps
window_lead_time = min(ahead_steps, window_lead_time)
return min(ahead_steps, window_lead_time)
forecasts = forecasts.to_dataset(name='values').to_dask_dataframe()
def _plot(self, target_var, plot_folder):
data = self._data.to_dataset(name='values').to_dask_dataframe()
logging.debug("... start plotting")
ax = sns.boxplot(x='index', y='values', hue='ahead', data=forecasts.compute(), whis=1.,
palette=[matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d",
window_lead_time).as_hex(),
color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex()
ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1., palette=color_palette,
flierprops={'marker': '.', 'markersize': 1}, showmeans=True,
meanprops={'markersize': 1, 'markeredgecolor': 'k'})
ax.set(xlabel='month', ylabel=f'{target_var}')
......
......@@ -15,8 +15,8 @@ from src.data_handling.data_distributor import Distributor
from src.data_handling.data_generator import DataGenerator
from src.model_modules.linear_model import OrdinaryLeastSquaredModel
from src import statistics
from src.plotting.postprocessing_plotting import plot_monthly_summary, plot_station_map, plot_conditional_quantiles, \
plot_climatological_skill_score, plot_competitive_skill_score
from src.plotting.postprocessing_plotting import plot_station_map, plot_conditional_quantiles, \
plot_climatological_skill_score, plot_competitive_skill_score, PlotMonthlySummary
from src.datastore import NameNotFoundInDataStore
......@@ -62,7 +62,7 @@ class PostProcessing(RunEnvironment):
plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN",
forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path)
plot_station_map(generators={'b': self.test_data}, plot_folder=self.plot_path)
plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var,
PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var,
plot_folder=self.plot_path)
#
plot_climatological_skill_score(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN")
......
......@@ -136,7 +136,7 @@ class SkillScores(RunEnvironment):
observation = data.sel(type=observation_name)
forecast = data.sel(type=forecast_name)
reference = data.sel(type=reference_name)
mse = statistics.mean_squared_error
mse = mean_squared_error
skill_score = 1 - mse(observation, forecast) / mse(observation, reference)
return skill_score.values
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment