From e48041830059d848ea95e627796529e888689bfe Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Mon, 20 Jan 2020 17:38:04 +0100 Subject: [PATCH] corrected plot metric functionality to be able to plot the (total) loss without "main" too --- src/plotting/training_monitoring.py | 16 ++++++++++++++-- src/run_modules/model_setup.py | 4 ++-- src/run_modules/training.py | 7 ++++++- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py index dee36166..617ff313 100644 --- a/src/plotting/training_monitoring.py +++ b/src/plotting/training_monitoring.py @@ -25,7 +25,7 @@ class PlotModelHistory: metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute path for the plot. """ - def __init__(self, filename: str, history: history_object, plot_metric: str = "loss"): + def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = True): """ Sets attributes and create plot :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a @@ -35,10 +35,22 @@ class PlotModelHistory: if isinstance(history, keras.callbacks.History): history = history.history self._data = pd.DataFrame.from_dict(history) - self._plot_metric = plot_metric + self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch) self._additional_columns = self._filter_columns(history) self._plot(filename) + @staticmethod + def _get_plot_metric(history, plot_metric, main_branch): + if plot_metric.lower() == "mse": + plot_metric = "mean_squared_error" + elif plot_metric.lower() == "mae": + plot_metric = "mean_absolute_error" + available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k if main_branch else True)] + print(available_keys) + available_keys.sort(key=len) + print(available_keys) + return available_keys[0] + def _filter_columns(self, history: Dict) -> List[str]: """ Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 0f3ff6d4..a7722018 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -15,8 +15,8 @@ from src.run_modules.run_environment import RunEnvironment from src.helpers import l_p_loss, LearningRateDecay from src.model_modules.inception_model import InceptionModelBase from src.model_modules.flatten import flatten_tail -# from src.model_modules.model_class import MyBranchedModel as MyModel -from src.model_modules.model_class import MyLittleModel as MyModel +from src.model_modules.model_class import MyBranchedModel as MyModel +# from src.model_modules.model_class import MyLittleModel as MyModel class ModelSetup(RunEnvironment): diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 272609a3..d15060da 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -134,5 +134,10 @@ class Training(RunEnvironment): """ path = self.data_store.get("plot_path", "general") name = self.data_store.get("experiment_name", "general") - PlotModelHistory(filename=os.path.join(path, f"{name}_history_loss_val_loss.pdf"), history=history) + filename = os.path.join(path, f"{name}_history_loss.pdf") + PlotModelHistory(filename=filename, history=history, main_branch=False) + filename = os.path.join(path, f"{name}_history_main_loss.pdf") + PlotModelHistory(filename=filename, history=history) + filename = os.path.join(path, f"{name}_history_main_mse.pdf") + PlotModelHistory(filename=filename, history=history, plot_metric="mse") PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) -- GitLab