From 53966bcc342c164dc366e661fcb991a080edf5c8 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 23 Aug 2022 16:21:43 +0200 Subject: [PATCH] improved speed of count training samples, only calculated if required, /close #421 --- mlair/run_modules/model_setup.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index b51a3f9c..efeff062 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -74,9 +74,6 @@ class ModelSetup(RunEnvironment): # set channels depending on inputs self._set_shapes() - # set number of training samples (total) - self._set_num_of_training_samples() - # build model graph using settings from my_model_settings() self.build_model() @@ -109,15 +106,11 @@ class ModelSetup(RunEnvironment): def _set_num_of_training_samples(self): """ Set number of training samples - needed for example for Bayesian NNs""" samples = 0 - for s in self.data_store.get("data_collection", "train"): - if isinstance(s.get_Y(), list): - s_sam = s.get_Y()[0].shape[0] - elif isinstance(s.get_Y(), tuple): - s_sam = s.get_Y().shape[0] - else: - s_sam = np.nan - samples += s_sam - self.num_of_training_samples = samples + upsampling = self.data_store.create_args_dict(["upsampling"], "train") + for data in self.data_store.get("data_collection", "train"): + length = data.__len__(**upsampling) + samples += length + return samples def compile_model(self): """ @@ -179,10 +172,11 @@ class ModelSetup(RunEnvironment): model = self.data_store.get("model_class") args_list = model.requirements() if "num_of_training_samples" in args_list: - self.data_store.set("num_of_training_samples", self.num_of_training_samples, scope=self.scope) - logging.info(f"Store number of training samples ({self.num_of_training_samples}) in data_store: " - f"self.data_store.set('num_of_training_samples', {self.num_of_training_samples}, scope='{self.scope}')") - + num_of_training_samples = self._set_num_of_training_samples() + self.data_store.set("num_of_training_samples", num_of_training_samples, scope=self.scope) + logging.info(f"Store number of training samples ({num_of_training_samples}) in data_store: " + f"self.data_store.set('num_of_training_samples', {num_of_training_samples}, scope=" + f"'{self.scope}')") args = self.data_store.create_args_dict(args_list, self.scope) self.model = model(**args) self.get_model_settings() -- GitLab