Skip to content
Snippets Groups Projects
Commit 5cb2ba3a authored by leufen1's avatar leufen1
Browse files

update test for train resuming

parent 1c57e71b
No related branches found
No related tags found
5 merge requests!413update release branch,!412Resolve "release v2.0.0",!361name of pdf starts now with feature_importance, there is now also another...,!350Resolve "upgrade code to TensorFlow V2",!335Resolve "upgrade code to TensorFlow V2"
Pipeline #83524 failed
import copy import copy
import glob import glob
import json import json
import time
import logging import logging
import os import os
import shutil import shutil
...@@ -161,11 +163,8 @@ class TestTraining: ...@@ -161,11 +163,8 @@ class TestTraining:
@pytest.fixture @pytest.fixture
def model(self, window_history_size, window_lead_time, statistics_per_var): def model(self, window_history_size, window_lead_time, statistics_per_var):
channels = len(list(statistics_per_var.keys())) channels = len(list(statistics_per_var.keys()))
return FCN([(window_history_size + 1, 1, channels)], [window_lead_time]) return FCN([(window_history_size + 1, 1, channels)], [window_lead_time])
# return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False)
@pytest.fixture @pytest.fixture
def callbacks(self, path): def callbacks(self, path):
clbk = CallbackHandler() clbk = CallbackHandler()
...@@ -194,7 +193,7 @@ class TestTraining: ...@@ -194,7 +193,7 @@ class TestTraining:
obj.data_store.set("data_collection", data_collection, "general.train") obj.data_store.set("data_collection", data_collection, "general.train")
obj.data_store.set("data_collection", data_collection, "general.val") obj.data_store.set("data_collection", data_collection, "general.val")
obj.data_store.set("data_collection", data_collection, "general.test") obj.data_store.set("data_collection", data_collection, "general.test")
obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) obj.model.compile(**obj.model.compile_options)
return obj return obj
@pytest.fixture @pytest.fixture
...@@ -229,6 +228,57 @@ class TestTraining: ...@@ -229,6 +228,57 @@ class TestTraining:
if os.path.exists(path): if os.path.exists(path):
shutil.rmtree(path) shutil.rmtree(path)
@staticmethod
def create_training_obj(epochs, path, data_collection, batch_path, model_path,
statistics_per_var, window_history_size, window_lead_time) -> Training:
channels = len(list(statistics_per_var.keys()))
model = FCN([(window_history_size + 1, 1, channels)], [window_lead_time])
obj = object.__new__(Training)
super(Training, obj).__init__()
obj.model = model
obj.train_set = None
obj.val_set = None
obj.test_set = None
obj.batch_size = 256
obj.epochs = epochs
clbk = CallbackHandler()
hist = HistoryAdvanced()
epo_timing = EpoTimingCallback()
clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
lr = LearningRateDecay()
clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
save_best_only=True)
obj.callbacks = clbk
obj.lr_sc = lr
obj.hist = hist
obj.experiment_name = "TestExperiment"
obj.data_store.set("data_collection", data_collection, "general.train")
obj.data_store.set("data_collection", data_collection, "general.val")
obj.data_store.set("data_collection", data_collection, "general.test")
if not os.path.exists(path):
os.makedirs(path)
obj.data_store.set("experiment_path", path, "general")
os.makedirs(batch_path, exist_ok=True)
obj.data_store.set("batch_path", batch_path, "general")
os.makedirs(model_path, exist_ok=True)
obj.data_store.set("model_path", model_path, "general")
obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
path_plot = os.path.join(path, "plots")
os.makedirs(path_plot, exist_ok=True)
obj.data_store.set("plot_path", path_plot, "general")
obj._train_model = True
obj._create_new_model = False
obj.model.compile(**obj.model.compile_options)
return obj
def test_init(self, ready_to_init): def test_init(self, ready_to_init):
assert isinstance(Training(), Training) # just test, if nothing fails assert isinstance(Training(), Training) # just test, if nothing fails
...@@ -312,58 +362,13 @@ class TestTraining: ...@@ -312,58 +362,13 @@ class TestTraining:
init_without_run.create_monitoring_plots(history, learning_rate) init_without_run.create_monitoring_plots(history, learning_rate)
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
def test_resume_training(self, ready_to_run, path: str, model: keras.Model, model_path, def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var,
batch_path, data_collection): window_history_size, window_lead_time):
with ready_to_run as run_obj:
assert run_obj._run() is None # rune once to create model
# init new object
obj = object.__new__(Training)
super(Training, obj).__init__()
obj.model = model
obj.train_set = None
obj.val_set = None
obj.test_set = None
obj.batch_size = 256
obj.epochs = 4
clbk = CallbackHandler()
hist = HistoryAdvanced()
epo_timing = EpoTimingCallback()
clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
lr = LearningRateDecay()
clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
save_best_only=True)
obj.callbacks = clbk
obj.lr_sc = lr
obj.hist = hist
obj.experiment_name = "TestExperiment"
obj.data_store.set("data_collection", data_collection, "general.train")
obj.data_store.set("data_collection", data_collection, "general.val")
obj.data_store.set("data_collection", data_collection, "general.test")
obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
if not os.path.exists(path):
os.makedirs(path)
obj.data_store.set("experiment_path", path, "general")
os.makedirs(batch_path, exist_ok=True)
obj.data_store.set("batch_path", batch_path, "general")
os.makedirs(model_path, exist_ok=True)
obj.data_store.set("model_path", model_path, "general")
obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
path_plot = os.path.join(path, "plots")
os.makedirs(path_plot, exist_ok=True)
obj.data_store.set("plot_path", path_plot, "general")
obj._train_model = True
obj._create_new_model = False
assert obj._run() is None
assert 1 == 1
assert 1 == 1
obj_1st = self.create_training_obj(2, path, data_collection, batch_path, model_path, statistics_per_var,
window_history_size, window_lead_time)
keras.utils.get_custom_objects().update(obj_1st.model.custom_objects)
assert obj_1st._run() is None
obj_2nd = self.create_training_obj(4, path, data_collection, batch_path, model_path, statistics_per_var,
window_history_size, window_lead_time)
assert obj_2nd._run() is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment