diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py index 6cd0d6879333fee775a48b1a70a09c91796dbb24..925d83d733b8cb5715b515856e09e773737778a9 100644 --- a/src/modules/model_setup.py +++ b/src/modules/model_setup.py @@ -9,6 +9,7 @@ from keras.regularizers import l2 from keras.optimizers import Adam, SGD import tensorflow as tf import logging +import os from src.modules.run_environment import RunEnvironment from src.helpers import l_p_loss, LearningRateDecay @@ -25,6 +26,7 @@ class ModelSetup(RunEnvironment): self.model = None self.model_name = self.data_store.get("experiment_name", "general") + "model-best.h5" self.scope = "general.model" + self._run() def _run(self): @@ -51,61 +53,63 @@ class ModelSetup(RunEnvironment): optimizer = self.data_store.get("optimizer", self.scope) loss = self.data_store.get("loss", self.scope) self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) + self.data_store.put("model", self.model, self.scope) def _set_checkpoint(self): - ModelCheckpoint(self.model_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') + checkpoint = ModelCheckpoint(self.model_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') + self.data_store.put("checkpoint", checkpoint, self.scope) def load_weights(self): - #try: - logging.debug('reload weights...') - self.model.load_weights(self.model_name) - #except: - # print('no weights to reload...') + try: + logging.debug('reload weights...') + self.model.load_weights(self.model_name) + except OSError: + logging.debug('no weights to reload...') def build_model(self): - args_list = ["activation", "window_size", "channels", "regularizer", "dropout_rate", "window_lead_time"] + args_list = ["activation", "window_history_size", "channels", "regularizer", "dropout_rate", "window_lead_time"] args = self.data_store.create_args_dict(args_list, self.scope) self.model = my_model(**args) - def plot_model(self): + def plot_model(self): # pragma: no cover with tf.device("/cpu:0"): - file_name = self.data_store.get("experiment_name", "general") + "model.pdf" + path = self.data_store.get("experiment_path", "general") + name = self.data_store.get("experiment_name", "general") + "model.pdf" + file_name = os.path.join(path, name) keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) def my_model_settings(self): - scope = "general.model" - # channels - X, y = self.data_store.get("generator", "general.train")[0] + X, _ = self.data_store.get("generator", "general.train")[0] channels = X.shape[-1] # input variables - self.data_store.put("channels", channels, scope) + self.data_store.put("channels", channels, self.scope) # dropout - self.data_store.put("dropout_rate", 0.1, scope) + self.data_store.put("dropout_rate", 0.1, self.scope) # regularizer - self.data_store.put("regularizer", l2(0.1), scope) + self.data_store.put("regularizer", l2(0.1), self.scope) # learning rate initial_lr = 1e-2 - self.data_store.put("initial_lr", initial_lr, scope) + self.data_store.put("initial_lr", initial_lr, self.scope) optimizer = SGD(lr=initial_lr, momentum=0.9) # optimizer=Adam(lr=initial_lr, amsgrad=True) - self.data_store.put("optimizer", optimizer, scope) - self.data_store.put("lr_decay", LearningRateDecay(base_lr=initial_lr, drop=.94, epochs_drop=10), scope) + self.data_store.put("optimizer", optimizer, self.scope) + self.data_store.put("lr_decay", LearningRateDecay(base_lr=initial_lr, drop=.94, epochs_drop=10), self.scope) # learning settings - self.data_store.put("epochs", 2, scope) - self.data_store.put("batch_size", int(256), scope) + self.data_store.put("epochs", 2, self.scope) + self.data_store.put("batch_size", int(256), self.scope) # activation activation = layers.PReLU # ELU #LeakyReLU keras.activations.tanh # - self.data_store.put("activation", activation, scope) + self.data_store.put("activation", activation, self.scope) # set los loss_all = my_loss() - self.data_store.put("loss", loss_all, scope) + self.data_store.put("loss", loss_all, self.scope) def my_loss(): @@ -115,7 +119,7 @@ def my_loss(): return loss_all -def my_model(activation, window_size, channels, regularizer, dropout_rate, window_lead_time): +def my_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time): conv_settings_dict1 = { 'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation}, @@ -147,7 +151,7 @@ def my_model(activation, window_size, channels, regularizer, dropout_rate, windo ########################################## inception_model = InceptionModelBase() - X_input = layers.Input(shape=(window_size + 1, 1, channels)) # add 1 to window_size to include current time step t0 + X_input = layers.Input(shape=(window_history_size + 1, 1, channels)) # add 1 to window_size to include current time step t0 X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1, regularizer=regularizer, batch_normalisation=True) diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d595463ff0f42818d634018fa8ddb03e5b54a2db --- /dev/null +++ b/test/test_modules/test_model_setup.py @@ -0,0 +1,56 @@ +import logging +import pytest +import os +import keras + +from src.modules.model_setup import ModelSetup +from src.modules.run_environment import RunEnvironment +from src.data_generator import DataGenerator + + +class TestModelSetup: + + @pytest.fixture + def setup(self): + obj = object.__new__(ModelSetup) + super(ModelSetup, obj).__init__() + obj.scope = "general.modeltest" + obj.model = None + yield obj + RunEnvironment().__del__() + + @pytest.fixture + def gen(self): + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + + @pytest.fixture + def setup_with_gen(self, setup, gen): + setup.data_store.put("generator", gen, "general.train") + return setup + + def test_set_checkpoint(self, setup): + assert "general.modeltest" not in setup.data_store.search_name("checkpoint") + setup.model_name = "TestName" + setup._set_checkpoint() + assert "general.modeltest" in setup.data_store.search_name("checkpoint") + + def test_my_model_settings(self, setup_with_gen): + setup_with_gen.my_model_settings() + expected = {"channels", "dropout_rate", "regularizer", "initial_lr", "optimizer", "lr_decay", "epochs", + "batch_size", "activation", "loss"} + assert expected <= set(setup_with_gen.data_store.search_scope(setup_with_gen.scope, current_scope_only=True)) + + def test_build_model(self, setup_with_gen): + setup_with_gen.my_model_settings() + setup_with_gen.data_store.put("window_lead_time", 2, "general") + assert setup_with_gen.model is None + setup_with_gen.build_model() + assert isinstance(setup_with_gen.model, keras.Model) + + def test_load_weights(self): + pass + + def test_compile_model(self): + pass +