Skip to content
Snippets Groups Projects
Commit e2fa9410 authored by leufen1's avatar leufen1
Browse files

better model settings tracking

parent 48ddb38d
No related branches found
No related tags found
5 merge requests!430update recent developments,!413update release branch,!412Resolve "release v2.0.0",!406Lukas issue368 feat prepare cnn class for filter benchmarking,!403Resolve "prepare CNN class for filter benchmarking"
Pipeline #93914 passed
......@@ -68,6 +68,7 @@ class CNNfromConfig(AbstractModelClass):
self.activation_output_name = activation_output
self.kwargs = kwargs
self.optimizer = self._set_optimizer(optimizer, **kwargs)
self._layer_save = []
# apply to model
self.set_model()
......@@ -84,6 +85,7 @@ class CNNfromConfig(AbstractModelClass):
x_in = layer(**layer_kwargs)(x_in)
if follow_up_layer is not None:
x_in = follow_up_layer()(x_in)
self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer})
x_in = keras.layers.Dense(self._output_shape)(x_in)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
......
......@@ -172,6 +172,7 @@ class ModelSetup(RunEnvironment):
def report_model(self):
# report model settings
_f = self._clean_name
model_settings = self.model.get_settings()
model_settings.update(self.model.compile_options)
model_settings.update(self.model.optimizer.get_config())
......@@ -180,9 +181,12 @@ class ModelSetup(RunEnvironment):
if v is None:
continue
if isinstance(v, list):
v = ",".join(self._clean_name(str(u)) for u in v)
if isinstance(v[0], dict):
v = ["{" + vi + "}" for vi in [",".join(f"{_f(str(uk))}:{_f(str(uv))}" for uk, uv in d.items()) for d in v]]
else:
v = ",".join(_f(str(u)) for u in v)
if "<" in str(v):
v = self._clean_name(str(v))
v = _f(str(v))
df.loc[k] = str(v)
df.loc["count params"] = str(self.model.count_params())
df.sort_index(inplace=True)
......@@ -203,5 +207,7 @@ class ModelSetup(RunEnvironment):
def _clean_name(orig_name: str):
mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name
mod_name = mod_name[0] if len(mod_name) == 1 else mod_name
return mod_name[:-1] if mod_name[-1] == ">" else "".join(mod_name)
mod_name = mod_name[0].split(".")[-1] if any(
map(lambda x: x in mod_name[0], ["tensorflow", "keras"])) else mod_name
mod_name = mod_name[:-1] if mod_name[-1] == ">" else "".join(mod_name)
return mod_name.split(".")[-1] if any(map(lambda x: x in mod_name, ["tensorflow", "keras"])) else mod_name
......@@ -126,6 +126,14 @@ class TestModelSetup:
def test_init(self):
pass
def test_clean_name(self, setup):
in_str = "<tensorflow.python.keras.initializers.initializers_v2.HeNormal object at 0x7fecfa0da9b0>"
assert setup._clean_name(in_str) == "HeNormal"
in_str = "<class 'tensorflow.python.keras.layers.convolutional.Conv2D'>"
assert setup._clean_name(in_str) == "Conv2D"
in_str = "default"
assert setup._clean_name(in_str) == "default"
class DummyData:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment