Skip to content
Snippets Groups Projects
Commit e4804183 authored by lukas leufen's avatar lukas leufen
Browse files

corrected plot metric functionality to be able to plot the (total) loss without "main" too

parent 012a05cb
No related branches found
No related tags found
2 merge requests!37include new development,!26Lukas issue034 feat plot mse history
Pipeline #28463 passed
...@@ -25,7 +25,7 @@ class PlotModelHistory: ...@@ -25,7 +25,7 @@ class PlotModelHistory:
metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute
path for the plot. path for the plot.
""" """
def __init__(self, filename: str, history: history_object, plot_metric: str = "loss"): def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = True):
""" """
Sets attributes and create plot Sets attributes and create plot
:param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a
...@@ -35,10 +35,22 @@ class PlotModelHistory: ...@@ -35,10 +35,22 @@ class PlotModelHistory:
if isinstance(history, keras.callbacks.History): if isinstance(history, keras.callbacks.History):
history = history.history history = history.history
self._data = pd.DataFrame.from_dict(history) self._data = pd.DataFrame.from_dict(history)
self._plot_metric = plot_metric self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch)
self._additional_columns = self._filter_columns(history) self._additional_columns = self._filter_columns(history)
self._plot(filename) self._plot(filename)
@staticmethod
def _get_plot_metric(history, plot_metric, main_branch):
if plot_metric.lower() == "mse":
plot_metric = "mean_squared_error"
elif plot_metric.lower() == "mae":
plot_metric = "mean_absolute_error"
available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k if main_branch else True)]
print(available_keys)
available_keys.sort(key=len)
print(available_keys)
return available_keys[0]
def _filter_columns(self, history: Dict) -> List[str]: def _filter_columns(self, history: Dict) -> List[str]:
""" """
Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are
......
...@@ -15,8 +15,8 @@ from src.run_modules.run_environment import RunEnvironment ...@@ -15,8 +15,8 @@ from src.run_modules.run_environment import RunEnvironment
from src.helpers import l_p_loss, LearningRateDecay from src.helpers import l_p_loss, LearningRateDecay
from src.model_modules.inception_model import InceptionModelBase from src.model_modules.inception_model import InceptionModelBase
from src.model_modules.flatten import flatten_tail from src.model_modules.flatten import flatten_tail
# from src.model_modules.model_class import MyBranchedModel as MyModel from src.model_modules.model_class import MyBranchedModel as MyModel
from src.model_modules.model_class import MyLittleModel as MyModel # from src.model_modules.model_class import MyLittleModel as MyModel
class ModelSetup(RunEnvironment): class ModelSetup(RunEnvironment):
......
...@@ -134,5 +134,10 @@ class Training(RunEnvironment): ...@@ -134,5 +134,10 @@ class Training(RunEnvironment):
""" """
path = self.data_store.get("plot_path", "general") path = self.data_store.get("plot_path", "general")
name = self.data_store.get("experiment_name", "general") name = self.data_store.get("experiment_name", "general")
PlotModelHistory(filename=os.path.join(path, f"{name}_history_loss_val_loss.pdf"), history=history) filename = os.path.join(path, f"{name}_history_loss.pdf")
PlotModelHistory(filename=filename, history=history, main_branch=False)
filename = os.path.join(path, f"{name}_history_main_loss.pdf")
PlotModelHistory(filename=filename, history=history)
filename = os.path.join(path, f"{name}_history_main_mse.pdf")
PlotModelHistory(filename=filename, history=history, plot_metric="mse")
PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment