diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 0b9e8ec56592901d9feba15eb50b6b21a0c53560..a64ce917393dda9605e81cbef1dfd60fc57beaa5 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -84,7 +84,7 @@ class ModelSetup(RunEnvironment): # load weights if no training shall be performed if not self._train_model and not self._create_new_model: - self.load_weights() + self.load_model() # create checkpoint self._set_callbacks() @@ -131,13 +131,13 @@ class ModelSetup(RunEnvironment): save_best_only=True, mode='auto') self.data_store.set("callbacks", callbacks, self.scope) - def load_weights(self): - """Try to load weights from existing model or skip if not possible.""" + def load_model(self): + """Try to load model from disk or skip if not possible.""" try: - self.model.load_weights(self.model_name) - logging.info(f"reload weights from model {self.model_name} ...") + self.model = keras.models.load_model(self.model_name) + logging.info(f"reload model {self.model_name} from disk ...") except OSError: - logging.info('no weights to reload...') + logging.info('no local model to load...') def build_model(self): """Build model using input and output shapes from data store.""" diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 0696c2e7b8daa75925cf16096e183de94c21fe85..64b323fa74b3770698d40364ee7defae88a01b4c 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -189,7 +189,7 @@ class Training(RunEnvironment): """ logging.debug(f"load best model: {name}") try: - self.model.load_weights(name) + self.model = keras.models.load_model(name) logging.info('reload weights...') except OSError: logging.info('no weights to reload...') diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index f1b210e1c7429c96658238ac21d96b7843053da7..46a5ba92b78cd4711e8f48c8df8e3721cf5d52cc 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -308,9 +308,58 @@ class TestTraining: init_without_run.create_monitoring_plots(history, learning_rate) assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 - def test_resume_training(self, ready_to_run): - with copy.copy(ready_to_run) as pre_run: - assert pre_run._run() is None # rune once to create model - ready_to_run.epochs = 4 # continue train up to epoch 4 - assert ready_to_run._run() is None + def test_resume_training(self, ready_to_run, path: str, model: keras.Model, model_path, + batch_path, data_collection): + with ready_to_run as run_obj: + assert run_obj._run() is None # rune once to create model + + # init new object + obj = object.__new__(Training) + super(Training, obj).__init__() + obj.model = model + obj.train_set = None + obj.val_set = None + obj.test_set = None + obj.batch_size = 256 + obj.epochs = 4 + + clbk = CallbackHandler() + hist = HistoryAdvanced() + epo_timing = EpoTimingCallback() + clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") + lr = LearningRateDecay() + clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") + clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing") + clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', + save_best_only=True) + obj.callbacks = clbk + obj.lr_sc = lr + obj.hist = hist + obj.experiment_name = "TestExperiment" + obj.data_store.set("data_collection", data_collection, "general.train") + obj.data_store.set("data_collection", data_collection, "general.val") + obj.data_store.set("data_collection", data_collection, "general.test") + obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) + if not os.path.exists(path): + os.makedirs(path) + obj.data_store.set("experiment_path", path, "general") + os.makedirs(batch_path, exist_ok=True) + obj.data_store.set("batch_path", batch_path, "general") + os.makedirs(model_path, exist_ok=True) + obj.data_store.set("model_path", model_path, "general") + obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") + obj.data_store.set("experiment_name", "TestExperiment", "general") + + path_plot = os.path.join(path, "plots") + os.makedirs(path_plot, exist_ok=True) + obj.data_store.set("plot_path", path_plot, "general") + obj._train_model = True + obj._create_new_model = False + + + assert obj._run() is None + assert 1 == 1 + assert 1 == 1 + +