From e2fa941062bb578e00b511f45a7291a88a31ada3 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 3 Mar 2022 16:35:36 +0100
Subject: [PATCH] better model settings tracking

---
 mlair/model_modules/convolutional_networks.py |  2 ++
 mlair/run_modules/model_setup.py              | 14 ++++++++++----
 test/test_run_modules/test_model_setup.py     | 10 +++++++++-
 3 files changed, 21 insertions(+), 5 deletions(-)

diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py
index 486441db..7f686734 100644
--- a/mlair/model_modules/convolutional_networks.py
+++ b/mlair/model_modules/convolutional_networks.py
@@ -68,6 +68,7 @@ class CNNfromConfig(AbstractModelClass):
         self.activation_output_name = activation_output
         self.kwargs = kwargs
         self.optimizer = self._set_optimizer(optimizer, **kwargs)
+        self._layer_save = []
 
         # apply to model
         self.set_model()
@@ -84,6 +85,7 @@ class CNNfromConfig(AbstractModelClass):
             x_in = layer(**layer_kwargs)(x_in)
             if follow_up_layer is not None:
                 x_in = follow_up_layer()(x_in)
+            self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer})
 
         x_in = keras.layers.Dense(self._output_shape)(x_in)
         out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 2d6d8396..4e9f8fa4 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -172,6 +172,7 @@ class ModelSetup(RunEnvironment):
 
     def report_model(self):
         # report model settings
+        _f = self._clean_name
         model_settings = self.model.get_settings()
         model_settings.update(self.model.compile_options)
         model_settings.update(self.model.optimizer.get_config())
@@ -180,9 +181,12 @@ class ModelSetup(RunEnvironment):
             if v is None:
                 continue
             if isinstance(v, list):
-                v = ",".join(self._clean_name(str(u)) for u in v)
+                if isinstance(v[0], dict):
+                    v = ["{" + vi + "}" for vi in [",".join(f"{_f(str(uk))}:{_f(str(uv))}" for uk, uv in d.items()) for d in v]]
+                else:
+                    v = ",".join(_f(str(u)) for u in v)
             if "<" in str(v):
-                v = self._clean_name(str(v))
+                v = _f(str(v))
             df.loc[k] = str(v)
         df.loc["count params"] = str(self.model.count_params())
         df.sort_index(inplace=True)
@@ -203,5 +207,7 @@ class ModelSetup(RunEnvironment):
     def _clean_name(orig_name: str):
         mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
         mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name
-        mod_name = mod_name[0] if len(mod_name) == 1 else mod_name
-        return mod_name[:-1] if mod_name[-1] == ">" else "".join(mod_name)
+        mod_name = mod_name[0].split(".")[-1] if any(
+            map(lambda x: x in mod_name[0], ["tensorflow", "keras"])) else mod_name
+        mod_name = mod_name[:-1] if mod_name[-1] == ">" else "".join(mod_name)
+        return mod_name.split(".")[-1] if any(map(lambda x: x in mod_name, ["tensorflow", "keras"])) else mod_name
diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py
index 7cefd0e5..60b37207 100644
--- a/test/test_run_modules/test_model_setup.py
+++ b/test/test_run_modules/test_model_setup.py
@@ -126,6 +126,14 @@ class TestModelSetup:
     def test_init(self):
         pass
 
+    def test_clean_name(self, setup):
+        in_str = "<tensorflow.python.keras.initializers.initializers_v2.HeNormal object at 0x7fecfa0da9b0>"
+        assert setup._clean_name(in_str) == "HeNormal"
+        in_str = "<class 'tensorflow.python.keras.layers.convolutional.Conv2D'>"
+        assert setup._clean_name(in_str) == "Conv2D"
+        in_str = "default"
+        assert setup._clean_name(in_str) == "default"
+
 
 class DummyData:
 
@@ -141,4 +149,4 @@ class DummyData:
     def get_Y(self, upsampling=False, as_numpy=True):
         Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5))  # samples, window
         Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 3))  # samples, window
-        return [Y1, Y2]
\ No newline at end of file
+        return [Y1, Y2]
-- 
GitLab