diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 2a41aab81d7ed62b1b58af515d703a2281236645..e6d6de152e42d44f271ba986b6645d2cd36b68d0 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -1042,7 +1042,6 @@ class PlotSeparationOfScales(AbstractPlotClass): # pragma: no cover data = dh.get_X(as_numpy=False)[0] station = dh.id_class.station[0] data = data.sel(Stations=station) - # plt.subplots() data.plot(x=self.time_dim, y=self.window_dim, col=self.filter_dim, row=self.target_dim, robust=True) self.plot_name = f"{orig_plot_name}_{station}" self._save() @@ -1085,9 +1084,8 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover return data def prepare_data(self, data: xr.DataArray): - self._data_table = data.to_pandas() - if "persi" in self._data_table.columns: - self._data_table["persi"] = self._data_table.pop("persi") + data_table = data.to_pandas() + self._data_table = data_table[data_table.mean().sort_values().index] self._n_boots = self._data_table.shape[0] def _apply_root(self): @@ -1102,7 +1100,7 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover if orientation == "v": figsize, width = (size, 5), 0.4 elif orientation == "h": - figsize, width = (6, (1+.5*size)), 0.65 + figsize, width = (7, (1+.5*size)), 0.65 else: raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") fig, ax = plt.subplots(figsize=figsize) @@ -1119,7 +1117,8 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover else: raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") text = f"n={n_boots}" if self.block_length is None else f"{self.block_length}, n={n_boots}" - text_box = AnchoredText(text, frameon=True, loc=1, pad=0.5) + loc = "upper right" if orientation == "h" else "upper left" + text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5) plt.setp(text_box.patch, edgecolor='k', facecolor='w') ax.add_artist(text_box) plt.setp(ax.lines, color='k') diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 9d03f47172d80b2d06e3ea6f10f44b076883c9ef..7f2b3b59b17910ae2667e003a821fbadab755b85 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -427,7 +427,7 @@ class PostProcessing(RunEnvironment): :return: the model """ - try: + try: # is only available if a model was trained in training stage model = self.data_store.get("best_model") except NameNotFoundInDataStore: logging.info("No model was saved in data store. Try to load model from experiment path.") diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 8d82afb4c002c660e6fb966945b2e383007d5b70..a38837dce041295d37fae1ea86ef2a215d51dc89 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -14,6 +14,7 @@ import psutil import pandas as pd from mlair.data_handler import KerasIterator +from mlair.model_modules import AbstractModelClass from mlair.model_modules.keras_extensions import CallbackHandler from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from mlair.run_modules.run_environment import RunEnvironment @@ -67,7 +68,7 @@ class Training(RunEnvironment): def __init__(self): """Set up and run training.""" super().__init__() - self.model: keras.Model = self.data_store.get("model", "model") + self.model: AbstractModelClass = self.data_store.get("model", "model") self.train_set: Union[KerasIterator, None] = None self.val_set: Union[KerasIterator, None] = None # self.test_set: Union[KerasIterator, None] = None