Skip to content
Snippets Groups Projects
Commit c3be681e authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue421_feat_calc-number-of-samples-only-if-needed-by-model' into 'develop'

Resolve "Calc number of samples only if needed by model"

See merge request !476
parents 0aa25bbb d63183b9
No related branches found
No related tags found
3 merge requests!500Develop,!499Resolve "release v2.3.0",!476Resolve "Calc number of samples only if needed by model"
Pipeline #110252 passed
......@@ -74,9 +74,6 @@ class ModelSetup(RunEnvironment):
# set channels depending on inputs
self._set_shapes()
# set number of training samples (total)
self._set_num_of_training_samples()
# build model graph using settings from my_model_settings()
self.build_model()
......@@ -109,15 +106,11 @@ class ModelSetup(RunEnvironment):
def _set_num_of_training_samples(self):
""" Set number of training samples - needed for example for Bayesian NNs"""
samples = 0
for s in self.data_store.get("data_collection", "train"):
if isinstance(s.get_Y(), list):
s_sam = s.get_Y()[0].shape[0]
elif isinstance(s.get_Y(), tuple):
s_sam = s.get_Y().shape[0]
else:
s_sam = np.nan
samples += s_sam
self.num_of_training_samples = samples
upsampling = self.data_store.create_args_dict(["upsampling"], "train")
for data in self.data_store.get("data_collection", "train"):
length = data.__len__(**upsampling)
samples += length
return samples
def compile_model(self):
"""
......@@ -179,10 +172,11 @@ class ModelSetup(RunEnvironment):
model = self.data_store.get("model_class")
args_list = model.requirements()
if "num_of_training_samples" in args_list:
self.data_store.set("num_of_training_samples", self.num_of_training_samples, scope=self.scope)
logging.info(f"Store number of training samples ({self.num_of_training_samples}) in data_store: "
f"self.data_store.set('num_of_training_samples', {self.num_of_training_samples}, scope='{self.scope}')")
num_of_training_samples = self._set_num_of_training_samples()
self.data_store.set("num_of_training_samples", num_of_training_samples, scope=self.scope)
logging.info(f"Store number of training samples ({num_of_training_samples}) in data_store: "
f"self.data_store.set('num_of_training_samples', {num_of_training_samples}, scope="
f"'{self.scope}')")
args = self.data_store.create_args_dict(args_list, self.scope)
self.model = model(**args)
self.get_model_settings()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment