Skip to content
Snippets Groups Projects
Commit 270c9027 authored by lukas leufen's avatar lukas leufen
Browse files

create model settings reporting

parent 50ec16b2
Branches
Tags
4 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!122Resolve "ML param reporting",!119Resolve "Include advanced data handling in workflow"
Pipeline #41033 passed
......@@ -5,8 +5,10 @@ __date__ = '2019-12-02'
import logging
import os
import re
import keras
import pandas as pd
import tensorflow as tf
from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
......@@ -88,6 +90,9 @@ class ModelSetup(RunEnvironment):
# compile model
self.compile_model()
# report settings
self.report_model()
def _set_channels(self):
"""Set channels as number of variables of train generator."""
channels = self.data_store.get("generator", "train")[0][0].shape[-1]
......@@ -147,3 +152,24 @@ class ModelSetup(RunEnvironment):
with tf.device("/cpu:0"):
file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
def report_model(self):
model_settings = self.model.get_settings()
df = pd.DataFrame(columns=["model setting"])
for k,v in model_settings.items():
if "<" in str(v):
v = self._clean_name(str(v))
df.loc[k] = v
df.sort_index(inplace=True)
column_format = "ll"
path = self.data_store.get("model_path")
df.to_latex(os.path.join(path, "model_settings.tex"), na_rep='---', column_format=column_format)
df.to_markdown(open(os.path.join(path, "model_settings.md"), mode="w", encoding='utf-8'),
tablefmt="github")
@staticmethod
def _clean_name(orig_name: str):
mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
mod_name = mod_name[1] if "class" in mod_name[0] else mod_name[0]
return mod_name[:-1] if mod_name[-1] == ">" else mod_name
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment