Commit 23315ee0 authored by lukas leufen's avatar lukas leufen 👻
Browse files

Merge branch 'release_v2.1.0' into 'master'

Resolve "release v2.1.0"

Closes #328, #380, #381, #385, and #391

See merge request !431
parents 064bee3e 49c26e0e
Pipeline #102149 passed with stages
in 13 minutes and 59 seconds
# Changelog
All notable changes to this project will be documented in this file.
## v2.1.0 - 2022-06-07 - new evaluation metrics and improved training
### general:
* new evaluation metrics, IOA and MNMB
* advanced train options for early stopping
* reduced execution time by refactoring
### new features:
* uncertainty estimation of MSE is now applied for each season separately (#374)
* added different configurations of early stopping to use either last trained or best epoch (#378)
* train monitoring plots now add a star for best epoch when using early stopping (#367)
* new evaluation metric index of agreement, IOA (#376)
* new evaluation metric modified normalised mean bias, MNMB (#380)
* new plot available that shows temporal evolution of MSE for each station (#381)
### technical:
* reduced loading of forecast path from data store (#328)
* bug fix for not catched error during transformation (#385)
* bug fix for data handler with climate and fir filter leading to calculate transformation always with fir filter (#387)
* improved duration for latex report creation at end of preprocessing (#388)
* enhanced speed for make prediction in postprocessing (#389)
* fix to always create version badge from version and not from tag name (#382)
## v2.0.0 - 2022-04-08 - tf2 usage, new model classes, and improved uncertainty estimate
### general:
......
#!/bin/bash
VERSION="$(git describe --tags $(git rev-list --tags --max-count=1))"
VERSION="$(git describe master)"
COLOR="blue"
BADGE_NAME="version"
......
......@@ -34,7 +34,7 @@ HPC systems, see [here](#special-instructions-for-installation-on-jülich-hpc-sy
* Installation of **MLAir**:
* Either clone MLAir from the [gitlab repository](https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair.git)
and use it without installation (beside the requirements)
* or download the distribution file ([current version](https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.0.0-py3-none-any.whl))
* or download the distribution file ([current version](https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.1.0-py3-none-any.whl))
and install it via `pip install <dist_file>.whl`. In this case, you can simply import MLAir in any python script
inside your virtual environment using `import mlair`.
......
......@@ -27,7 +27,7 @@ Installation of MLAir
* Install all requirements from `requirements.txt <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/requirements.txt>`_
preferably in a virtual environment
* Either clone MLAir from the `gitlab repository <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair.git>`_
* or download the distribution file (`current version <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.0.0-py3-none-any.whl>`_)
* or download the distribution file (`current version <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.1.0-py3-none-any.whl>`_)
and install it via :py:`pip install <dist_file>.whl`. In this case, you can simply
import MLAir in any python script inside your virtual environment using :py:`import mlair`.
......
__version_info__ = {
'major': 2,
'minor': 0,
'minor': 1,
'micro': 0,
}
......
......@@ -2,6 +2,9 @@ __author__ = "Lukas Leufen"
__date__ = '2020-06-25'
import numpy as np
DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
......@@ -24,6 +27,8 @@ DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY = False
DEFAULT_PERMUTE_DATA = False
DEFAULT_BATCH_SIZE = int(256 * 2)
DEFAULT_EPOCHS = 20
DEFAULT_EARLY_STOPPING_EPOCHS = np.inf
DEFAULT_RESTORE_BEST_MODEL_WEIGHTS = True
DEFAULT_TARGET_VAR = "o3"
DEFAULT_TARGET_DIM = "variables"
DEFAULT_WINDOW_LEAD_TIME = 3
......
......@@ -316,10 +316,17 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
@classmethod
def _split_chem_and_meteo_variables(cls, **kwargs):
"""
Select all used variables and split them into categories chem and other.
Chemical variables are indicated by `cls.data_handler_climate_fir.chem_vars`. To indicate used variables, this
method uses 1) parameter `variables`, 2) keys from `statistics_per_var`, 3) keys from
`cls.data_handler_climate_fir.DEFAULT_VAR_ALL_DICT`. Option 3) is also applied if 1) or 2) are given but None.
"""
if "variables" in kwargs:
variables = kwargs.get("variables")
elif "statistics_per_var" in kwargs:
variables = kwargs.get("statistics_per_var")
variables = kwargs.get("statistics_per_var").keys()
else:
variables = None
if variables is None:
......@@ -348,14 +355,7 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
cls.prepare_build(sp_keys, chem_vars, cls.chem_indicator)
sp_chem_unfiltered = cls.data_handler_unfiltered(station, **sp_keys)
if len(meteo_vars) > 0:
if cls.data_handler_fir_pos is None:
if "extend_length_opts" in kwargs:
if isinstance(kwargs["extend_length_opts"], dict) and cls.meteo_indicator not in kwargs["extend_length_opts"].keys():
cls.data_handler_fir_pos = 0 # use faster fir version without climate estimate
else:
cls.data_handler_fir_pos = 1 # use slower fir version with climate estimate
else:
cls.data_handler_fir_pos = 0 # use faster fir version without climate estimate
cls.set_data_handler_fir_pos(**kwargs)
sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_fir[cls.data_handler_fir_pos].requirements() if k in kwargs}
sp_keys = cls.build_update_transformation(sp_keys, dh_type="filtered_meteo")
cls.prepare_build(sp_keys, meteo_vars, cls.meteo_indicator)
......@@ -369,8 +369,36 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
return cls(sp_chem, sp_meteo, sp_chem_unfiltered, sp_meteo_unfiltered, chem_vars, meteo_vars, **dp_args)
@classmethod
def set_data_handler_fir_pos(cls, **kwargs):
"""
Set position of fir data handler to use either faster FIR version or slower climate FIR.
This method will set data handler indicator to 0 if either no parameter "extend_length_opts" is given or the
parameter is of type dict but has no entry for the meteo_indicator. In all other cases, indicator is set to 1.
"""
p_name = "extend_length_opts"
if cls.data_handler_fir_pos is None:
if p_name in kwargs:
if isinstance(kwargs[p_name], dict) and cls.meteo_indicator not in kwargs[p_name].keys():
cls.data_handler_fir_pos = 0 # use faster fir version without climate estimate
else:
cls.data_handler_fir_pos = 1 # use slower fir version with climate estimate
else:
cls.data_handler_fir_pos = 0 # use faster fir version without climate estimate
@classmethod
def prepare_build(cls, kwargs, var_list, var_type):
"""
Prepares for build of class.
`variables` parameter is updated by `var_list`, which should only include variables of a specific type (e.g.
only chemical variables) indicated by `var_type`. Furthermore, this method cleans the `kwargs` dictionary as
follows: For all parameters provided as dict to separate between chem and meteo options (dict must have keys
from `cls.chem_indicator` and/or `cls.meteo_indicator`), this parameter is removed from kwargs and its value
related to `var_type` added again. In case there is no value for given `var_type`, the parameter is not added
at all (as this parameter is assumed to affect only other types of variables).
"""
kwargs.update({"variables": var_list})
for k in list(kwargs.keys()):
v = kwargs[k]
......@@ -382,17 +410,6 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
except KeyError:
pass
@staticmethod
def adjust_window_opts(key: str, parameter_name: str, kwargs: dict):
try:
if parameter_name in kwargs:
window_opt = kwargs.pop(parameter_name)
if isinstance(window_opt, dict):
window_opt = window_opt[key]
kwargs[parameter_name] = window_opt
except KeyError:
pass
def _create_collection(self):
collection = super()._create_collection()
if self.id_class_other is not None:
......@@ -419,9 +436,10 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
# meteo transformation
if len(meteo_vars) > 0:
cls.set_data_handler_fir_pos(**kwargs)
kwargs_meteo = copy.deepcopy(kwargs)
cls.prepare_build(kwargs_meteo, meteo_vars, cls.meteo_indicator)
dh_transformation = (cls.data_handler_fir[cls.data_handler_fir_pos or 0], cls.data_handler_unfiltered)
dh_transformation = (cls.data_handler_fir[cls.data_handler_fir_pos], cls.data_handler_unfiltered)
transformation_meteo = super().transformation(set_stations, tmp_path=tmp_path,
dh_transformation=dh_transformation, **kwargs_meteo)
......
......@@ -125,8 +125,9 @@ class DefaultDataHandler(AbstractDataHandler):
def get_data(self, upsampling=False, as_numpy=True):
self._load()
X = self.get_X(upsampling, as_numpy)
Y = self.get_Y(upsampling, as_numpy)
as_numpy_X, as_numpy_Y = as_numpy if isinstance(as_numpy, tuple) else (as_numpy, as_numpy)
X = self.get_X(upsampling, as_numpy_X)
Y = self.get_Y(upsampling, as_numpy_Y)
self._reset_data()
return X, Y
......@@ -378,7 +379,7 @@ def f_proc(data_handler, station, return_strategy="", tmp_path=None, **sp_keys):
assert return_strategy in ["result", "reference"]
try:
res = data_handler(station, **sp_keys)
except (AttributeError, EmptyQueryResult, KeyError, ValueError) as e:
except (AttributeError, EmptyQueryResult, KeyError, ValueError, IndexError) as e:
logging.info(f"remove station {station} because it raised an error: {e}")
res = None
if return_strategy == "result":
......
......@@ -144,8 +144,8 @@ class KerasIterator(keras.utils.Sequence):
mod_rank = self._get_model_rank()
for data in self._collection:
logging.debug(f"prepare batches for {str(data)}")
X = data.get_X(upsampling=self.upsampling)
Y = [data.get_Y(upsampling=self.upsampling)[0] for _ in range(mod_rank)]
X, _Y = data.get_data(upsampling=self.upsampling)
Y = [_Y[0] for _ in range(mod_rank)]
if self.upsampling:
X, Y = self._permute_data(X, Y)
if remaining is not None:
......
......@@ -122,6 +122,21 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce
return round_type(number * multiplier) / multiplier
def relative_round(x: float, sig: int) -> float:
"""
Round small numbers according to given "significance".
Example: relative_round(0.03112, 2) -> 0.031, relative_round(0.03112, 1) -> 0.03
:params x: number to round
:params sig: "significance" to determine number of decimals
:return: rounded number
"""
assert sig >= 1
return round(x, sig-int(np.floor(np.log10(abs(x))))-1)
def remove_items(obj: Union[List, Dict, Tuple], items: Any):
"""
Remove item(s) from either list, tuple or dictionary.
......
......@@ -11,6 +11,8 @@ import pandas as pd
from typing import Union, Tuple, Dict, List
import itertools
from collections import OrderedDict
from mlair.helpers import to_list
Data = Union[xr.DataArray, pd.DataFrame]
......@@ -211,13 +213,42 @@ def mean_absolute_error(a, b, dim=None):
return np.abs(a - b).mean(dim)
def index_of_agreement(a, b, dim=None):
"""Calculate index of agreement (IOA) where a is the forecast and b the reference (e.g. observation)."""
num = (np.square(b - a)).sum(dim)
b_mean = (b * np.ones(1)).mean(dim)
den = (np.square(np.abs(b - b_mean) + np.abs(a - b_mean))).sum(dim)
frac = num / den
# issue with 0/0 division for exactly equal arrays
if isinstance(frac, (int, float)):
frac = 0 if num == 0 else frac
else:
frac[num == 0] = 0
return 1 - frac
def modified_normalized_mean_bias(a, b, dim=None):
"""Calculate modified normalized mean bias (MNMB) where a is the forecast and b the reference (e.g. observation)."""
N = np.count_nonzero(a) if len(a.shape) == 1 else a.notnull().sum(dim)
return 2 * ((a - b) / (a + b)).sum(dim) / N
def calculate_error_metrics(a, b, dim):
"""Calculate MSE, RMSE, and MAE. Additionally return number of used values for calculation."""
"""Calculate MSE, RMSE, MAE, IOA, and MNMB. Additionally, return number of used values for calculation.
:param a: forecast data to calculate metrics for
:param b: reference (e.g. observation)
:param dim: dimension to calculate metrics along
:returns: dict with results for all metrics indicated by lowercase metric short name
"""
mse = mean_squared_error(a, b, dim)
rmse = np.sqrt(mse)
mae = mean_absolute_error(a, b, dim)
ioa = index_of_agreement(a, b, dim)
mnmb = modified_normalized_mean_bias(a, b, dim)
n = (a - b).notnull().sum(dim)
return {"mse": mse, "rmse": rmse, "mae": mae, "n": n}
return {"mse": mse, "rmse": rmse, "mae": mae, "ioa": ioa, "mnmb": mnmb, "n": n}
def mann_whitney_u_test(data: pd.DataFrame, reference_col_name: str, **kwargs):
......@@ -540,7 +571,6 @@ def create_single_bootstrap_realization(data: xr.DataArray, dim_name_time: str)
:param dim_name_time: name of time dimension
:return: bootstrapped realization of data
"""
num_of_blocks = data.coords[dim_name_time].shape[0]
boot_idx = np.random.choice(num_of_blocks, size=num_of_blocks, replace=True)
return data.isel({dim_name_time: boot_idx})
......@@ -556,7 +586,7 @@ def calculate_average(data: xr.DataArray, **kwargs) -> xr.DataArray:
def create_n_bootstrap_realizations(data: xr.DataArray, dim_name_time: str, dim_name_model: str, n_boots: int = 1000,
dim_name_boots: str = 'boots') -> xr.DataArray:
dim_name_boots: str = 'boots', seasons: List = None) -> Dict[str, xr.DataArray]:
"""
Create n bootstrap realizations and calculate averages across realizations
......@@ -565,26 +595,23 @@ def create_n_bootstrap_realizations(data: xr.DataArray, dim_name_time: str, dim_
:param dim_name_model: name of model dimension
:param n_boots: number of bootstap realizations
:param dim_name_boots: name of bootstap dimension
:param seasons: calculate errors for given seasons in addition (default None)
:return:
"""
seasons = [] if seasons is None else to_list(seasons) # assure seasons to be empty list if None
res_dims = [dim_name_boots]
dims = list(data.dims)
other_dims = [v for v in dims if v in set(dims).difference([dim_name_time])]
coords = {dim_name_boots: range(n_boots), **{dim_name: data.coords[dim_name] for dim_name in other_dims}}
if len(dims) > 1:
res_dims = res_dims + other_dims
res = xr.DataArray(np.nan, dims=res_dims, coords=coords)
realizations = {k: xr.DataArray(np.nan, dims=res_dims, coords=coords) for k in seasons + [""]}
for boot in range(n_boots):
res[boot] = (calculate_average(
create_single_bootstrap_realization(data, dim_name_time=dim_name_time),
dim=dim_name_time, skipna=True))
return res
shuffled = create_single_bootstrap_realization(data, dim_name_time=dim_name_time)
realizations[""][boot] = calculate_average(shuffled, dim=dim_name_time, skipna=True)
for season in seasons:
assert season in ["DJF", "MAM", "JJA", "SON"]
sel = shuffled[dim_name_time].dt.season == season
realizations[season][boot] = calculate_average(shuffled.sel({dim_name_time: sel}),
dim=dim_name_time, skipna=True)
return realizations
......@@ -105,7 +105,10 @@ def get_all_args(*args, remove=None, add=None):
return res
def check_nested_equality(obj1, obj2):
def check_nested_equality(obj1, obj2, precision=None):
"""Check for equality in nested structures. Use precision to indicate number of decimals to check for consistency"""
assert precision is None or isinstance(precision, int)
try:
print(f"check type {type(obj1)} and {type(obj2)}")
......@@ -116,22 +119,38 @@ def check_nested_equality(obj1, obj2):
assert len(obj1) == len(obj2)
for pos in range(len(obj1)):
print(f"check pos {obj1[pos]} and {obj2[pos]}")
assert check_nested_equality(obj1[pos], obj2[pos]) is True
assert check_nested_equality(obj1[pos], obj2[pos], precision) is True
elif isinstance(obj1, dict):
print(f"check keys {obj1.keys()} and {obj2.keys()}")
assert sorted(obj1.keys()) == sorted(obj2.keys())
for k in obj1.keys():
print(f"check pos {obj1[k]} and {obj2[k]}")
assert check_nested_equality(obj1[k], obj2[k]) is True
assert check_nested_equality(obj1[k], obj2[k], precision) is True
elif isinstance(obj1, xr.DataArray):
print(f"check xr {obj1} and {obj2}")
assert xr.testing.assert_equal(obj1, obj2) is None
if precision is None:
print(f"check xr {obj1} and {obj2}")
assert xr.testing.assert_equal(obj1, obj2) is None
else:
print(f"check xr {obj1} and {obj2} with precision {precision}")
assert xr.testing.assert_allclose(obj1, obj2, atol=10**(-precision)) is None
elif isinstance(obj1, np.ndarray):
print(f"check np {obj1} and {obj2}")
assert np.testing.assert_array_equal(obj1, obj2) is None
if precision is None:
print(f"check np {obj1} and {obj2}")
assert np.testing.assert_array_equal(obj1, obj2) is None
else:
print(f"check np {obj1} and {obj2} with precision {precision}")
assert np.testing.assert_array_almost_equal(obj1, obj2, decimal=precision) is None
else:
print(f"check equal {obj1} and {obj2}")
assert obj1 == obj2
if isinstance(obj1, (int, float)) and isinstance(obj2, (int, float)):
if precision is None:
print(f"check number equal {obj1} and {obj2}")
assert np.testing.assert_equal(obj1, obj2) is None
else:
print(f"check number equal {obj1} and {obj2} with precision {precision}")
assert np.testing.assert_almost_equal(obj1, obj2, decimal=precision) is None
else:
print(f"check equal {obj1} and {obj2}")
assert obj1 == obj2
except AssertionError:
return False
return True
......@@ -163,6 +163,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def __init__(self, *args, **kwargs):
"""Initialise ModelCheckpointAdvanced and set callbacks attribute."""
self.callbacks = kwargs.pop("callbacks")
self.epoch_best = None
self.restore_best_weights = kwargs.pop("restore_best_weights", True)
super().__init__(*args, **kwargs)
def update_best(self, hist):
......@@ -176,7 +178,19 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
:param hist: The History object from the previous (interrupted) training.
"""
self.best = hist.history.get(self.monitor)[-1]
if self.restore_best_weights:
f = np.min if self.monitor_op.__name__ == "less" else np.max
f_loc = lambda x: np.where(x == f(x))[0][-1]
_d = hist.history.get(self.monitor)
loc = f_loc(_d)
assert f(_d) == _d[loc]
self.epoch_best = loc
self.best = _d[loc]
logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}")
else:
_d = hist.history.get(self.monitor)[-1]
self.best = _d
logging.info(f"Set only best result ({self.monitor}={self.best}) without best epoch")
def update_callbacks(self, callbacks):
"""
......@@ -197,6 +211,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.save_best_only:
current = logs.get(self.monitor)
if current == self.best:
if self.restore_best_weights:
self.epoch_best = epoch
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f:
......
......@@ -25,6 +25,7 @@ from mlair.helpers import TimeTrackingWrapper
from mlair.plotting.abstract_plot_class import AbstractPlotClass
from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks
logging.getLogger('matplotlib').setLevel(logging.WARNING)
......@@ -1095,7 +1096,7 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
def __init__(self, data: xr.DataArray, plot_folder: str = ".", model_type_dim: str = "type",
error_measure: str = "mse", error_unit: str = None, dim_name_boots: str = 'boots',
block_length: str = None, model_name: str = "NN", model_indicator: str = "nn",
ahead_dim: str = "ahead", sampling: Union[str, Tuple[str]] = ""):
ahead_dim: str = "ahead", sampling: Union[str, Tuple[str]] = "", season_annotation: str = None):
super().__init__(plot_folder, "sample_uncertainty_from_bootstrap")
self.default_plot_name = self.plot_name
self.model_type_dim = model_type_dim
......@@ -1105,6 +1106,7 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
self.error_unit = error_unit
self.block_length = block_length
self.model_name = model_name
_season = season_annotation or ""
self.sampling = {"daily": "d", "hourly": "H"}.get(sampling[1] if isinstance(sampling, tuple) else sampling, "")
data = self.rename_model_indicator(data, model_name, model_indicator)
self.prepare_data(data)
......@@ -1114,12 +1116,12 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
# plot raw metric (mse)
for orientation, utest, agg_type in variants:
self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type)
self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, season=_season)
# plot root of metric (rmse)
self._apply_root()
for orientation, utest, agg_type in variants:
self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, tag="_sqrt")
self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, tag="_sqrt", season=_season)
self._data_table = None
self._n_boots = None
......@@ -1148,9 +1150,10 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
self.error_measure = f"root {self.error_measure}"
self.error_unit = self.error_unit.replace("$^2$", "")
def _plot(self, orientation: str = "v", apply_u_test: bool = False, agg_type="single", tag=""):
def _plot(self, orientation: str = "v", apply_u_test: bool = False, agg_type="single", tag="", season=""):
self.plot_name = self.default_plot_name + {"v": "_vertical", "h": "_horizontal"}[orientation] + \
{True: "_u_test", False: ""}[apply_u_test] + "_" + agg_type + tag
{True: "_u_test", False: ""}[apply_u_test] + "_" + agg_type + tag + \
{"": ""}.get(season, f"_{season}")
if apply_u_test is True and agg_type == "multi":
return # not implemented
data_table = self._data_table
......@@ -1198,10 +1201,13 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})")
xlims = list(ax.get_xlim())
ax.set_xlim([xlims[0], xlims[1] * 1.015])
else:
raise ValueError(f"orientation must be `v' or `h' but is: {orientation}")
text = f"n={n_boots}" if self.block_length is None else f"{self.block_length}, n={n_boots}"
text = f"n={n_boots}"
if self.block_length is not None:
text = f"{self.block_length}, {text}"
if len(season) > 0:
text = f"{season}, {text}"
loc = "lower left"
text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5, bbox_to_anchor=(0., 1.0),
bbox_transform=ax.transAxes)
......@@ -1234,6 +1240,85 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
return ax
@TimeTrackingWrapper
class PlotTimeEvolutionMetric(AbstractPlotClass):
def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type", plot_folder=".",
error_measure: str = "mse", error_unit: str = None, model_name: str = "NN",
model_indicator: str = "nn", time_dim="index"):
super().__init__(plot_folder, "time_evolution_mse")
self.title = error_measure + f" (in {error_unit})" if error_unit is not None else ""
plot_name = self.plot_name
vmin = int(data.quantile(0.05))
vmax = int(data.quantile(0.95))
data = self._prepare_data(data, time_dim, model_type_dim, model_indicator, model_name)
for t in data[model_type_dim]:
# note: could be expanded to create plot per ahead step
plot_data = data.sel({model_type_dim: t}).mean(ahead_dim).to_pandas()
years = plot_data.columns.strftime("%Y").to_list()
months = plot_data.columns.strftime("%b").to_list()
plot_data.columns = plot_data.columns.strftime("%b %Y")
self.plot_name = f"{plot_name}_{t.values}"
self._plot(plot_data, years, months, vmin, vmax, str(t.values))
@staticmethod
def _find_nan_edge(data, time_dim):
coll = []
for i in data:
if bool(i) is False:
break
else:
coll.append(i[time_dim].values)
return coll
def _prepare_data(self, data, time_dim, model_type_dim, model_indicator, model_name):
# remove nans at begin and end
nan_locs = data.isnull().all(helpers.remove_items(data.dims, time_dim))
nans_at_end = self._find_nan_edge(reversed(nan_locs), time_dim)
nans_at_begin = self._find_nan_edge(nan_locs, time_dim)
data = data.drop(nans_at_begin + nans_at_end, time_dim)
# rename nn model
data[model_type_dim] = [v if v != model_indicator else model_name for v in data[model_type_dim].data.tolist()]
return data
@staticmethod
def _set_ticks(ax, years, months):
from matplotlib.ticker import IndexLocator
ax.xaxis.set_major_locator(IndexLocator(1, 0.5))
locs = ax.get_xticks(minor=False).tolist()[:len(months)]
ax.set_xticks(locs, minor=True)
ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0)
locs_major = []
labels_major = []