diff --git a/run.py b/run.py index c06bf6480aa17918a1e8131defd3b3a5233ffb44..9f38fdca9c51cbed332725ce8e120e1493551b93 100644 --- a/run.py +++ b/run.py @@ -24,7 +24,7 @@ def main(parser_args): Training() - # PostProcessing() + PostProcessing() if __name__ == "__main__": diff --git a/src/run_modules/training.py b/src/run_modules/training.py index d1962605dc0a0eb3f6d0b0104f80e73a05134afc..195ae28a395e1c888a10a6c2a4a05f963a84d0bc 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -86,31 +86,28 @@ class Training(RunEnvironment): verbose=2, validation_data=self.val_set.distribute_on_batches(), validation_steps=len(self.val_set), - # callbacks=self.callbacks) callbacks=[self.lr_sc, self.hist, self.checkpoint]) else: lr_filepath = self.checkpoint.callbacks[0]["path"] hist_filepath = self.checkpoint.callbacks[1]["path"] - lr_callbacks = pickle.load(open(lr_filepath, "rb")) - hist_callbacks = pickle.load(open(hist_filepath, "rb")) - self.lr_sc = lr_callbacks - self.hist = hist_callbacks + self.lr_sc = pickle.load(open(lr_filepath, "rb")) + self.hist = pickle.load(open(hist_filepath, "rb")) self.model = keras.models.load_model(self.checkpoint.filepath) - initial_epoch = max(hist_callbacks.epoch) + 1 + initial_epoch = max(self.hist.epoch) + 1 callbacks = [{"callback": self.lr_sc, "path": lr_filepath}, {"callback": self.hist, "path": hist_filepath}] self.checkpoint.update_callbacks(callbacks) - self.checkpoint.update_best(hist_callbacks) + self.checkpoint.update_best(self.hist) _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), - steps_per_epoch=len(self.train_set), - epochs=self.epochs, - verbose=2, - validation_data=self.val_set.distribute_on_batches(), - validation_steps=len(self.val_set), - callbacks=[self.lr_sc, self.hist, self.checkpoint], - initial_epoch=initial_epoch) + steps_per_epoch=len(self.train_set), + epochs=self.epochs, + verbose=2, + validation_data=self.val_set.distribute_on_batches(), + validation_steps=len(self.val_set), + callbacks=[self.lr_sc, self.hist, self.checkpoint], + initial_epoch=initial_epoch) history = self.hist - self.save_callbacks(history) + self.save_callbacks_as_json(history) self.load_best_model(self.checkpoint.filepath) self.create_monitoring_plots(history, self.lr_sc) @@ -137,7 +134,7 @@ class Training(RunEnvironment): except OSError: logging.info('no weights to reload...') - def save_callbacks(self, history: keras.callbacks.History) -> None: + def save_callbacks_as_json(self, history: keras.callbacks.History) -> None: """ Save callbacks (history, learning rate) of training. * history.history -> history.json @@ -150,7 +147,6 @@ class Training(RunEnvironment): json.dump(history.history, f) with open(os.path.join(path, "history_lr.json"), "w") as f: json.dump(self.lr_sc.lr, f) - # json.dump(self.callbacks["learning_rate"].lr, f) def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None: """ diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 4631fe5ae4a7a93dcba05dcb33baa9437e4cbd20..580e7925e5bd48e6fd6aa2226a88e5af617ea3cb 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -189,21 +189,21 @@ class TestTraining: assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) def test_save_callbacks_history_created(self, init_without_run, history, path): - init_without_run.save_callbacks(history) + init_without_run.save_callbacks_as_json(history) assert "history.json" in os.listdir(path) def test_save_callbacks_lr_created(self, init_with_lr, history, path): - init_with_lr.save_callbacks(history) + init_with_lr.save_callbacks_as_json(history) assert "history_lr.json" in os.listdir(path) def test_save_callbacks_inspect_history(self, init_without_run, history, path): - init_without_run.save_callbacks(history) + init_without_run.save_callbacks_as_json(history) with open(os.path.join(path, "history.json")) as jfile: hist = json.load(jfile) assert hist == history.history def test_save_callbacks_inspect_lr(self, init_with_lr, history, path): - init_with_lr.save_callbacks(history) + init_with_lr.save_callbacks_as_json(history) with open(os.path.join(path, "history_lr.json")) as jfile: lr = json.load(jfile) assert lr == init_with_lr.lr_sc.lr