diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 6c05a5731056cd3926e319b6d549ca3f64ed3873..83efac734c079602ab6003a7b690bea895233a93 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -5,8 +5,10 @@ __date__ = '2019-12-02' import logging import os +import re import keras +import pandas as pd import tensorflow as tf from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler @@ -88,6 +90,9 @@ class ModelSetup(RunEnvironment): # compile model self.compile_model() + # report settings + self.report_model() + def _set_channels(self): """Set channels as number of variables of train generator.""" channels = self.data_store.get("generator", "train")[0][0].shape[-1] @@ -147,3 +152,24 @@ class ModelSetup(RunEnvironment): with tf.device("/cpu:0"): file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) + + def report_model(self): + model_settings = self.model.get_settings() + df = pd.DataFrame(columns=["model setting"]) + for k,v in model_settings.items(): + if "<" in str(v): + v = self._clean_name(str(v)) + df.loc[k] = v + df.sort_index(inplace=True) + column_format = "ll" + path = self.data_store.get("model_path") + df.to_latex(os.path.join(path, "model_settings.tex"), na_rep='---', column_format=column_format) + df.to_markdown(open(os.path.join(path, "model_settings.md"), mode="w", encoding='utf-8'), + tablefmt="github") + + @staticmethod + def _clean_name(orig_name: str): + mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ") + mod_name = mod_name[1] if "class" in mod_name[0] else mod_name[0] + return mod_name[:-1] if mod_name[-1] == ">" else mod_name +