Skip to content
Snippets Groups Projects

Resolve "release v2.3.0"

Merged Ghost User requested to merge release_v2.3.0 into master
1 file
+ 10
16
Compare changes
  • Side-by-side
  • Inline
@@ -74,9 +74,6 @@ class ModelSetup(RunEnvironment):
@@ -74,9 +74,6 @@ class ModelSetup(RunEnvironment):
# set channels depending on inputs
# set channels depending on inputs
self._set_shapes()
self._set_shapes()
# set number of training samples (total)
self._set_num_of_training_samples()
# build model graph using settings from my_model_settings()
# build model graph using settings from my_model_settings()
self.build_model()
self.build_model()
@@ -109,15 +106,11 @@ class ModelSetup(RunEnvironment):
@@ -109,15 +106,11 @@ class ModelSetup(RunEnvironment):
def _set_num_of_training_samples(self):
def _set_num_of_training_samples(self):
""" Set number of training samples - needed for example for Bayesian NNs"""
""" Set number of training samples - needed for example for Bayesian NNs"""
samples = 0
samples = 0
for s in self.data_store.get("data_collection", "train"):
upsampling = self.data_store.create_args_dict(["upsampling"], "train")
if isinstance(s.get_Y(), list):
for data in self.data_store.get("data_collection", "train"):
s_sam = s.get_Y()[0].shape[0]
length = data.__len__(**upsampling)
elif isinstance(s.get_Y(), tuple):
samples += length
s_sam = s.get_Y().shape[0]
return samples
else:
s_sam = np.nan
samples += s_sam
self.num_of_training_samples = samples
def compile_model(self):
def compile_model(self):
"""
"""
@@ -179,10 +172,11 @@ class ModelSetup(RunEnvironment):
@@ -179,10 +172,11 @@ class ModelSetup(RunEnvironment):
model = self.data_store.get("model_class")
model = self.data_store.get("model_class")
args_list = model.requirements()
args_list = model.requirements()
if "num_of_training_samples" in args_list:
if "num_of_training_samples" in args_list:
self.data_store.set("num_of_training_samples", self.num_of_training_samples, scope=self.scope)
num_of_training_samples = self._set_num_of_training_samples()
logging.info(f"Store number of training samples ({self.num_of_training_samples}) in data_store: "
self.data_store.set("num_of_training_samples", num_of_training_samples, scope=self.scope)
f"self.data_store.set('num_of_training_samples', {self.num_of_training_samples}, scope='{self.scope}')")
logging.info(f"Store number of training samples ({num_of_training_samples}) in data_store: "
f"self.data_store.set('num_of_training_samples', {num_of_training_samples}, scope="
 
f"'{self.scope}')")
args = self.data_store.create_args_dict(args_list, self.scope)
args = self.data_store.create_args_dict(args_list, self.scope)
self.model = model(**args)
self.model = model(**args)
self.get_model_settings()
self.get_model_settings()
Loading