Skip to content
Snippets Groups Projects
Commit 53966bcc authored by leufen1's avatar leufen1
Browse files

improved speed of count training samples, only calculated if required, /close #421

parent c1dfc1b3
No related branches found
No related tags found
3 merge requests!500Develop,!499Resolve "release v2.3.0",!476Resolve "Calc number of samples only if needed by model"
Pipeline #109755 failed
...@@ -74,9 +74,6 @@ class ModelSetup(RunEnvironment): ...@@ -74,9 +74,6 @@ class ModelSetup(RunEnvironment):
# set channels depending on inputs # set channels depending on inputs
self._set_shapes() self._set_shapes()
# set number of training samples (total)
self._set_num_of_training_samples()
# build model graph using settings from my_model_settings() # build model graph using settings from my_model_settings()
self.build_model() self.build_model()
...@@ -109,15 +106,11 @@ class ModelSetup(RunEnvironment): ...@@ -109,15 +106,11 @@ class ModelSetup(RunEnvironment):
def _set_num_of_training_samples(self): def _set_num_of_training_samples(self):
""" Set number of training samples - needed for example for Bayesian NNs""" """ Set number of training samples - needed for example for Bayesian NNs"""
samples = 0 samples = 0
for s in self.data_store.get("data_collection", "train"): upsampling = self.data_store.create_args_dict(["upsampling"], "train")
if isinstance(s.get_Y(), list): for data in self.data_store.get("data_collection", "train"):
s_sam = s.get_Y()[0].shape[0] length = data.__len__(**upsampling)
elif isinstance(s.get_Y(), tuple): samples += length
s_sam = s.get_Y().shape[0] return samples
else:
s_sam = np.nan
samples += s_sam
self.num_of_training_samples = samples
def compile_model(self): def compile_model(self):
""" """
...@@ -179,10 +172,11 @@ class ModelSetup(RunEnvironment): ...@@ -179,10 +172,11 @@ class ModelSetup(RunEnvironment):
model = self.data_store.get("model_class") model = self.data_store.get("model_class")
args_list = model.requirements() args_list = model.requirements()
if "num_of_training_samples" in args_list: if "num_of_training_samples" in args_list:
self.data_store.set("num_of_training_samples", self.num_of_training_samples, scope=self.scope) num_of_training_samples = self._set_num_of_training_samples()
logging.info(f"Store number of training samples ({self.num_of_training_samples}) in data_store: " self.data_store.set("num_of_training_samples", num_of_training_samples, scope=self.scope)
f"self.data_store.set('num_of_training_samples', {self.num_of_training_samples}, scope='{self.scope}')") logging.info(f"Store number of training samples ({num_of_training_samples}) in data_store: "
f"self.data_store.set('num_of_training_samples', {num_of_training_samples}, scope="
f"'{self.scope}')")
args = self.data_store.create_args_dict(args_list, self.scope) args = self.data_store.create_args_dict(args_list, self.scope)
self.model = model(**args) self.model = model(**args)
self.get_model_settings() self.get_model_settings()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment