diff --git a/run.py b/run.py index 71244fb9d15f594ac3ffbce60341d5c8dcb15f03..c06bf6480aa17918a1e8131defd3b3a5233ffb44 100644 --- a/run.py +++ b/run.py @@ -24,14 +24,14 @@ def main(parser_args): Training() - PostProcessing() + # PostProcessing() if __name__ == "__main__": formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' - # logging.basicConfig(format=formatter, level=logging.INFO) - logging.basicConfig(format=formatter, level=logging.DEBUG) + logging.basicConfig(format=formatter, level=logging.INFO) + # logging.basicConfig(format=formatter, level=logging.DEBUG) parser = argparse.ArgumentParser() parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, diff --git a/src/helpers.py b/src/helpers.py index 172a8dd3cf04a15e9069347dac7f06c6d2d8ed60..f119f140a159d4e050c77fde563651484a57079d 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -57,6 +57,8 @@ class LearningRateDecay(keras.callbacks.History): self.base_lr = self.check_param(base_lr, 'base_lr') self.drop = self.check_param(drop, 'drop') self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None) + self.epoch = [] + self.history = {} @staticmethod def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1): @@ -80,6 +82,9 @@ class LearningRateDecay(keras.callbacks.History): raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: " f"{name}={value}") + def on_train_begin(self, logs=None): + pass + def on_epoch_begin(self, epoch: int, logs=None): """ Lower learning rate every epochs_drop epochs by factor drop. diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 6b0fe236ff8ee726c34a721a6be0ed8be91f2bb8..5e9931d70d33a16cd4478427db7402a2b38dc9c0 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -110,7 +110,7 @@ class MyLittleModel(AbstractModelClass): self.initial_lr = 1e-2 self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) - self.epochs = 2 + self.epochs = 20 self.batch_size = int(256) self.activation = keras.layers.PReLU @@ -190,7 +190,7 @@ class MyBranchedModel(AbstractModelClass): self.initial_lr = 1e-2 self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) - self.epochs = 2 + self.epochs = 20 self.batch_size = int(256) self.activation = keras.layers.PReLU diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index a7722018c52275b390a10199cb30b7b936ed37a3..a4d89f65d679f10b3bd05191b708761925e2cd4d 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -4,19 +4,20 @@ __date__ = '2019-12-02' import keras from keras import losses -from keras.callbacks import ModelCheckpoint +from keras.callbacks import ModelCheckpoint, History from keras.regularizers import l2 from keras.optimizers import SGD import tensorflow as tf import logging import os +import pickle from src.run_modules.run_environment import RunEnvironment from src.helpers import l_p_loss, LearningRateDecay from src.model_modules.inception_model import InceptionModelBase from src.model_modules.flatten import flatten_tail -from src.model_modules.model_class import MyBranchedModel as MyModel -# from src.model_modules.model_class import MyLittleModel as MyModel +# from src.model_modules.model_class import MyBranchedModel as MyModel +from src.model_modules.model_class import MyLittleModel as MyModel class ModelSetup(RunEnvironment): @@ -30,13 +31,11 @@ class ModelSetup(RunEnvironment): exp_name = self.data_store.get("experiment_name", "general") self.scope = "general.model" self.checkpoint_name = os.path.join(path, f"{exp_name}_model-best.h5") + self.callbacks_name = os.path.join(path, f"{exp_name}_model-best-callbacks-%s.pickle") self._run() def _run(self): - # create checkpoint - self._set_checkpoint() - # set channels depending on inputs self._set_channels() @@ -50,6 +49,9 @@ class ModelSetup(RunEnvironment): if self.data_store.get("trainable", self.scope) is False: self.load_weights() + # create checkpoint + self._set_checkpoint() + # compile model self.compile_model() @@ -64,7 +66,17 @@ class ModelSetup(RunEnvironment): self.data_store.set("model", self.model, self.scope) def _set_checkpoint(self): - checkpoint = ModelCheckpoint(self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') + lr = self.data_store.get("lr_decay", scope="general.model") + # checkpoint = ModelCheckpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') + # checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', + # save_best_only=True, mode='auto', callbacks_to_save=lr, + # callbacks_filepath=self.callbacks_name) + hist = HistoryAdvanced() + self.data_store.set("hist", hist, scope="general.model") + callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"}, + {"callback": hist, "path": self.callbacks_name % "hist"}] + checkpoint = ModelCheckpointAdvanced2(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', + save_best_only=True, mode='auto', callbacks=callbacks) self.data_store.set("checkpoint", checkpoint, self.scope) def load_weights(self): @@ -92,6 +104,61 @@ class ModelSetup(RunEnvironment): keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) +class HistoryAdvanced(History): + + def __init__(self, old_epoch=None, old_history=None): + self.epoch = old_epoch or [] + self.history = old_history or {} + super().__init__() + + def on_train_begin(self, logs=None): + pass + + +class ModelCheckpointAdvanced(ModelCheckpoint): + + def __init__(self, *args, **kwargs): + self.callbacks_to_save = kwargs.pop("callbacks_to_save") + self.callbacks_filepath = kwargs.pop("callbacks_filepath") + super().__init__(*args, **kwargs) + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + + file_path = self.callbacks_filepath + if self.epochs_since_last_save == 0 and epoch != 0: + if self.save_best_only: + current = logs.get(self.monitor) + if current == self.best: + with open(file_path, "wb") as f: + pickle.dump(self.callbacks_to_save, f) + else: + with open(file_path, "wb") as f: + pickle.dump(self.callbacks_to_save, f) + + +class ModelCheckpointAdvanced2(ModelCheckpoint): + + def __init__(self, *args, **kwargs): + self.callbacks = kwargs.pop("callbacks") + super().__init__(*args, **kwargs) + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + + for callback in self.callbacks: + file_path = callback["path"] + if self.epochs_since_last_save == 0 and epoch != 0: + if self.save_best_only: + current = logs.get(self.monitor) + if current == self.best: + with open(file_path, "wb") as f: + pickle.dump(callback["callback"], f) + else: + with open(file_path, "wb") as f: + pickle.dump(callback["callback"], f) + + def my_loss(): loss = l_p_loss(4) keras_loss = losses.mean_squared_error diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 96936ce124e05251af483e758401d833a44531f4..e9c7487b4be94c8dfad6c64a88bd9b0815656064 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -5,11 +5,13 @@ import logging import os import json import keras +import pickle from src.run_modules.run_environment import RunEnvironment from src.data_handling.data_distributor import Distributor from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from src.helpers import LearningRateDecay +from src.run_modules.model_setup import ModelCheckpointAdvanced2 class Training(RunEnvironment): @@ -22,8 +24,10 @@ class Training(RunEnvironment): self.test_set = None self.batch_size = self.data_store.get("batch_size", "general.model") self.epochs = self.data_store.get("epochs", "general.model") - self.checkpoint = self.data_store.get("checkpoint", "general.model") + self.checkpoint: ModelCheckpointAdvanced2 = self.data_store.get("checkpoint", "general.model") + # self.callbacks = self.data_store.get("callbacks", "general.model") self.lr_sc = self.data_store.get("lr_decay", "general.model") + self.hist = self.data_store.get("hist", "general.model") self.experiment_name = self.data_store.get("experiment_name", "general") self._run() @@ -76,13 +80,36 @@ class Training(RunEnvironment): model from training is saved for class variable model. """ logging.info(f"Train with {len(self.train_set)} mini batches.") - history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), - steps_per_epoch=len(self.train_set), - epochs=self.epochs, - verbose=2, - validation_data=self.val_set.distribute_on_batches(), - validation_steps=len(self.val_set), - callbacks=[self.checkpoint, self.lr_sc]) + if not os.path.exists(self.checkpoint.filepath): + history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), + steps_per_epoch=len(self.train_set), + epochs=self.epochs, + verbose=2, + validation_data=self.val_set.distribute_on_batches(), + validation_steps=len(self.val_set), + # callbacks=self.callbacks) + callbacks=[self.checkpoint, self.lr_sc, self.hist]) + else: + lr_filepath = self.checkpoint.callbacks[0]["path"] # TODO: stopped here. why does training start 1 epoch too early or doesn't it? + hist_filepath = self.checkpoint.callbacks[1]["path"] + lr_callbacks = pickle.load(open(lr_filepath, "rb")) + hist_callbacks = pickle.load(open(hist_filepath, "rb")) + self.lr_sc = lr_callbacks + self.hist = hist_callbacks + self.model = keras.models.load_model(self.checkpoint.filepath) + initial_epoch = max(hist_callbacks.epoch) + 1 + callbacks = [{"callback": self.lr_sc, "path": lr_filepath}, + {"callback": self.hist, "path": hist_filepath}] + self.checkpoint.callbacks = callbacks + history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), + steps_per_epoch=len(self.train_set), + epochs=self.epochs, + verbose=2, + validation_data=self.val_set.distribute_on_batches(), + validation_steps=len(self.val_set), + callbacks=[self.checkpoint, self.lr_sc, self.hist], + initial_epoch=initial_epoch) + history = self.hist self.save_callbacks(history) self.load_best_model(self.checkpoint.filepath) self.create_monitoring_plots(history, self.lr_sc) @@ -123,6 +150,7 @@ class Training(RunEnvironment): json.dump(history.history, f) with open(os.path.join(path, "history_lr.json"), "w") as f: json.dump(self.lr_sc.lr, f) + # json.dump(self.callbacks["learning_rate"].lr, f) def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None: """