Skip to content
Snippets Groups Projects
Commit 4b82d37b authored by leufen1's avatar leufen1
Browse files

test if pipeline finishes now successfully

parent 91ff847d
No related branches found
No related tags found
5 merge requests!413update release branch,!412Resolve "release v2.0.0",!361name of pdf starts now with feature_importance, there is now also another...,!350Resolve "upgrade code to TensorFlow V2",!335Resolve "upgrade code to TensorFlow V2"
Pipeline #82202 failed
...@@ -84,7 +84,7 @@ class ModelSetup(RunEnvironment): ...@@ -84,7 +84,7 @@ class ModelSetup(RunEnvironment):
# load weights if no training shall be performed # load weights if no training shall be performed
if not self._train_model and not self._create_new_model: if not self._train_model and not self._create_new_model:
self.load_weights() self.load_model()
# create checkpoint # create checkpoint
self._set_callbacks() self._set_callbacks()
...@@ -131,13 +131,13 @@ class ModelSetup(RunEnvironment): ...@@ -131,13 +131,13 @@ class ModelSetup(RunEnvironment):
save_best_only=True, mode='auto') save_best_only=True, mode='auto')
self.data_store.set("callbacks", callbacks, self.scope) self.data_store.set("callbacks", callbacks, self.scope)
def load_weights(self): def load_model(self):
"""Try to load weights from existing model or skip if not possible.""" """Try to load model from disk or skip if not possible."""
try: try:
self.model.load_weights(self.model_name) self.model = keras.models.load_model(self.model_name)
logging.info(f"reload weights from model {self.model_name} ...") logging.info(f"reload model {self.model_name} from disk ...")
except OSError: except OSError:
logging.info('no weights to reload...') logging.info('no local model to load...')
def build_model(self): def build_model(self):
"""Build model using input and output shapes from data store.""" """Build model using input and output shapes from data store."""
......
...@@ -189,7 +189,7 @@ class Training(RunEnvironment): ...@@ -189,7 +189,7 @@ class Training(RunEnvironment):
""" """
logging.debug(f"load best model: {name}") logging.debug(f"load best model: {name}")
try: try:
self.model.load_weights(name) self.model = keras.models.load_model(name)
logging.info('reload weights...') logging.info('reload weights...')
except OSError: except OSError:
logging.info('no weights to reload...') logging.info('no weights to reload...')
......
...@@ -308,9 +308,58 @@ class TestTraining: ...@@ -308,9 +308,58 @@ class TestTraining:
init_without_run.create_monitoring_plots(history, learning_rate) init_without_run.create_monitoring_plots(history, learning_rate)
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
def test_resume_training(self, ready_to_run): def test_resume_training(self, ready_to_run, path: str, model: keras.Model, model_path,
with copy.copy(ready_to_run) as pre_run: batch_path, data_collection):
assert pre_run._run() is None # rune once to create model with ready_to_run as run_obj:
ready_to_run.epochs = 4 # continue train up to epoch 4 assert run_obj._run() is None # rune once to create model
assert ready_to_run._run() is None
# init new object
obj = object.__new__(Training)
super(Training, obj).__init__()
obj.model = model
obj.train_set = None
obj.val_set = None
obj.test_set = None
obj.batch_size = 256
obj.epochs = 4
clbk = CallbackHandler()
hist = HistoryAdvanced()
epo_timing = EpoTimingCallback()
clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
lr = LearningRateDecay()
clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
save_best_only=True)
obj.callbacks = clbk
obj.lr_sc = lr
obj.hist = hist
obj.experiment_name = "TestExperiment"
obj.data_store.set("data_collection", data_collection, "general.train")
obj.data_store.set("data_collection", data_collection, "general.val")
obj.data_store.set("data_collection", data_collection, "general.test")
obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
if not os.path.exists(path):
os.makedirs(path)
obj.data_store.set("experiment_path", path, "general")
os.makedirs(batch_path, exist_ok=True)
obj.data_store.set("batch_path", batch_path, "general")
os.makedirs(model_path, exist_ok=True)
obj.data_store.set("model_path", model_path, "general")
obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
path_plot = os.path.join(path, "plots")
os.makedirs(path_plot, exist_ok=True)
obj.data_store.set("plot_path", path_plot, "general")
obj._train_model = True
obj._create_new_model = False
assert obj._run() is None
assert 1 == 1
assert 1 == 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment