From 53966bcc342c164dc366e661fcb991a080edf5c8 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 23 Aug 2022 16:21:43 +0200
Subject: [PATCH] improved speed of count training samples, only calculated if
 required, /close #421

---
 mlair/run_modules/model_setup.py | 26 ++++++++++----------------
 1 file changed, 10 insertions(+), 16 deletions(-)

diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index b51a3f9c..efeff062 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -74,9 +74,6 @@ class ModelSetup(RunEnvironment):
         # set channels depending on inputs
         self._set_shapes()
 
-        # set number of training samples (total)
-        self._set_num_of_training_samples()
-
         # build model graph using settings from my_model_settings()
         self.build_model()
 
@@ -109,15 +106,11 @@ class ModelSetup(RunEnvironment):
     def _set_num_of_training_samples(self):
         """ Set number of training samples - needed for example for Bayesian NNs"""
         samples = 0
-        for s in self.data_store.get("data_collection", "train"):
-            if isinstance(s.get_Y(), list):
-                s_sam = s.get_Y()[0].shape[0]
-            elif isinstance(s.get_Y(), tuple):
-                s_sam = s.get_Y().shape[0]
-            else:
-                s_sam = np.nan
-            samples += s_sam
-        self.num_of_training_samples = samples
+        upsampling = self.data_store.create_args_dict(["upsampling"], "train")
+        for data in self.data_store.get("data_collection", "train"):
+            length = data.__len__(**upsampling)
+            samples += length
+        return samples
 
     def compile_model(self):
         """
@@ -179,10 +172,11 @@ class ModelSetup(RunEnvironment):
         model = self.data_store.get("model_class")
         args_list = model.requirements()
         if "num_of_training_samples" in args_list:
-            self.data_store.set("num_of_training_samples", self.num_of_training_samples, scope=self.scope)
-            logging.info(f"Store number of training samples ({self.num_of_training_samples}) in data_store: "
-                         f"self.data_store.set('num_of_training_samples', {self.num_of_training_samples}, scope='{self.scope}')")
-
+            num_of_training_samples = self._set_num_of_training_samples()
+            self.data_store.set("num_of_training_samples", num_of_training_samples, scope=self.scope)
+            logging.info(f"Store number of training samples ({num_of_training_samples}) in data_store: "
+                         f"self.data_store.set('num_of_training_samples', {num_of_training_samples}, scope="
+                         f"'{self.scope}')")
         args = self.data_store.create_args_dict(args_list, self.scope)
         self.model = model(**args)
         self.get_model_settings()
-- 
GitLab