Skip to content
Snippets Groups Projects
Commit 01dc6fb2 authored by leufen1's avatar leufen1
Browse files

log more model information during model setup stage

parent 56f3657c
Branches
Tags
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!300include cnn class,!271Resolve "create CNN model class",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #64202 passed
...@@ -9,6 +9,7 @@ chardet==4.0.0 ...@@ -9,6 +9,7 @@ chardet==4.0.0
coverage==5.4 coverage==5.4
cycler==0.10.0 cycler==0.10.0
dask==2021.2.0 dask==2021.2.0
dill==0.3.3
fsspec==0.8.5 fsspec==0.8.5
gast==0.4.0 gast==0.4.0
grpcio==1.35.0 grpcio==1.35.0
......
...@@ -9,6 +9,7 @@ chardet==4.0.0 ...@@ -9,6 +9,7 @@ chardet==4.0.0
coverage==5.4 coverage==5.4
cycler==0.10.0 cycler==0.10.0
dask==2021.2.0 dask==2021.2.0
dill==0.3.3
fsspec==0.8.5 fsspec==0.8.5
gast==0.4.0 gast==0.4.0
grpcio==1.35.0 grpcio==1.35.0
......
...@@ -6,6 +6,7 @@ __date__ = '2019-12-02' ...@@ -6,6 +6,7 @@ __date__ = '2019-12-02'
import logging import logging
import os import os
import re import re
from dill.source import getsource
import keras import keras
import pandas as pd import pandas as pd
...@@ -57,12 +58,12 @@ class ModelSetup(RunEnvironment): ...@@ -57,12 +58,12 @@ class ModelSetup(RunEnvironment):
super().__init__() super().__init__()
self.model = None self.model = None
exp_name = self.data_store.get("experiment_name") exp_name = self.data_store.get("experiment_name")
path = self.data_store.get("model_path") self.path = self.data_store.get("model_path")
self.scope = "model" self.scope = "model"
self.path = os.path.join(path, f"{exp_name}_%s") path = os.path.join(self.path, f"{exp_name}_%s")
self.model_name = self.path % "%s.h5" self.model_name = path % "%s.h5"
self.checkpoint_name = self.path % "model-best.h5" self.checkpoint_name = path % "model-best.h5"
self.callbacks_name = self.path % "model-best-callbacks-%s.pickle" self.callbacks_name = path % "model-best-callbacks-%s.pickle"
self._train_model = self.data_store.get("train_model") self._train_model = self.data_store.get("train_model")
self._create_new_model = self.data_store.get("create_new_model") self._create_new_model = self.data_store.get("create_new_model")
self._run() self._run()
...@@ -167,6 +168,7 @@ class ModelSetup(RunEnvironment): ...@@ -167,6 +168,7 @@ class ModelSetup(RunEnvironment):
keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
def report_model(self): def report_model(self):
# report model settings
model_settings = self.model.get_settings() model_settings = self.model.get_settings()
model_settings.update(self.model.compile_options) model_settings.update(self.model.compile_options)
model_settings.update(self.model.optimizer.get_config()) model_settings.update(self.model.optimizer.get_config())
...@@ -179,17 +181,23 @@ class ModelSetup(RunEnvironment): ...@@ -179,17 +181,23 @@ class ModelSetup(RunEnvironment):
if "<" in str(v): if "<" in str(v):
v = self._clean_name(str(v)) v = self._clean_name(str(v))
df.loc[k] = str(v) df.loc[k] = str(v)
df.loc["count params"] = str(self.model.count_params())
df.sort_index(inplace=True) df.sort_index(inplace=True)
column_format = "ll" column_format = "ll"
path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
path_config.check_path_and_create(path) path_config.check_path_and_create(path)
df.to_latex(os.path.join(path, "model_settings.tex"), na_rep='---', column_format=column_format) for p in [path, self.path]: # log to `latex_report` and `model`
df.to_markdown(open(os.path.join(path, "model_settings.md"), mode="w", encoding='utf-8'), df.to_latex(os.path.join(p, "model_settings.tex"), na_rep='---', column_format=column_format)
tablefmt="github") df.to_markdown(open(os.path.join(p, "model_settings.md"), mode="w", encoding='utf-8'), tablefmt="github")
# report model summary to file
with open(os.path.join(self.path, "model_summary.txt"), "w") as fh:
self.model.summary(print_fn=lambda x: fh.write(x + "\n"))
# print model code to file
with open(os.path.join(self.path, "model_code.txt"), "w") as fh:
fh.write(getsource(self.data_store.get("model_class")))
@staticmethod @staticmethod
def _clean_name(orig_name: str): def _clean_name(orig_name: str):
mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ") mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name[0] mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name[0]
return mod_name[:-1] if mod_name[-1] == ">" else mod_name return mod_name[:-1] if mod_name[-1] == ">" else mod_name
...@@ -9,6 +9,7 @@ chardet==4.0.0 ...@@ -9,6 +9,7 @@ chardet==4.0.0
coverage==5.4 coverage==5.4
cycler==0.10.0 cycler==0.10.0
dask==2021.2.0 dask==2021.2.0
dill==0.3.3
fsspec==0.8.5 fsspec==0.8.5
gast==0.4.0 gast==0.4.0
grpcio==1.35.0 grpcio==1.35.0
......
...@@ -9,6 +9,7 @@ chardet==4.0.0 ...@@ -9,6 +9,7 @@ chardet==4.0.0
coverage==5.4 coverage==5.4
cycler==0.10.0 cycler==0.10.0
dask==2021.2.0 dask==2021.2.0
dill==0.3.3
fsspec==0.8.5 fsspec==0.8.5
gast==0.4.0 gast==0.4.0
grpcio==1.35.0 grpcio==1.35.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment