From ee7569ab774ec8a698f0f29f78856677e2d08a61 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 14 Jul 2020 15:59:23 +0200
Subject: [PATCH] training reports now too

---
 mlair/run_modules/model_setup.py |  4 +++-
 mlair/run_modules/training.py    | 19 +++++++++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 83efac73..c0890624 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -13,6 +13,7 @@ import tensorflow as tf
 
 from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
 from mlair.run_modules.run_environment import RunEnvironment
+from mlair.configuration import path_config
 
 
 class ModelSetup(RunEnvironment):
@@ -162,7 +163,8 @@ class ModelSetup(RunEnvironment):
             df.loc[k] = v
         df.sort_index(inplace=True)
         column_format = "ll"
-        path = self.data_store.get("model_path")
+        path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
+        path_config.check_path_and_create(path)
         df.to_latex(os.path.join(path, "model_settings.tex"), na_rep='---', column_format=column_format)
         df.to_markdown(open(os.path.join(path, "model_settings.md"), mode="w", encoding='utf-8'),
                        tablefmt="github")
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 0c516c57..23347a30 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -15,6 +15,7 @@ from mlair.data_handling import Distributor
 from mlair.model_modules.keras_extensions import CallbackHandler
 from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
 from mlair.run_modules.run_environment import RunEnvironment
+from mlair.configuration import path_config
 
 
 class Training(RunEnvironment):
@@ -82,6 +83,7 @@ class Training(RunEnvironment):
         if self._trainable:
             self.train()
             self.save_model()
+            self.report_training()
         else:
             logging.info("No training has started, because trainable parameter was false.")
 
@@ -228,3 +230,20 @@ class Training(RunEnvironment):
         # plot learning rate
         if lr_sc:
             PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
+
+    def report_training(self):
+        data = {"mini batches": len(self.train_set),
+                "upsampling extremes": self.train_set.upsampling,
+                "shuffling": self.train_set.do_data_permutation,
+                "created new model": self._create_new_model,
+                "epochs": self.epochs,
+                "batch size": self.batch_size}
+        import pandas as pd
+        df = pd.DataFrame.from_dict(data, orient="index", columns=["training setting"])
+        df.sort_index(inplace=True)
+        column_format = "ll"
+        path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
+        path_config.check_path_and_create(path)
+        df.to_latex(os.path.join(path, "training_settings.tex"), na_rep='---', column_format=column_format)
+        df.to_markdown(open(os.path.join(path, "training_settings.md"), mode="w", encoding='utf-8'),
+                       tablefmt="github")
\ No newline at end of file
-- 
GitLab