Select Git revision
-
lukas leufen authoredlukas leufen authored
test_training.py 17.23 KiB
import copy
import glob
import json
import time
import logging
import os
import shutil
from typing import Callable
import tensorflow.keras as keras
import mock
import pytest
from tensorflow.keras.callbacks import History
from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler
from mlair.helpers import PyTestRegex
from mlair.model_modules.fully_connected_networks import FCN
from mlair.model_modules.flatten import flatten_tail
from mlair.model_modules.inception_model import InceptionModelBase
from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback
from mlair.run_modules.run_environment import RunEnvironment
from mlair.run_modules.training import Training
def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False):
inception_model = InceptionModelBase()
conv_settings_dict1 = {
'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
if add_minor_branch:
out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
output_activation='linear', reduction_filter=64,
name='Minor_1', dropout_rate=dropout_rate,
)]
else:
out = []
X_in = keras.layers.Dropout(dropout_rate)(X_in)
out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size,
output_activation='linear', reduction_filter=64,
name='Main', dropout_rate=dropout_rate,
))
return keras.Model(inputs=X_input, outputs=out)
class TestTraining:
@pytest.fixture
def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path, batch_path):
obj = object.__new__(Training)
super(Training, obj).__init__()
obj.model = model
obj.train_set = None
obj.val_set = None
obj.test_set = None
obj.batch_size = 256
obj.epochs = 2
clbk, hist, lr = callbacks
obj.callbacks = clbk
obj.lr_sc = lr
obj.hist = hist
obj.experiment_name = "TestExperiment"
obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train")
obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val")
obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test")
if not os.path.exists(path):
os.makedirs(path)
obj.data_store.set("experiment_path", path, "general")
os.makedirs(batch_path)
obj.data_store.set("batch_path", batch_path, "general")
os.makedirs(model_path)
obj.data_store.set("model_path", model_path, "general")
obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
path_plot = os.path.join(path, "plots")
os.makedirs(path_plot)
obj.data_store.set("plot_path", path_plot, "general")
obj._train_model = True
obj._create_new_model = False
try:
yield obj
finally:
if os.path.exists(path):
shutil.rmtree(path)
try:
RunEnvironment().__del__()
except AssertionError:
pass
# try:
# yield obj
# finally:
# if os.path.exists(path):
# shutil.rmtree(path)
# try:
# RunEnvironment().__del__()
# except AssertionError:
# pass
@pytest.fixture
def learning_rate(self):
lr = LearningRateDecay()
lr.lr = {"lr": [0.01, 0.0094]}
return lr
@pytest.fixture
def history(self):
h = History()
h.epoch = [0, 1]
h.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
'loss': [0.6795708956961347, 0.45963566494176616],
'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
h.model = mock.MagicMock()
return h
@pytest.fixture
def epo_timing(self):
epo_timing = EpoTimingCallback()
epo_timing.epoch = [0, 1]
epo_timing.epo_timing = {"epo_timing": [0.1, 0.2]}
@pytest.fixture
def path(self):
return os.path.join(os.path.dirname(__file__), "TestExperiment")
@pytest.fixture
def model_path(self, path):
return os.path.join(path, "model")
@pytest.fixture
def batch_path(self, path):
return os.path.join(path, "batch")
@pytest.fixture
def window_history_size(self):
return 7
@pytest.fixture
def window_lead_time(self):
return 2
@pytest.fixture
def statistics_per_var(self):
return {'o3': 'dma8eu', 'temp': 'maximum'}
@pytest.fixture
def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var):
data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(path, 'data'),
experiment_path=os.path.join(path, 'exp_path'),
statistics_per_var=statistics_per_var, station_type="background",
network="AIRBASE", sampling="daily", target_dim="variables",
target_var="o3", time_dim="datetime",
window_history_size=window_history_size,
window_lead_time=window_lead_time, name_affix="train")
return DataCollection([data_prep])
@pytest.fixture
def model(self, window_history_size, window_lead_time, statistics_per_var):
channels = len(list(statistics_per_var.keys()))
return FCN([(window_history_size + 1, 1, channels)], [window_lead_time])
@pytest.fixture
def callbacks(self, path):
clbk = CallbackHandler()
hist = HistoryAdvanced()
epo_timing = EpoTimingCallback()
clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
lr = LearningRateDecay()
clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
save_best_only=True)
return clbk, hist, lr
@pytest.fixture
def ready_to_train(self, data_collection: DataCollection, init_without_run: Training, batch_path: str):
batch_size = init_without_run.batch_size
model = init_without_run.model
init_without_run.train_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="train")
init_without_run.val_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="val")
init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
return init_without_run
@pytest.fixture
def ready_to_run(self, data_collection, init_without_run):
obj = init_without_run
obj.data_store.set("data_collection", data_collection, "general.train")
obj.data_store.set("data_collection", data_collection, "general.val")
obj.data_store.set("data_collection", data_collection, "general.test")
obj.model.compile(**obj.model.compile_options)
keras.utils.get_custom_objects().update(obj.model.custom_objects)
return obj
@pytest.fixture
def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path):
if not os.path.exists(path):
os.makedirs(path)
os.makedirs(model_path)
obj = RunEnvironment()
obj.data_store.set("data_collection", data_collection, "general.train")
obj.data_store.set("data_collection", data_collection, "general.val")
obj.data_store.set("data_collection", data_collection, "general.test")
model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
obj.data_store.set("model", model, "general.model")
obj.data_store.set("model_path", model_path, "general")
obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
obj.data_store.set("batch_size", 256, "general")
obj.data_store.set("epochs", 2, "general")
clbk, hist, lr = callbacks
obj.data_store.set("callbacks", clbk, "general.model")
obj.data_store.set("lr_decay", lr, "general.model")
obj.data_store.set("hist", hist, "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
obj.data_store.set("experiment_path", path, "general")
obj.data_store.set("train_model", True, "general")
obj.data_store.set("create_new_model", True, "general")
os.makedirs(batch_path)
obj.data_store.set("batch_path", batch_path, "general")
path_plot = os.path.join(path, "plots")
os.makedirs(path_plot)
obj.data_store.set("plot_path", path_plot, "general")
yield obj
if os.path.exists(path):
shutil.rmtree(path)
@staticmethod
def create_training_obj(epochs, path, data_collection, batch_path, model_path,
statistics_per_var, window_history_size, window_lead_time) -> Training:
channels = len(list(statistics_per_var.keys()))
model = FCN([(window_history_size + 1, 1, channels)], [window_lead_time])
obj = object.__new__(Training)
super(Training, obj).__init__()
obj.model = model
obj.train_set = None
obj.val_set = None
obj.test_set = None
obj.batch_size = 256
obj.epochs = epochs
clbk = CallbackHandler()
hist = HistoryAdvanced()
epo_timing = EpoTimingCallback()
clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
lr = LearningRateDecay()
clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
save_best_only=True)
obj.callbacks = clbk
obj.lr_sc = lr
obj.hist = hist
obj.experiment_name = "TestExperiment"
obj.data_store.set("data_collection", data_collection, "general.train")
obj.data_store.set("data_collection", data_collection, "general.val")
obj.data_store.set("data_collection", data_collection, "general.test")
if not os.path.exists(path):
os.makedirs(path)
obj.data_store.set("experiment_path", path, "general")
os.makedirs(batch_path, exist_ok=True)
obj.data_store.set("batch_path", batch_path, "general")
os.makedirs(model_path, exist_ok=True)
obj.data_store.set("model_path", model_path, "general")
obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
path_plot = os.path.join(path, "plots")
os.makedirs(path_plot, exist_ok=True)
obj.data_store.set("plot_path", path_plot, "general")
obj._train_model = True
obj._create_new_model = False
obj.model.compile(**obj.model.compile_options)
return obj
def test_init(self, ready_to_init):
assert isinstance(Training(), Training) # just test, if nothing fails
def test_no_training(self, ready_to_init, caplog):
caplog.set_level(logging.INFO)
ready_to_init.data_store.set("train_model", False)
Training()
message = "No training has started, because train_model parameter was false."
assert caplog.record_tuples[-2] == ("root", 20, message)
def test_run(self, ready_to_run):
assert ready_to_run._run() is None # just test, if nothing fails
def test_make_predict_function(self, init_without_run):
assert hasattr(init_without_run.model, "predict_function") is True
assert init_without_run.model.predict_function is None
init_without_run.make_predict_function()
assert isinstance(init_without_run.model.predict_function, Callable)
def test_set_gen(self, init_without_run):
assert init_without_run.train_set is None
init_without_run._set_gen("train")
assert isinstance(init_without_run.train_set, KerasIterator)
assert init_without_run.train_set._collection.return_value == "mock_train_gen"
def test_set_generators(self, init_without_run):
sets = ["train", "val"]
assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
init_without_run.set_generators()
assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
assert all(
[getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets])
def test_train(self, ready_to_train, path):
assert ready_to_train.model.history is None
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
ready_to_train.train()
assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"]
assert ready_to_train.model.history.epoch == [0, 1]
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
def test_save_model(self, init_without_run, model_path, caplog):
caplog.set_level(logging.DEBUG)
model_name = "test_model.h5"
assert model_name not in os.listdir(model_path)
init_without_run.save_model()
message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}")
assert caplog.record_tuples[1] == ("root", 10, message)
assert model_name in os.listdir(model_path)
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
assert "history.json" in os.listdir(model_path)
def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
assert "history_lr.json" in os.listdir(model_path)
def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
with open(os.path.join(model_path, "history.json")) as jfile:
hist = json.load(jfile)
assert hist == history.history
def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
with open(os.path.join(model_path, "history_lr.json")) as jfile:
lr = json.load(jfile)
assert lr == learning_rate.lr
def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
history.model.output_names = mock.MagicMock(return_value=["Main"])
history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
init_without_run.create_monitoring_plots(history, learning_rate, epoch_best=1)
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var,
window_history_size, window_lead_time):
obj_1st = self.create_training_obj(4, path, data_collection, batch_path, model_path, statistics_per_var,
window_history_size, window_lead_time)
keras.utils.get_custom_objects().update(obj_1st.model.custom_objects)
assert obj_1st._run() is None
obj_2nd = self.create_training_obj(8, path, data_collection, batch_path, model_path, statistics_per_var,
window_history_size, window_lead_time)
assert obj_2nd._run() is None