diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index 44e664e4f47dfd842ed956fcf7f7e56becb758ef..51ea1cd344c1ff1899af818c6b38a2cbb93b733a 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -1,6 +1,8 @@ import copy import glob import json +import time + import logging import os import shutil @@ -161,11 +163,8 @@ class TestTraining: @pytest.fixture def model(self, window_history_size, window_lead_time, statistics_per_var): channels = len(list(statistics_per_var.keys())) - return FCN([(window_history_size + 1, 1, channels)], [window_lead_time]) - # return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False) - @pytest.fixture def callbacks(self, path): clbk = CallbackHandler() @@ -194,7 +193,7 @@ class TestTraining: obj.data_store.set("data_collection", data_collection, "general.train") obj.data_store.set("data_collection", data_collection, "general.val") obj.data_store.set("data_collection", data_collection, "general.test") - obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) + obj.model.compile(**obj.model.compile_options) return obj @pytest.fixture @@ -229,6 +228,57 @@ class TestTraining: if os.path.exists(path): shutil.rmtree(path) + @staticmethod + def create_training_obj(epochs, path, data_collection, batch_path, model_path, + statistics_per_var, window_history_size, window_lead_time) -> Training: + + channels = len(list(statistics_per_var.keys())) + model = FCN([(window_history_size + 1, 1, channels)], [window_lead_time]) + + obj = object.__new__(Training) + super(Training, obj).__init__() + obj.model = model + obj.train_set = None + obj.val_set = None + obj.test_set = None + obj.batch_size = 256 + obj.epochs = epochs + + clbk = CallbackHandler() + hist = HistoryAdvanced() + epo_timing = EpoTimingCallback() + clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") + lr = LearningRateDecay() + clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") + clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing") + clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', + save_best_only=True) + obj.callbacks = clbk + obj.lr_sc = lr + obj.hist = hist + obj.experiment_name = "TestExperiment" + obj.data_store.set("data_collection", data_collection, "general.train") + obj.data_store.set("data_collection", data_collection, "general.val") + obj.data_store.set("data_collection", data_collection, "general.test") + if not os.path.exists(path): + os.makedirs(path) + obj.data_store.set("experiment_path", path, "general") + os.makedirs(batch_path, exist_ok=True) + obj.data_store.set("batch_path", batch_path, "general") + os.makedirs(model_path, exist_ok=True) + obj.data_store.set("model_path", model_path, "general") + obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") + obj.data_store.set("experiment_name", "TestExperiment", "general") + + path_plot = os.path.join(path, "plots") + os.makedirs(path_plot, exist_ok=True) + obj.data_store.set("plot_path", path_plot, "general") + obj._train_model = True + obj._create_new_model = False + + obj.model.compile(**obj.model.compile_options) + return obj + def test_init(self, ready_to_init): assert isinstance(Training(), Training) # just test, if nothing fails @@ -312,58 +362,13 @@ class TestTraining: init_without_run.create_monitoring_plots(history, learning_rate) assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 - def test_resume_training(self, ready_to_run, path: str, model: keras.Model, model_path, - batch_path, data_collection): - with ready_to_run as run_obj: - assert run_obj._run() is None # rune once to create model - - # init new object - obj = object.__new__(Training) - super(Training, obj).__init__() - obj.model = model - obj.train_set = None - obj.val_set = None - obj.test_set = None - obj.batch_size = 256 - obj.epochs = 4 - - clbk = CallbackHandler() - hist = HistoryAdvanced() - epo_timing = EpoTimingCallback() - clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") - lr = LearningRateDecay() - clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") - clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing") - clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', - save_best_only=True) - obj.callbacks = clbk - obj.lr_sc = lr - obj.hist = hist - obj.experiment_name = "TestExperiment" - obj.data_store.set("data_collection", data_collection, "general.train") - obj.data_store.set("data_collection", data_collection, "general.val") - obj.data_store.set("data_collection", data_collection, "general.test") - obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) - if not os.path.exists(path): - os.makedirs(path) - obj.data_store.set("experiment_path", path, "general") - os.makedirs(batch_path, exist_ok=True) - obj.data_store.set("batch_path", batch_path, "general") - os.makedirs(model_path, exist_ok=True) - obj.data_store.set("model_path", model_path, "general") - obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") - obj.data_store.set("experiment_name", "TestExperiment", "general") - - path_plot = os.path.join(path, "plots") - os.makedirs(path_plot, exist_ok=True) - obj.data_store.set("plot_path", path_plot, "general") - obj._train_model = True - obj._create_new_model = False - - - assert obj._run() is None - assert 1 == 1 - assert 1 == 1 - - + def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var, + window_history_size, window_lead_time): + obj_1st = self.create_training_obj(2, path, data_collection, batch_path, model_path, statistics_per_var, + window_history_size, window_lead_time) + keras.utils.get_custom_objects().update(obj_1st.model.custom_objects) + assert obj_1st._run() is None + obj_2nd = self.create_training_obj(4, path, data_collection, batch_path, model_path, statistics_per_var, + window_history_size, window_lead_time) + assert obj_2nd._run() is None