Skip to content
Snippets Groups Projects
Commit e66e969c authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'develop' into 'lukas_issue339_feat_filter-with-future-mix'

Develop

See merge request !369
parents 66ecc664 8578459c
Branches
Tags
4 merge requests!413update release branch,!412Resolve "release v2.0.0",!369Develop,!358Resolve "filter with future mix"
Pipeline #84926 failed
......@@ -1042,7 +1042,6 @@ class PlotSeparationOfScales(AbstractPlotClass): # pragma: no cover
data = dh.get_X(as_numpy=False)[0]
station = dh.id_class.station[0]
data = data.sel(Stations=station)
# plt.subplots()
data.plot(x=self.time_dim, y=self.window_dim, col=self.filter_dim, row=self.target_dim, robust=True)
self.plot_name = f"{orig_plot_name}_{station}"
self._save()
......@@ -1085,9 +1084,8 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
return data
def prepare_data(self, data: xr.DataArray):
self._data_table = data.to_pandas()
if "persi" in self._data_table.columns:
self._data_table["persi"] = self._data_table.pop("persi")
data_table = data.to_pandas()
self._data_table = data_table[data_table.mean().sort_values().index]
self._n_boots = self._data_table.shape[0]
def _apply_root(self):
......@@ -1102,7 +1100,7 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
if orientation == "v":
figsize, width = (size, 5), 0.4
elif orientation == "h":
figsize, width = (6, (1+.5*size)), 0.65
figsize, width = (7, (1+.5*size)), 0.65
else:
raise ValueError(f"orientation must be `v' or `h' but is: {orientation}")
fig, ax = plt.subplots(figsize=figsize)
......@@ -1119,7 +1117,8 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
else:
raise ValueError(f"orientation must be `v' or `h' but is: {orientation}")
text = f"n={n_boots}" if self.block_length is None else f"{self.block_length}, n={n_boots}"
text_box = AnchoredText(text, frameon=True, loc=1, pad=0.5)
loc = "upper right" if orientation == "h" else "upper left"
text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5)
plt.setp(text_box.patch, edgecolor='k', facecolor='w')
ax.add_artist(text_box)
plt.setp(ax.lines, color='k')
......
......@@ -427,7 +427,7 @@ class PostProcessing(RunEnvironment):
:return: the model
"""
try:
try: # is only available if a model was trained in training stage
model = self.data_store.get("best_model")
except NameNotFoundInDataStore:
logging.info("No model was saved in data store. Try to load model from experiment path.")
......
......@@ -14,6 +14,7 @@ import psutil
import pandas as pd
from mlair.data_handler import KerasIterator
from mlair.model_modules import AbstractModelClass
from mlair.model_modules.keras_extensions import CallbackHandler
from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
from mlair.run_modules.run_environment import RunEnvironment
......@@ -67,7 +68,7 @@ class Training(RunEnvironment):
def __init__(self):
"""Set up and run training."""
super().__init__()
self.model: keras.Model = self.data_store.get("model", "model")
self.model: AbstractModelClass = self.data_store.get("model", "model")
self.train_set: Union[KerasIterator, None] = None
self.val_set: Union[KerasIterator, None] = None
# self.test_set: Union[KerasIterator, None] = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment