diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index 8898a6b2d0591328f2bb7010ccbfe144a48ca40b..4dc9521abf3569eb57249286e92c1e6a259c667d 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -38,6 +38,11 @@ class AbstractModelClass(ABC): self._input_shape = input_shape self._output_shape = self.__extract_from_tuple(output_shape) + def load_model(self, name: str): + hist = self.model.history + self.model = keras.models.load_model(name) + self.model.history = hist + def __getattr__(self, name: str) -> Any: """ Is called if __getattribute__ is not able to find requested attribute. diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index a64ce917393dda9605e81cbef1dfd60fc57beaa5..98263eb732d8067fba0950c7a4882fb3ef020995 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -134,7 +134,7 @@ class ModelSetup(RunEnvironment): def load_model(self): """Try to load model from disk or skip if not possible.""" try: - self.model = keras.models.load_model(self.model_name) + self.model.load_model(self.model_name) logging.info(f"reload model {self.model_name} from disk ...") except OSError: logging.info('no local model to load...') diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 64b323fa74b3770698d40364ee7defae88a01b4c..0d875766926e870349337a0597e2b3612a93ee07 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -149,7 +149,7 @@ class Training(RunEnvironment): logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") self.callbacks.load_callbacks() self.callbacks.update_checkpoint() - self.model = keras.models.load_model(checkpoint.filepath) + self.model.load_model(checkpoint.filepath) hist: History = self.callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 _ = self.model.fit(self.train_set, @@ -179,6 +179,7 @@ class Training(RunEnvironment): model_name = self.data_store.get("model_name", "model") logging.debug(f"save best model to {model_name}") self.model.save(model_name, save_format='h5') + self.model.save(model_name) self.data_store.set("best_model", self.model) def load_best_model(self, name: str) -> None: @@ -189,7 +190,7 @@ class Training(RunEnvironment): """ logging.debug(f"load best model: {name}") try: - self.model = keras.models.load_model(name) + self.model.load_model(name) logging.info('reload weights...') except OSError: logging.info('no weights to reload...') diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index 46a5ba92b78cd4711e8f48c8df8e3721cf5d52cc..44e664e4f47dfd842ed956fcf7f7e56becb758ef 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -13,6 +13,7 @@ from tensorflow.keras.callbacks import History from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler from mlair.helpers import PyTestRegex +from mlair.model_modules.fully_connected_networks import FCN from mlair.model_modules.flatten import flatten_tail from mlair.model_modules.inception_model import InceptionModelBase from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback @@ -160,7 +161,10 @@ class TestTraining: @pytest.fixture def model(self, window_history_size, window_lead_time, statistics_per_var): channels = len(list(statistics_per_var.keys())) - return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False) + + return FCN([(window_history_size + 1, 1, channels)], [window_lead_time]) + + # return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False) @pytest.fixture def callbacks(self, path):