Skip to content
Snippets Groups Projects
Commit a43bbba0 authored by leufen1's avatar leufen1
Browse files

tests should pass now, at least for training run module

parent e4796194
No related branches found
No related tags found
5 merge requests!413update release branch,!412Resolve "release v2.0.0",!361name of pdf starts now with feature_importance, there is now also another...,!350Resolve "upgrade code to TensorFlow V2",!335Resolve "upgrade code to TensorFlow V2"
Pipeline #82162 failed
......@@ -3,6 +3,7 @@
__author__ = 'Lukas Leufen, Felix Kleinert'
__date__ = '2020-01-31'
import copy
import logging
import math
import pickle
......@@ -199,12 +200,18 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f:
pickle.dump(callback["callback"], f)
c = copy.copy(callback["callback"])
if hasattr(c, "model"):
c.model = None
pickle.dump(c, f)
else:
with open(file_path, "wb") as f:
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
pickle.dump(callback["callback"], f)
c = copy.copy(callback["callback"])
if hasattr(c, "model"):
c.model = None
pickle.dump(c, f)
clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str})
......@@ -346,6 +353,8 @@ class CallbackHandler:
for pos, callback in enumerate(self.__callbacks):
path = callback["path"]
clb = pickle.load(open(path, "rb"))
if clb.model is None:
clb.model = self._checkpoint.model
self._update_callback(pos, clb)
def update_checkpoint(self, history_name: str = "hist") -> None:
......
import copy
import glob
import json
import logging
import os
import shutil
from typing import Callable
import tensorflow.keras as keras
import mock
......@@ -76,10 +78,24 @@ class TestTraining:
obj.data_store.set("plot_path", path_plot, "general")
obj._train_model = True
obj._create_new_model = False
try:
yield obj
finally:
if os.path.exists(path):
shutil.rmtree(path)
try:
RunEnvironment().__del__()
except AssertionError:
pass
# try:
# yield obj
# finally:
# if os.path.exists(path):
# shutil.rmtree(path)
# try:
# RunEnvironment().__del__()
# except AssertionError:
# pass
@pytest.fixture
def learning_rate(self):
......@@ -223,9 +239,10 @@ class TestTraining:
assert ready_to_run._run() is None # just test, if nothing fails
def test_make_predict_function(self, init_without_run):
assert hasattr(init_without_run.model, "predict_function") is False
assert hasattr(init_without_run.model, "predict_function") is True
assert init_without_run.model.predict_function is None
init_without_run.make_predict_function()
assert hasattr(init_without_run.model, "predict_function")
assert isinstance(init_without_run.model.predict_function, Callable)
def test_set_gen(self, init_without_run):
assert init_without_run.train_set is None
......@@ -242,10 +259,10 @@ class TestTraining:
[getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets])
def test_train(self, ready_to_train, path):
assert not hasattr(ready_to_train.model, "history")
assert ready_to_train.model.history is None
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
ready_to_train.train()
assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"]
assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"]
assert ready_to_train.model.history.epoch == [0, 1]
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
......@@ -260,8 +277,8 @@ class TestTraining:
def test_load_best_model_no_weights(self, init_without_run, caplog):
caplog.set_level(logging.DEBUG)
init_without_run.load_best_model("notExisting")
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
init_without_run.load_best_model("notExisting.h5")
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5"))
assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
......@@ -290,3 +307,10 @@ class TestTraining:
history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
init_without_run.create_monitoring_plots(history, learning_rate)
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
def test_resume_training(self, ready_to_run):
with copy.copy(ready_to_run) as pre_run:
assert pre_run._run() is None # rune once to create model
ready_to_run.epochs = 4 # continue train up to epoch 4
assert ready_to_run._run() is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment