Skip to content
Snippets Groups Projects
Commit cba186aa authored by leufen1's avatar leufen1
Browse files

use table methods in training too,

parent e54ef541
No related branches found
No related tags found
6 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!283Merge latest develop into falcos issue,!264Merge develop into felix_issue287_tech-wrf-datahandler-should-inherit-from-singlestationdatahandler,!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler",!257Resolve "store performance measures"
...@@ -10,13 +10,15 @@ from typing import Union ...@@ -10,13 +10,15 @@ from typing import Union
import keras import keras
from keras.callbacks import Callback, History from keras.callbacks import Callback, History
import psutil
import pandas as pd
from mlair.data_handler import KerasIterator from mlair.data_handler import KerasIterator
from mlair.model_modules.keras_extensions import CallbackHandler from mlair.model_modules.keras_extensions import CallbackHandler
from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
from mlair.run_modules.run_environment import RunEnvironment from mlair.run_modules.run_environment import RunEnvironment
from mlair.configuration import path_config from mlair.configuration import path_config
from mlair.helpers import to_list from mlair.helpers import to_list, tables
class Training(RunEnvironment): class Training(RunEnvironment):
...@@ -141,7 +143,8 @@ class Training(RunEnvironment): ...@@ -141,7 +143,8 @@ class Training(RunEnvironment):
verbose=2, verbose=2,
validation_data=self.val_set, validation_data=self.val_set,
validation_steps=len(self.val_set), validation_steps=len(self.val_set),
callbacks=self.callbacks.get_callbacks(as_dict=False)) callbacks=self.callbacks.get_callbacks(as_dict=False),
workers=psutil.cpu_count(logical=False))
else: else:
logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
self.callbacks.load_callbacks() self.callbacks.load_callbacks()
...@@ -156,7 +159,8 @@ class Training(RunEnvironment): ...@@ -156,7 +159,8 @@ class Training(RunEnvironment):
validation_data=self.val_set, validation_data=self.val_set,
validation_steps=len(self.val_set), validation_steps=len(self.val_set),
callbacks=self.callbacks.get_callbacks(as_dict=False), callbacks=self.callbacks.get_callbacks(as_dict=False),
initial_epoch=initial_epoch) initial_epoch=initial_epoch,
workers=psutil.cpu_count(logical=False))
history = hist history = hist
try: try:
lr = self.callbacks.get_callback_by_name("lr") lr = self.callbacks.get_callback_by_name("lr")
...@@ -233,22 +237,26 @@ class Training(RunEnvironment): ...@@ -233,22 +237,26 @@ class Training(RunEnvironment):
PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
def report_training(self): def report_training(self):
# create training summary
data = {"mini batches": len(self.train_set), data = {"mini batches": len(self.train_set),
"upsampling extremes": self.train_set.upsampling, "upsampling extremes": self.train_set.upsampling,
"shuffling": self.train_set.shuffle, "shuffling": self.train_set.shuffle,
"created new model": self._create_new_model, "created new model": self._create_new_model,
"epochs": self.epochs, "epochs": self.epochs,
"batch size": self.batch_size} "batch size": self.batch_size}
import pandas as pd
df = pd.DataFrame.from_dict(data, orient="index", columns=["training setting"]) df = pd.DataFrame.from_dict(data, orient="index", columns=["training setting"])
df.sort_index(inplace=True) df.sort_index(inplace=True)
column_format = "ll"
path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
path_config.check_path_and_create(path) path_config.check_path_and_create(path)
df.to_latex(os.path.join(path, "training_settings.tex"), na_rep='---', column_format=column_format)
df.to_markdown(open(os.path.join(path, "training_settings.md"), mode="w", encoding='utf-8'),
tablefmt="github")
val_score = self.model.evaluate_generator(generator=self.val_set, use_multiprocessing=True, verbose=0, steps=1) # store as .tex and .md
for index, item in enumerate(to_list(val_score)): tables.save_to_tex(path, "training_settings.tex", column_format="ll", df=df)
logging.info(f"{self.model.metrics_names[index]} (val), {item}") tables.save_to_md(path, "training_settings.md", df=df)
# calculate val scores
val_score = self.model.evaluate_generator(generator=self.val_set, use_multiprocessing=True, verbose=0)
path = self.data_store.get("model_path")
with open(os.path.join(path, "val_scores.txt"), "a") as f:
for index, item in enumerate(to_list(val_score)):
logging.info(f"{self.model.metrics_names[index]} (val), {item}")
f.write(f"{self.model.metrics_names[index]}, {item}\n")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment