From 01dc6fb2e6c26bbb03ee2c3d1827c77f24109743 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Fri, 26 Mar 2021 12:45:38 +0100
Subject: [PATCH] log more model information during model setup stage

---
 HPC_setup/requirements_HDFML_additionals.txt  |  1 +
 HPC_setup/requirements_JUWELS_additionals.txt |  1 +
 mlair/run_modules/model_setup.py              | 26 ++++++++++++-------
 requirements.txt                              |  1 +
 requirements_gpu.txt                          |  1 +
 5 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/HPC_setup/requirements_HDFML_additionals.txt b/HPC_setup/requirements_HDFML_additionals.txt
index 12e09ccd..7d6163a6 100644
--- a/HPC_setup/requirements_HDFML_additionals.txt
+++ b/HPC_setup/requirements_HDFML_additionals.txt
@@ -9,6 +9,7 @@ chardet==4.0.0
 coverage==5.4
 cycler==0.10.0
 dask==2021.2.0
+dill==0.3.3
 fsspec==0.8.5
 gast==0.4.0
 grpcio==1.35.0
diff --git a/HPC_setup/requirements_JUWELS_additionals.txt b/HPC_setup/requirements_JUWELS_additionals.txt
index 12e09ccd..7d6163a6 100644
--- a/HPC_setup/requirements_JUWELS_additionals.txt
+++ b/HPC_setup/requirements_JUWELS_additionals.txt
@@ -9,6 +9,7 @@ chardet==4.0.0
 coverage==5.4
 cycler==0.10.0
 dask==2021.2.0
+dill==0.3.3
 fsspec==0.8.5
 gast==0.4.0
 grpcio==1.35.0
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 5dd73d50..8fae430f 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -6,6 +6,7 @@ __date__ = '2019-12-02'
 import logging
 import os
 import re
+from dill.source import getsource
 
 import keras
 import pandas as pd
@@ -57,12 +58,12 @@ class ModelSetup(RunEnvironment):
         super().__init__()
         self.model = None
         exp_name = self.data_store.get("experiment_name")
-        path = self.data_store.get("model_path")
+        self.path = self.data_store.get("model_path")
         self.scope = "model"
-        self.path = os.path.join(path, f"{exp_name}_%s")
-        self.model_name = self.path % "%s.h5"
-        self.checkpoint_name = self.path % "model-best.h5"
-        self.callbacks_name = self.path % "model-best-callbacks-%s.pickle"
+        path = os.path.join(self.path, f"{exp_name}_%s")
+        self.model_name = path % "%s.h5"
+        self.checkpoint_name = path % "model-best.h5"
+        self.callbacks_name = path % "model-best-callbacks-%s.pickle"
         self._train_model = self.data_store.get("train_model")
         self._create_new_model = self.data_store.get("create_new_model")
         self._run()
@@ -167,6 +168,7 @@ class ModelSetup(RunEnvironment):
             keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
 
     def report_model(self):
+        # report model settings
         model_settings = self.model.get_settings()
         model_settings.update(self.model.compile_options)
         model_settings.update(self.model.optimizer.get_config())
@@ -179,17 +181,23 @@ class ModelSetup(RunEnvironment):
             if "<" in str(v):
                 v = self._clean_name(str(v))
             df.loc[k] = str(v)
+        df.loc["count params"] = str(self.model.count_params())
         df.sort_index(inplace=True)
         column_format = "ll"
         path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
         path_config.check_path_and_create(path)
-        df.to_latex(os.path.join(path, "model_settings.tex"), na_rep='---', column_format=column_format)
-        df.to_markdown(open(os.path.join(path, "model_settings.md"), mode="w", encoding='utf-8'),
-                       tablefmt="github")
+        for p in [path, self.path]:  # log to `latex_report` and `model`
+            df.to_latex(os.path.join(p, "model_settings.tex"), na_rep='---', column_format=column_format)
+            df.to_markdown(open(os.path.join(p, "model_settings.md"), mode="w", encoding='utf-8'), tablefmt="github")
+        # report model summary to file
+        with open(os.path.join(self.path, "model_summary.txt"), "w") as fh:
+            self.model.summary(print_fn=lambda x: fh.write(x + "\n"))
+        # print model code to file
+        with open(os.path.join(self.path, "model_code.txt"), "w") as fh:
+            fh.write(getsource(self.data_store.get("model_class")))
 
     @staticmethod
     def _clean_name(orig_name: str):
         mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
         mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name[0]
         return mod_name[:-1] if mod_name[-1] == ">" else mod_name
-
diff --git a/requirements.txt b/requirements.txt
index b0a6e7f5..af742fde 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,6 +9,7 @@ chardet==4.0.0
 coverage==5.4
 cycler==0.10.0
 dask==2021.2.0
+dill==0.3.3
 fsspec==0.8.5
 gast==0.4.0
 grpcio==1.35.0
diff --git a/requirements_gpu.txt b/requirements_gpu.txt
index 35fe0d5e..7dd443a4 100644
--- a/requirements_gpu.txt
+++ b/requirements_gpu.txt
@@ -9,6 +9,7 @@ chardet==4.0.0
 coverage==5.4
 cycler==0.10.0
 dask==2021.2.0
+dill==0.3.3
 fsspec==0.8.5
 gast==0.4.0
 grpcio==1.35.0
-- 
GitLab