Skip to content
Snippets Groups Projects

Resolve "ML param reporting"

Merged Ghost User requested to merge lukas_issue145_feat_ml-param-reporting into develop
3 files
+ 48
0
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -5,12 +5,15 @@ __date__ = '2019-12-02'
@@ -5,12 +5,15 @@ __date__ = '2019-12-02'
import logging
import logging
import os
import os
 
import re
import keras
import keras
 
import pandas as pd
import tensorflow as tf
import tensorflow as tf
from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
from mlair.run_modules.run_environment import RunEnvironment
from mlair.run_modules.run_environment import RunEnvironment
 
from mlair.configuration import path_config
class ModelSetup(RunEnvironment):
class ModelSetup(RunEnvironment):
@@ -88,6 +91,9 @@ class ModelSetup(RunEnvironment):
@@ -88,6 +91,9 @@ class ModelSetup(RunEnvironment):
# compile model
# compile model
self.compile_model()
self.compile_model()
 
# report settings
 
self.report_model()
 
def _set_channels(self):
def _set_channels(self):
"""Set channels as number of variables of train generator."""
"""Set channels as number of variables of train generator."""
channels = self.data_store.get("generator", "train")[0][0].shape[-1]
channels = self.data_store.get("generator", "train")[0][0].shape[-1]
@@ -147,3 +153,25 @@ class ModelSetup(RunEnvironment):
@@ -147,3 +153,25 @@ class ModelSetup(RunEnvironment):
with tf.device("/cpu:0"):
with tf.device("/cpu:0"):
file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
 
 
def report_model(self):
 
model_settings = self.model.get_settings()
 
df = pd.DataFrame(columns=["model setting"])
 
for k,v in model_settings.items():
 
if "<" in str(v):
 
v = self._clean_name(str(v))
 
df.loc[k] = v
 
df.sort_index(inplace=True)
 
column_format = "ll"
 
path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
 
path_config.check_path_and_create(path)
 
df.to_latex(os.path.join(path, "model_settings.tex"), na_rep='---', column_format=column_format)
 
df.to_markdown(open(os.path.join(path, "model_settings.md"), mode="w", encoding='utf-8'),
 
tablefmt="github")
 
 
@staticmethod
 
def _clean_name(orig_name: str):
 
mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
 
mod_name = mod_name[1] if "class" in mod_name[0] else mod_name[0]
 
return mod_name[:-1] if mod_name[-1] == ">" else mod_name
 
Loading