From 01dc3e2bffa60d443f23ebd8102dc54f5fac8524 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Fri, 26 Feb 2021 16:37:51 +0100 Subject: [PATCH] ensure loss to be a list when logging --- mlair/run_modules/post_processing.py | 2 +- mlair/run_modules/training.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 807f32bb..6f78a03d 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -390,7 +390,7 @@ class PostProcessing(RunEnvironment): use_multiprocessing=True, verbose=0, steps=1) path = self.data_store.get("model_path") with open(os.path.join(path, "test_scores.txt"), "a") as f: - for index, item in enumerate(test_score): + for index, item in enumerate(to_list(test_score)): logging.info(f"{self.model.metrics_names[index]} (test), {item}") f.write(f"{self.model.metrics_names[index]}, {item}\n") diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index bbb3fabf..d4badfe2 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -16,6 +16,7 @@ from mlair.model_modules.keras_extensions import CallbackHandler from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from mlair.run_modules.run_environment import RunEnvironment from mlair.configuration import path_config +from mlair.helpers import to_list class Training(RunEnvironment): @@ -249,5 +250,5 @@ class Training(RunEnvironment): tablefmt="github") val_score = self.model.evaluate_generator(generator=self.val_set, use_multiprocessing=True, verbose=0, steps=1) - for index, item in enumerate(val_score): + for index, item in enumerate(to_list(val_score)): logging.info(f"{self.model.metrics_names[index]} (val), {item}") -- GitLab