diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py index dee36166abacf275c79213305ac15e1918e1957c..617ff3135734056e51746ae2924a123a3eb34f8f 100644 --- a/src/plotting/training_monitoring.py +++ b/src/plotting/training_monitoring.py @@ -25,7 +25,7 @@ class PlotModelHistory: metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute path for the plot. """ - def __init__(self, filename: str, history: history_object, plot_metric: str = "loss"): + def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = True): """ Sets attributes and create plot :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a @@ -35,10 +35,22 @@ class PlotModelHistory: if isinstance(history, keras.callbacks.History): history = history.history self._data = pd.DataFrame.from_dict(history) - self._plot_metric = plot_metric + self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch) self._additional_columns = self._filter_columns(history) self._plot(filename) + @staticmethod + def _get_plot_metric(history, plot_metric, main_branch): + if plot_metric.lower() == "mse": + plot_metric = "mean_squared_error" + elif plot_metric.lower() == "mae": + plot_metric = "mean_absolute_error" + available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k if main_branch else True)] + print(available_keys) + available_keys.sort(key=len) + print(available_keys) + return available_keys[0] + def _filter_columns(self, history: Dict) -> List[str]: """ Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 0f3ff6d436b8a65528626f5f80508af222a1e68f..a7722018c52275b390a10199cb30b7b936ed37a3 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -15,8 +15,8 @@ from src.run_modules.run_environment import RunEnvironment from src.helpers import l_p_loss, LearningRateDecay from src.model_modules.inception_model import InceptionModelBase from src.model_modules.flatten import flatten_tail -# from src.model_modules.model_class import MyBranchedModel as MyModel -from src.model_modules.model_class import MyLittleModel as MyModel +from src.model_modules.model_class import MyBranchedModel as MyModel +# from src.model_modules.model_class import MyLittleModel as MyModel class ModelSetup(RunEnvironment): diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 272609a31a3e3c91d6857ed841d5dd2783c66f35..d15060daeb2596d048ff44f5f9948c0002069762 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -134,5 +134,10 @@ class Training(RunEnvironment): """ path = self.data_store.get("plot_path", "general") name = self.data_store.get("experiment_name", "general") - PlotModelHistory(filename=os.path.join(path, f"{name}_history_loss_val_loss.pdf"), history=history) + filename = os.path.join(path, f"{name}_history_loss.pdf") + PlotModelHistory(filename=filename, history=history, main_branch=False) + filename = os.path.join(path, f"{name}_history_main_loss.pdf") + PlotModelHistory(filename=filename, history=history) + filename = os.path.join(path, f"{name}_history_main_mse.pdf") + PlotModelHistory(filename=filename, history=history, plot_metric="mse") PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)