Skip to content
Snippets Groups Projects
Commit 270c9027 authored by lukas leufen's avatar lukas leufen
Browse files

create model settings reporting

parent 50ec16b2
No related branches found
No related tags found
4 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!122Resolve "ML param reporting",!119Resolve "Include advanced data handling in workflow"
Pipeline #41033 passed
...@@ -5,8 +5,10 @@ __date__ = '2019-12-02' ...@@ -5,8 +5,10 @@ __date__ = '2019-12-02'
import logging import logging
import os import os
import re
import keras import keras
import pandas as pd
import tensorflow as tf import tensorflow as tf
from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
...@@ -88,6 +90,9 @@ class ModelSetup(RunEnvironment): ...@@ -88,6 +90,9 @@ class ModelSetup(RunEnvironment):
# compile model # compile model
self.compile_model() self.compile_model()
# report settings
self.report_model()
def _set_channels(self): def _set_channels(self):
"""Set channels as number of variables of train generator.""" """Set channels as number of variables of train generator."""
channels = self.data_store.get("generator", "train")[0][0].shape[-1] channels = self.data_store.get("generator", "train")[0][0].shape[-1]
...@@ -147,3 +152,24 @@ class ModelSetup(RunEnvironment): ...@@ -147,3 +152,24 @@ class ModelSetup(RunEnvironment):
with tf.device("/cpu:0"): with tf.device("/cpu:0"):
file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
def report_model(self):
model_settings = self.model.get_settings()
df = pd.DataFrame(columns=["model setting"])
for k,v in model_settings.items():
if "<" in str(v):
v = self._clean_name(str(v))
df.loc[k] = v
df.sort_index(inplace=True)
column_format = "ll"
path = self.data_store.get("model_path")
df.to_latex(os.path.join(path, "model_settings.tex"), na_rep='---', column_format=column_format)
df.to_markdown(open(os.path.join(path, "model_settings.md"), mode="w", encoding='utf-8'),
tablefmt="github")
@staticmethod
def _clean_name(orig_name: str):
mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
mod_name = mod_name[1] if "class" in mod_name[0] else mod_name[0]
return mod_name[:-1] if mod_name[-1] == ">" else mod_name
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment