diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 6c05a5731056cd3926e319b6d549ca3f64ed3873..c089062428667a2356c16087f454ae4ca94269d7 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -5,12 +5,15 @@ __date__ = '2019-12-02' import logging import os +import re import keras +import pandas as pd import tensorflow as tf from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler from mlair.run_modules.run_environment import RunEnvironment +from mlair.configuration import path_config class ModelSetup(RunEnvironment): @@ -88,6 +91,9 @@ class ModelSetup(RunEnvironment): # compile model self.compile_model() + # report settings + self.report_model() + def _set_channels(self): """Set channels as number of variables of train generator.""" channels = self.data_store.get("generator", "train")[0][0].shape[-1] @@ -147,3 +153,25 @@ class ModelSetup(RunEnvironment): with tf.device("/cpu:0"): file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) + + def report_model(self): + model_settings = self.model.get_settings() + df = pd.DataFrame(columns=["model setting"]) + for k,v in model_settings.items(): + if "<" in str(v): + v = self._clean_name(str(v)) + df.loc[k] = v + df.sort_index(inplace=True) + column_format = "ll" + path = os.path.join(self.data_store.get("experiment_path"), "latex_report") + path_config.check_path_and_create(path) + df.to_latex(os.path.join(path, "model_settings.tex"), na_rep='---', column_format=column_format) + df.to_markdown(open(os.path.join(path, "model_settings.md"), mode="w", encoding='utf-8'), + tablefmt="github") + + @staticmethod + def _clean_name(orig_name: str): + mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ") + mod_name = mod_name[1] if "class" in mod_name[0] else mod_name[0] + return mod_name[:-1] if mod_name[-1] == ">" else mod_name + diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 0c516c577201c62badd3d2b769a291612dce0c5a..23347a30b6e55c6903154128aab055d39045c965 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -15,6 +15,7 @@ from mlair.data_handling import Distributor from mlair.model_modules.keras_extensions import CallbackHandler from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from mlair.run_modules.run_environment import RunEnvironment +from mlair.configuration import path_config class Training(RunEnvironment): @@ -82,6 +83,7 @@ class Training(RunEnvironment): if self._trainable: self.train() self.save_model() + self.report_training() else: logging.info("No training has started, because trainable parameter was false.") @@ -228,3 +230,20 @@ class Training(RunEnvironment): # plot learning rate if lr_sc: PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) + + def report_training(self): + data = {"mini batches": len(self.train_set), + "upsampling extremes": self.train_set.upsampling, + "shuffling": self.train_set.do_data_permutation, + "created new model": self._create_new_model, + "epochs": self.epochs, + "batch size": self.batch_size} + import pandas as pd + df = pd.DataFrame.from_dict(data, orient="index", columns=["training setting"]) + df.sort_index(inplace=True) + column_format = "ll" + path = os.path.join(self.data_store.get("experiment_path"), "latex_report") + path_config.check_path_and_create(path) + df.to_latex(os.path.join(path, "training_settings.tex"), na_rep='---', column_format=column_format) + df.to_markdown(open(os.path.join(path, "training_settings.md"), mode="w", encoding='utf-8'), + tablefmt="github") \ No newline at end of file diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 1f218db5f46b256a4772bdd0521e248d84c54da2..b80570bb51ec5886f163842a3a40411148df3419 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -75,6 +75,7 @@ class TestTraining: os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") obj._trainable = True + obj._create_new_model = False yield obj if os.path.exists(path): shutil.rmtree(path)