diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 66697e07d68a01ffa203798c11417b8440d54214..e3945a542d60b09dc9855bd28be87cdba729ed72 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -7,14 +7,11 @@ import os import keras import tensorflow as tf -from keras import losses -from src.helpers import l_p_loss -from src.model_modules.flatten import flatten_tail -from src.model_modules.inception_model import InceptionModelBase -from src.model_modules.keras_extensions import HistoryAdvanced, ModelCheckpointAdvanced +from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler # from src.model_modules.model_class import MyBranchedModel as MyModel -from src.model_modules.model_class import MyLittleModel as MyModel +# from src.model_modules.model_class import MyLittleModel as MyModel +from src.model_modules.model_class import MyTowerModel as MyModel from src.run_modules.run_environment import RunEnvironment @@ -52,7 +49,7 @@ class ModelSetup(RunEnvironment): self.load_weights() # create checkpoint - self._set_checkpoint() + self._set_callbacks() # compile model self.compile_model() @@ -67,19 +64,20 @@ class ModelSetup(RunEnvironment): self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) self.data_store.set("model", self.model, self.scope) - def _set_checkpoint(self): + def _set_callbacks(self): """ - Must be run after all callback functions that shall be tracked during training have been created (currently this - affects the learning rate decay and the advanced history [actually created in this method]). + Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the + advanced model checkpoint is added. """ lr = self.data_store.get("lr_decay", scope="general.model") hist = HistoryAdvanced() self.data_store.set("hist", hist, scope="general.model") - callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"}, - {"callback": hist, "path": self.callbacks_name % "hist"}] - checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', - save_best_only=True, mode='auto', callbacks=callbacks) - self.data_store.set("checkpoint", checkpoint, self.scope) + callbacks = CallbackHandler() + callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") + callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") + callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', + save_best_only=True, mode='auto') + self.data_store.set("callbacks", callbacks, self.scope) def load_weights(self): try: diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 0b11da8d8f9a23c51d787f00d00a74d7517ea3b1..7a522af0298bcabee62579f68bd29ed123cac7b0 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -9,7 +9,7 @@ import pickle import keras from src.data_handling.data_distributor import Distributor -from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced +from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced, CallbackHandler from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from src.run_modules.run_environment import RunEnvironment @@ -24,9 +24,7 @@ class Training(RunEnvironment): self.test_set = None self.batch_size = self.data_store.get("batch_size", "general.model") self.epochs = self.data_store.get("epochs", "general.model") - self.checkpoint: ModelCheckpointAdvanced = self.data_store.get("checkpoint", "general.model") - self.lr_sc = self.data_store.get("lr_decay", "general.model") - self.hist = self.data_store.get("hist", "general.model") + self.callbacks: CallbackHandler = self.data_store.get("callbacks", "general.model") self.experiment_name = self.data_store.get("experiment_name", "general") self._trainable = self.data_store.get("trainable", "general") self._create_new_model = self.data_store.get("create_new_model", "general") @@ -87,38 +85,35 @@ class Training(RunEnvironment): locally stored information and the corresponding model and proceed with the already started training. """ logging.info(f"Train with {len(self.train_set)} mini batches.") - if not os.path.exists(self.checkpoint.filepath) or self._create_new_model: + checkpoint = self.callbacks.get_checkpoint() + if not os.path.exists(checkpoint.filepath) or self._create_new_model: history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), steps_per_epoch=len(self.train_set), epochs=self.epochs, verbose=2, validation_data=self.val_set.distribute_on_batches(), validation_steps=len(self.val_set), - callbacks=[self.lr_sc, self.hist, self.checkpoint]) + callbacks=self.callbacks.get_callbacks(as_dict=False)) else: logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") - lr_filepath = self.checkpoint.callbacks[0]["path"] - hist_filepath = self.checkpoint.callbacks[1]["path"] - self.lr_sc = pickle.load(open(lr_filepath, "rb")) - self.hist = pickle.load(open(hist_filepath, "rb")) - self.model = keras.models.load_model(self.checkpoint.filepath) - initial_epoch = max(self.hist.epoch) + 1 - callbacks = [{"callback": self.lr_sc, "path": lr_filepath}, - {"callback": self.hist, "path": hist_filepath}] - self.checkpoint.update_callbacks(callbacks) - self.checkpoint.update_best(self.hist) + self.callbacks.load_callbacks() + self.callbacks.update_checkpoint() + self.model = keras.models.load_model(checkpoint.filepath) + hist = self.callbacks.get_callback_by_name("hist") + initial_epoch = max(hist.epoch) + 1 _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), steps_per_epoch=len(self.train_set), epochs=self.epochs, verbose=2, validation_data=self.val_set.distribute_on_batches(), validation_steps=len(self.val_set), - callbacks=[self.lr_sc, self.hist, self.checkpoint], + callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch) - history = self.hist - self.save_callbacks_as_json(history) - self.load_best_model(self.checkpoint.filepath) - self.create_monitoring_plots(history, self.lr_sc) + history = hist + lr = self.callbacks.get_callback_by_name("lr") + self.save_callbacks_as_json(history, lr) + self.load_best_model(checkpoint.filepath) + self.create_monitoring_plots(history, lr) def save_model(self) -> None: """ @@ -141,7 +136,7 @@ class Training(RunEnvironment): except OSError: logging.info('no weights to reload...') - def save_callbacks_as_json(self, history: keras.callbacks.History) -> None: + def save_callbacks_as_json(self, history: keras.callbacks.History, lr_sc: keras.callbacks) -> None: """ Save callbacks (history, learning rate) of training. * history.history -> history.json @@ -153,7 +148,7 @@ class Training(RunEnvironment): with open(os.path.join(path, "history.json"), "w") as f: json.dump(history.history, f) with open(os.path.join(path, "history_lr.json"), "w") as f: - json.dump(self.lr_sc.lr, f) + json.dump(lr_sc.lr, f) def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None: """ diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index 35c5f8ee7581856a9feee3abd0face73ee83952c..ade35a244601d138d22af6305e67b5aeae964680 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -56,11 +56,11 @@ class TestModelSetup: def current_scope_as_set(model_cls): return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True)) - def test_set_checkpoint(self, setup): - assert "general.modeltest" not in setup.data_store.search_name("checkpoint") + def test_set_callbacks(self, setup): + assert "general.modeltest" not in setup.data_store.search_name("callbacks") setup.checkpoint_name = "TestName" - setup._set_checkpoint() - assert "general.modeltest" in setup.data_store.search_name("checkpoint") + setup._set_callbacks() + assert "general.modeltest" in setup.data_store.search_name("callbacks") def test_get_model_settings(self, setup_with_model): with pytest.raises(EmptyScope): diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index ac040c3a286c25dc84853c26c8509278642a1495..31c673f05d055eb7c4ee76318711de030d97d480 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -14,7 +14,7 @@ from src.data_handling.data_generator import DataGenerator from src.helpers import PyTestRegex from src.model_modules.flatten import flatten_tail from src.model_modules.inception_model import InceptionModelBase -from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced +from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler from src.run_modules.run_environment import RunEnvironment from src.run_modules.training import Training @@ -39,7 +39,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m class TestTraining: @pytest.fixture - def init_without_run(self, path: str, model: keras.Model, checkpoint: ModelCheckpoint): + def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler): obj = object.__new__(Training) super(Training, obj).__init__() obj.model = model @@ -48,9 +48,10 @@ class TestTraining: obj.test_set = None obj.batch_size = 256 obj.epochs = 2 - obj.checkpoint = checkpoint - obj.lr_sc = LearningRateDecay() - obj.hist = HistoryAdvanced() + clbk, hist, lr = callbacks + obj.callbacks = clbk + obj.lr_sc = lr + obj.hist = hist obj.experiment_name = "TestExperiment" obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train") obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val") @@ -70,12 +71,9 @@ class TestTraining: @pytest.fixture def learning_rate(self): - return {"lr": [0.01, 0.0094]} - - @pytest.fixture - def init_with_lr(self, init_without_run, learning_rate): - init_without_run.lr_sc.lr = learning_rate - return init_without_run + lr = LearningRateDecay() + lr.lr = {"lr": [0.01, 0.0094]} + return lr @pytest.fixture def history(self): @@ -105,8 +103,15 @@ class TestTraining: return my_test_model(keras.layers.PReLU, 7, 2, 0.1, False) @pytest.fixture - def checkpoint(self, path): - return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True) + def callbacks(self, path): + clbk = CallbackHandler() + hist = HistoryAdvanced() + clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") + lr = LearningRateDecay() + clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") + clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', + save_best_only=True) + return clbk, hist, lr @pytest.fixture def ready_to_train(self, generator: DataGenerator, init_without_run: Training): @@ -125,7 +130,7 @@ class TestTraining: return obj @pytest.fixture - def ready_to_init(self, generator, model, checkpoint, path): + def ready_to_init(self, generator, model, callbacks, path): os.makedirs(path) obj = RunEnvironment() obj.data_store.set("generator", generator, "general.train") @@ -136,14 +141,14 @@ class TestTraining: obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model") obj.data_store.set("batch_size", 256, "general.model") obj.data_store.set("epochs", 2, "general.model") - obj.data_store.set("checkpoint", checkpoint, "general.model") - obj.data_store.set("lr_decay", LearningRateDecay(), "general.model") - obj.data_store.set("hist", HistoryAdvanced(), "general.model") + clbk, hist, lr = callbacks + obj.data_store.set("callbacks", clbk, "general.model") + obj.data_store.set("lr_decay", lr, "general.model") + obj.data_store.set("hist", hist, "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") obj.data_store.set("experiment_path", path, "general") obj.data_store.set("trainable", True, "general") - obj.data_store.set("create_new_model" - "", True, "general") + obj.data_store.set("create_new_model", True, "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") @@ -197,25 +202,25 @@ class TestTraining: assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting")) assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) - def test_save_callbacks_history_created(self, init_without_run, history, path): - init_without_run.save_callbacks_as_json(history) + def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, path): + init_without_run.save_callbacks_as_json(history, learning_rate) assert "history.json" in os.listdir(path) - def test_save_callbacks_lr_created(self, init_with_lr, history, path): - init_with_lr.save_callbacks_as_json(history) + def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, path): + init_without_run.save_callbacks_as_json(history, learning_rate) assert "history_lr.json" in os.listdir(path) - def test_save_callbacks_inspect_history(self, init_without_run, history, path): - init_without_run.save_callbacks_as_json(history) + def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, path): + init_without_run.save_callbacks_as_json(history, learning_rate) with open(os.path.join(path, "history.json")) as jfile: hist = json.load(jfile) assert hist == history.history - def test_save_callbacks_inspect_lr(self, init_with_lr, history, path): - init_with_lr.save_callbacks_as_json(history) + def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, path): + init_without_run.save_callbacks_as_json(history, learning_rate) with open(os.path.join(path, "history_lr.json")) as jfile: lr = json.load(jfile) - assert lr == init_with_lr.lr_sc.lr + assert lr == learning_rate.lr def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path): assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0