Skip to content
Snippets Groups Projects
Commit dea75845 authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue32_refac_restructure-plot-routines-in-modules' into 'develop'

Lukas issue32 refac restructure plot routines in modules

See merge request toar/machinelearningtools!28
parents 179b0b16 af678d0c
Branches
Tags
2 merge requests!37include new development,!28Lukas issue32 refac restructure plot routines in modules
Pipeline #28791 passed
...@@ -7,6 +7,7 @@ import math ...@@ -7,6 +7,7 @@ import math
import warnings import warnings
from src import helpers from src import helpers
from src.helpers import TimeTracking from src.helpers import TimeTracking
from src.run_modules.run_environment import RunEnvironment
import numpy as np import numpy as np
import xarray as xr import xarray as xr
...@@ -19,16 +20,20 @@ import cartopy.crs as ccrs ...@@ -19,16 +20,20 @@ import cartopy.crs as ccrs
import cartopy.feature as cfeature import cartopy.feature as cfeature
from matplotlib.backends.backend_pdf import PdfPages from matplotlib.backends.backend_pdf import PdfPages
from typing import Dict, List from typing import Dict, List, Tuple
logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING)
def plot_monthly_summary(stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None, class PlotMonthlySummary(RunEnvironment):
plot_folder: str = "."):
""" """
Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. The plot is saved Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. The plot is saved
in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution. in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution.
"""
def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None,
plot_folder: str = "."):
"""
Sets attributes and create plot
:param stations: all stations to plot :param stations: all stations to plot
:param data_path: path, where the data is located :param data_path: path, where the data is located
:param name: full name of the local files with a % as placeholder for the station name :param name: full name of the local files with a % as placeholder for the station name
...@@ -37,12 +42,24 @@ def plot_monthly_summary(stations: List, data_path: str, name: str, target_var: ...@@ -37,12 +42,24 @@ def plot_monthly_summary(stations: List, data_path: str, name: str, target_var:
the maximum lead time from data is used. (default None -> use maximum lead time from data). the maximum lead time from data is used. (default None -> use maximum lead time from data).
:param plot_folder: path to save the plot (default: current directory) :param plot_folder: path to save the plot (default: current directory)
""" """
logging.debug("run plot_monthly_summary()") super().__init__()
self._data_path = data_path
self._data_name = name
self._data = self._prepare_data(stations)
self._window_lead_time = self._get_window_lead_time(window_lead_time)
self._plot(target_var, plot_folder)
def _prepare_data(self, stations: List) -> xr.DataArray:
"""
Pre-process data required to plot. For each station, load locally saved predictions, extract the CNN and orig
prediction and group them into monthly bins (no aggregation, only sorting them).
:param stations: all stations to plot
:return: The entire data set, flagged with the corresponding month.
"""
forecasts = None forecasts = None
for station in stations: for station in stations:
logging.debug(f"... preprocess station {station}") logging.debug(f"... preprocess station {station}")
file_name = os.path.join(data_path, name % station) file_name = os.path.join(self._data_path, self._data_name % station)
data = xr.open_dataarray(file_name) data = xr.open_dataarray(file_name)
data_cnn = data.sel(type="CNN").squeeze() data_cnn = data.sel(type="CNN").squeeze()
...@@ -58,47 +75,85 @@ def plot_monthly_summary(stations: List, data_path: str, name: str, target_var: ...@@ -58,47 +75,85 @@ def plot_monthly_summary(stations: List, data_path: str, name: str, target_var:
data_concat = data_concat.clip(min=0) data_concat = data_concat.clip(min=0)
forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat
return forecasts
ahead_steps = len(forecasts.ahead) def _get_window_lead_time(self, window_lead_time: int):
"""
Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from
data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number
of ahead dimensions in data is lower than the given lead time, data's lead time is used.
:param window_lead_time: lead time from arguments to validate
:return: validated lead time, comes either from given argument or from data itself
"""
ahead_steps = len(self._data.ahead)
if window_lead_time is None: if window_lead_time is None:
window_lead_time = ahead_steps window_lead_time = ahead_steps
window_lead_time = min(ahead_steps, window_lead_time) return min(ahead_steps, window_lead_time)
forecasts = forecasts.to_dataset(name='values').to_dask_dataframe() def _plot(self, target_var: str, plot_folder: str):
"""
Main plot function that creates a monthly grouped box plot over all stations but with separate boxes for each
lead time step.
:param target_var: display name of the target variable on plot's axis
:param plot_folder: path to save the plot
"""
data = self._data.to_dataset(name='values').to_dask_dataframe()
logging.debug("... start plotting") logging.debug("... start plotting")
ax = sns.boxplot(x='index', y='values', hue='ahead', data=forecasts.compute(), whis=1., color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex()
palette=[matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1., palette=color_palette,
window_lead_time).as_hex(),
flierprops={'marker': '.', 'markersize': 1}, showmeans=True, flierprops={'marker': '.', 'markersize': 1}, showmeans=True,
meanprops={'markersize': 1, 'markeredgecolor': 'k'}) meanprops={'markersize': 1, 'markeredgecolor': 'k'})
ax.set(xlabel='month', ylabel=f'{target_var}') ax.set(xlabel='month', ylabel=f'{target_var}')
plt.tight_layout() plt.tight_layout()
self._save(plot_folder)
@staticmethod
def _save(plot_folder):
"""
Standard save method to store plot locally. The name of this plot is static.
:param plot_folder: path to save the plot
"""
plot_name = os.path.join(os.path.abspath(plot_folder), 'monthly_summary_box_plot.pdf') plot_name = os.path.join(os.path.abspath(plot_folder), 'monthly_summary_box_plot.pdf')
logging.debug(f"... save plot to {plot_name}") logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500) plt.savefig(plot_name, dpi=500)
plt.close('all') plt.close('all')
def plot_station_map(generators: Dict, plot_folder: str = "."): class PlotStationMap(RunEnvironment):
""" """
Plot geographical overview of all used stations. Different data sets can be colorised by its key in the input Plot geographical overview of all used stations as squares. Different data sets can be colorised by its key in the
dictionary generators. The key represents the color to plot on the map. Currently, there is only a white background, input dictionary generators. The key represents the color to plot on the map. Currently, there is only a white
but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is saved under background, but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is
plot_path with the name station_map.pdf saved under plot_path with the name station_map.pdf
"""
def __init__(self, generators: Dict, plot_folder: str = "."):
"""
Sets attributes and create plot
:param generators: dictionary with the plot color of each data set as key and the generator containing all stations :param generators: dictionary with the plot color of each data set as key and the generator containing all stations
as value. as value.
:param plot_folder: path to save the plot (default: current directory) :param plot_folder: path to save the plot (default: current directory)
""" """
logging.debug("run station_map()") super().__init__()
fig = plt.figure(figsize=(10, 5)) self._ax = None
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) self._plot(generators, plot_folder)
ax.set_extent([0, 20, 42, 58], crs=ccrs.PlateCarree())
ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black') def _draw_background(self):
ax.add_feature(cfeature.LAKES.with_scale("50m")) """
ax.add_feature(cfeature.OCEAN.with_scale("50m")) Draw coastline, lakes, ocean, rivers and country borders as background on the map.
ax.add_feature(cfeature.RIVERS.with_scale("10m")) """
ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black') self._ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black')
self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
self._ax.add_feature(cfeature.RIVERS.with_scale("10m"))
self._ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black')
def _plot_stations(self, generators):
"""
The actual plot function. Loops over all keys in generators dict and its containing stations and plots a square
and the stations's position on the map regarding the given color.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
"""
if generators is not None: if generators is not None:
for color, gen in generators.items(): for color, gen in generators.items():
for k, v in enumerate(gen): for k, v in enumerate(gen):
...@@ -106,8 +161,28 @@ def plot_station_map(generators: Dict, plot_folder: str = "."): ...@@ -106,8 +161,28 @@ def plot_station_map(generators: Dict, plot_folder: str = "."):
# station_names = gen.get_data_generator(k).meta.loc[['station_id']] # station_names = gen.get_data_generator(k).meta.loc[['station_id']]
IDx, IDy = float(station_coords.loc['station_lon'].values), float( IDx, IDy = float(station_coords.loc['station_lon'].values), float(
station_coords.loc['station_lat'].values) station_coords.loc['station_lat'].values)
ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree()) self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())
def _plot(self, generators: Dict, plot_folder: str):
"""
Main plot function to create the station map plot. Sets figure and calls all required sub-methods.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
:param plot_folder: path to save the plot
"""
fig = plt.figure(figsize=(10, 5))
self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
self._ax.set_extent([0, 20, 42, 58], crs=ccrs.PlateCarree())
self._draw_background()
self._plot_stations(generators)
self._save(plot_folder)
@staticmethod
def _save(plot_folder):
"""
Standard save method to store plot locally. The name of this plot is static.
:param plot_folder: path to save the plot
"""
plot_name = os.path.join(os.path.abspath(plot_folder), 'station_map.pdf') plot_name = os.path.join(os.path.abspath(plot_folder), 'station_map.pdf')
logging.debug(f"... save plot to {plot_name}") logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500) plt.savefig(plot_name, dpi=500)
...@@ -249,75 +324,155 @@ def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_w ...@@ -249,75 +324,155 @@ def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_w
logging.info(f"plot_conditional_quantiles() finished after {time}") logging.info(f"plot_conditional_quantiles() finished after {time}")
def plot_climatological_skill_score(data: Dict, plot_folder: str = ".", score_only: bool = True, class PlotClimatologicalSkillScore(RunEnvironment):
extra_name_tag: str = "", model_setup: str = ""):
""" """
Create plot of climatological skill score after Murphy (1988) as box plot over all stations. A forecast time step Create plot of climatological skill score after Murphy (1988) as box plot over all stations. A forecast time step
(called "ahead") is separately shown to highlight the differences for each prediction time step. Either each single (called "ahead") is separately shown to highlight the differences for each prediction time step. Either each single
term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed (score_only=True, term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed (score_only=True,
default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with
name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi.
"""
def __init__(self, data: Dict, plot_folder: str = ".", score_only: bool = True, extra_name_tag: str = "",
model_setup: str = ""):
"""
Sets attributes and create plot
:param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms.
:param plot_folder: path to save the plot (default: current directory) :param plot_folder: path to save the plot (default: current directory)
:param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True) :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True)
:param extra_name_tag: additional tag that can be included in the plot name (default "") :param extra_name_tag: additional tag that can be included in the plot name (default "")
:param model_setup: architecture type (default "CNN") :param model_setup: architecture type to specify plot name (default "CNN")
"""
super().__init__()
self._labels = None
self._data = self._prepare_data(data, score_only)
self._plot(plot_folder, score_only, extra_name_tag, model_setup)
def _prepare_data(self, data: Dict, score_only: bool) -> pd.DataFrame:
"""
Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set
plot labels depending on the lead time dimensions.
:param data: dictionary with station names as keys and 2D xarrays as values
:param score_only: if true only scores of CASE I to IV are relevant
:return: pre-processed data set
""" """
logging.debug("run plot_climatological_skill_score()")
data = helpers.dict_to_xarray(data, "station") data = helpers.dict_to_xarray(data, "station")
labels = [str(i) + "d" for i in data.coords["ahead"].values] self._labels = [str(i) + "d" for i in data.coords["ahead"].values]
fig, ax = plt.subplots()
if score_only: if score_only:
data = data.loc[:, ["CASE I", "CASE II", "CASE III", "CASE IV"], :] data = data.loc[:, ["CASE I", "CASE II", "CASE III", "CASE IV"], :]
lab_add = '' return data.to_dataframe("data").reset_index(level=[0, 1, 2])
else:
def _label_add(self, score_only: bool):
"""
Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True).
:param score_only: if false all terms are relevant, otherwise only CASE I to IV
:return: additional label
"""
return "" if score_only else "terms and "
def _plot(self, plot_folder, score_only, extra_name_tag, model_setup):
"""
Main plot function to plot climatological skill score.
:param plot_folder: path to save the plot
:param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms
:param extra_name_tag: additional tag that can be included in the plot name
:param model_setup: architecture type to specify plot name
"""
fig, ax = plt.subplots()
if not score_only:
fig.set_size_inches(11.7, 8.27) fig.set_size_inches(11.7, 8.27)
lab_add = "terms and " sns.boxplot(x="terms", y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d",
data = data.to_dataframe("data").reset_index(level=[0, 1, 2]) showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
sns.boxplot(x="terms", y="data", hue="ahead", data=data, ax=ax, whis=1., palette="Blues_d", showmeans=True,
meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
ax.axhline(y=0, color="grey", linewidth=.5) ax.axhline(y=0, color="grey", linewidth=.5)
ax.set(ylabel=f"{lab_add}skill score", xlabel="", title="summary of all stations") ax.set(ylabel=f"{self._label_add(score_only)}skill score", xlabel="", title="summary of all stations")
handles, _ = ax.get_legend_handles_labels() handles, _ = ax.get_legend_handles_labels()
ax.legend(handles, labels) ax.legend(handles, self._labels)
plt.tight_layout() plt.tight_layout()
self._save(plot_folder, extra_name_tag, model_setup)
@staticmethod
def _save(plot_folder, extra_name_tag, model_setup):
"""
Standard save method to store plot locally. The name of this plot is dynamic. It includes the model setup like
'CNN' and can additionally be adjusted using an extra name tag.
:param plot_folder: path to save the plot
:param extra_name_tag: additional tag that can be included in the plot name
:param model_setup: architecture type to specify plot name
"""
plot_name = os.path.join(plot_folder, f"skill_score_clim_{extra_name_tag}{model_setup}.pdf") plot_name = os.path.join(plot_folder, f"skill_score_clim_{extra_name_tag}{model_setup}.pdf")
logging.debug(f"... save plot to {plot_name}") logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500) plt.savefig(plot_name, dpi=500)
plt.close('all') plt.close('all')
def plot_competitive_skill_score(data: pd.DataFrame, plot_folder=".", model_setup="CNN"): class PlotCompetitiveSkillScore(RunEnvironment):
""" """
Create competitive skill score for the given model setup and the reference models ordinary least squared ("ols") and Create competitive skill score for the given model setup and the reference models ordinary least squared ("ols") and
the persistence forecast ("persi") for all lead times ("ahead"). The plot is saved under plot_folder with the name the persistence forecast ("persi") for all lead times ("ahead"). The plot is saved under plot_folder with the name
skill_score_competitive_{model_setup}.pdf and resolution of 500dpi. skill_score_competitive_{model_setup}.pdf and resolution of 500dpi.
"""
def __init__(self, data: pd.DataFrame, plot_folder=".", model_setup="CNN"):
"""
:param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre- :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre-
calculated comparisons for cnn, persistence and ols. calculated comparisons for cnn, persistence and ols.
:param plot_folder: path to save the plot (default: current directory) :param plot_folder: path to save the plot (default: current directory)
:param model_setup: architecture type (default "CNN") :param model_setup: architecture type (default "CNN")
""" """
logging.debug("run plot_general_skill_score()") super().__init__()
self._labels = None
self._data = self._prepare_data(data)
self._plot(plot_folder, model_setup)
def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Reformat given data and create plot labels. Introduces the dimensions stations and comparison
:param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre-
calculated comparisons for cnn, persistence and ols.
:return: processed data
"""
data = pd.concat(data, axis=0) data = pd.concat(data, axis=0)
data = xr.DataArray(data, dims=["stations", "ahead"]).unstack("stations") data = xr.DataArray(data, dims=["stations", "ahead"]).unstack("stations")
data = data.rename({"stations_level_0": "stations", "stations_level_1": "comparison"}) data = data.rename({"stations_level_0": "stations", "stations_level_1": "comparison"})
data = data.to_dataframe("data").unstack(level=1).swaplevel() data = data.to_dataframe("data").unstack(level=1).swaplevel()
data.columns = data.columns.levels[1] data.columns = data.columns.levels[1]
self._labels = [str(i) + "d" for i in data.index.levels[1].values]
return data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data")
labels = [str(i) + "d" for i in data.index.levels[1].values] def _plot(self, plot_folder, model_setup):
data = data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data") """
Main plot function to plot skill scores of the comparisons cnn-persi, ols-persi and cnn-ols.
:param plot_folder: path to save the plot
:param model_setup:
:return: architecture type to specify plot name
"""
fig, ax = plt.subplots() fig, ax = plt.subplots()
sns.boxplot(x="comparison", y="data", hue="ahead", data=data, whis=1., ax=ax, palette="Blues_d", showmeans=True, sns.boxplot(x="comparison", y="data", hue="ahead", data=self._data, whis=1., ax=ax, palette="Blues_d",
meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
order=["cnn-persi", "ols-persi", "cnn-ols"]) order=["cnn-persi", "ols-persi", "cnn-ols"])
ax.axhline(y=0, color="grey", linewidth=.5) ax.axhline(y=0, color="grey", linewidth=.5)
ax.set(ylabel="skill score", xlabel="competing models", title="summary of all stations",
ylim=(np.min([0, helpers.float_round(data.min()[2], 2) - 0.1]), helpers.float_round(data.max()[2], 2) + 0.1)) ax.set(ylabel="skill score", xlabel="competing models", title="summary of all stations", ylim=self._ylim())
handles, _ = ax.get_legend_handles_labels() handles, _ = ax.get_legend_handles_labels()
ax.legend(handles, labels) ax.legend(handles, self._labels)
plt.tight_layout() plt.tight_layout()
self._save(plot_folder, model_setup)
def _ylim(self) -> Tuple[float, float]:
"""
Calculate y-axis limits from data. Lower is the minimum of either 0 or data's minimum (reduced by small
subtrahend) and upper limit is data's maximum (increased by a small addend).
:return:
"""
lower = np.min([0, helpers.float_round(self._data.min()[2], 2) - 0.1])
upper = helpers.float_round(self._data.max()[2], 2) + 0.1
return lower, upper
@staticmethod
def _save(plot_folder, model_setup):
"""
Standard save method to store plot locally. The name of this plot is dynamic by including the model setup.
:param plot_folder: path to save the plot
:param model_setup: architecture type to specify plot name
"""
plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}.pdf") plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}.pdf")
logging.debug(f"... save plot to {plot_name}") logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500) plt.savefig(plot_name, dpi=500)
......
...@@ -47,8 +47,11 @@ experiment_path ...@@ -47,8 +47,11 @@ experiment_path
└─── plots └─── plots
conditional_quantiles_cali-ref_plot.pdf conditional_quantiles_cali-ref_plot.pdf
conditional_quantiles_like-bas_plot.pdf conditional_quantiles_like-bas_plot.pdf
test_monthly_box.pdf monthly_summary_box_plot.pdf
test_map_plot.pdf skill_score_clim_all_terms_<architecture>.pdf
skill_score_clim_<architecture>.pdf
skill_score_competitive_<architecture>.pdf
station_map.pdf
<experiment_name>_history_learning_rate.pdf <experiment_name>_history_learning_rate.pdf
<experiment_name>_history_loss.pdf <experiment_name>_history_loss.pdf
<experiment_name>_history_main_loss.pdf <experiment_name>_history_main_loss.pdf
......
...@@ -15,8 +15,8 @@ from src.data_handling.data_distributor import Distributor ...@@ -15,8 +15,8 @@ from src.data_handling.data_distributor import Distributor
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src.model_modules.linear_model import OrdinaryLeastSquaredModel
from src import statistics from src import statistics
from src.plotting.postprocessing_plotting import plot_monthly_summary, plot_station_map, plot_conditional_quantiles, \ from src.plotting.postprocessing_plotting import plot_conditional_quantiles
plot_climatological_skill_score, plot_competitive_skill_score from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, PlotCompetitiveSkillScore
from src.datastore import NameNotFoundInDataStore from src.datastore import NameNotFoundInDataStore
...@@ -61,14 +61,13 @@ class PostProcessing(RunEnvironment): ...@@ -61,14 +61,13 @@ class PostProcessing(RunEnvironment):
forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path)
plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN", plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN",
forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path)
plot_station_map(generators={'b': self.test_data}, plot_folder=self.plot_path) PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var, PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var,
plot_folder=self.plot_path) plot_folder=self.plot_path)
# PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN")
plot_climatological_skill_score(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False,
plot_climatological_skill_score(self.skill_scores[1], plot_folder=self.plot_path, score_only=False,
extra_name_tag="all_terms_", model_setup="CNN") extra_name_tag="all_terms_", model_setup="CNN")
plot_competitive_skill_score(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN")
def calculate_test_score(self): def calculate_test_score(self):
test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(), test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),
...@@ -228,14 +227,14 @@ class PostProcessing(RunEnvironment): ...@@ -228,14 +227,14 @@ class PostProcessing(RunEnvironment):
def calculate_skill_scores(self): def calculate_skill_scores(self):
path = self.data_store.get("forecast_path", "general") path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
skill_score_general = {} skill_score_competitive = {}
skill_score_climatological = {} skill_score_climatological = {}
for station in self.test_data.stations: for station in self.test_data.stations:
file = os.path.join(path, f"forecasts_{station}_test.nc") file = os.path.join(path, f"forecasts_{station}_test.nc")
data = xr.open_dataarray(file) data = xr.open_dataarray(file)
skill_score = statistics.SkillScores(data) skill_score = statistics.SkillScores(data)
external_data = self._get_external_data(station) external_data = self._get_external_data(station)
skill_score_general[station] = skill_score.skill_scores(window_lead_time) skill_score_competitive[station] = skill_score.skill_scores(window_lead_time)
skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data, skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data,
window_lead_time) window_lead_time)
return skill_score_general, skill_score_climatological return skill_score_competitive, skill_score_climatological
...@@ -136,7 +136,7 @@ class SkillScores(RunEnvironment): ...@@ -136,7 +136,7 @@ class SkillScores(RunEnvironment):
observation = data.sel(type=observation_name) observation = data.sel(type=observation_name)
forecast = data.sel(type=forecast_name) forecast = data.sel(type=forecast_name)
reference = data.sel(type=reference_name) reference = data.sel(type=reference_name)
mse = statistics.mean_squared_error mse = mean_squared_error
skill_score = 1 - mse(observation, forecast) / mse(observation, reference) skill_score = 1 - mse(observation, forecast) / mse(observation, reference)
return skill_score.values return skill_score.values
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment