Skip to content
Snippets Groups Projects
Commit a43bbba0 authored by leufen1's avatar leufen1
Browse files

tests should pass now, at least for training run module

parent e4796194
Branches
Tags
5 merge requests!413update release branch,!412Resolve "release v2.0.0",!361name of pdf starts now with feature_importance, there is now also another...,!350Resolve "upgrade code to TensorFlow V2",!335Resolve "upgrade code to TensorFlow V2"
Pipeline #82162 failed
......@@ -3,6 +3,7 @@
__author__ = 'Lukas Leufen, Felix Kleinert'
__date__ = '2020-01-31'
import copy
import logging
import math
import pickle
......@@ -199,12 +200,18 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f:
pickle.dump(callback["callback"], f)
c = copy.copy(callback["callback"])
if hasattr(c, "model"):
c.model = None
pickle.dump(c, f)
else:
with open(file_path, "wb") as f:
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
pickle.dump(callback["callback"], f)
c = copy.copy(callback["callback"])
if hasattr(c, "model"):
c.model = None
pickle.dump(c, f)
clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str})
......@@ -346,6 +353,8 @@ class CallbackHandler:
for pos, callback in enumerate(self.__callbacks):
path = callback["path"]
clb = pickle.load(open(path, "rb"))
if clb.model is None:
clb.model = self._checkpoint.model
self._update_callback(pos, clb)
def update_checkpoint(self, history_name: str = "hist") -> None:
......
import copy
import glob
import json
import logging
import os
import shutil
from typing import Callable
import tensorflow.keras as keras
import mock
......@@ -76,10 +78,24 @@ class TestTraining:
obj.data_store.set("plot_path", path_plot, "general")
obj._train_model = True
obj._create_new_model = False
try:
yield obj
finally:
if os.path.exists(path):
shutil.rmtree(path)
try:
RunEnvironment().__del__()
except AssertionError:
pass
# try:
# yield obj
# finally:
# if os.path.exists(path):
# shutil.rmtree(path)
# try:
# RunEnvironment().__del__()
# except AssertionError:
# pass
@pytest.fixture
def learning_rate(self):
......@@ -223,9 +239,10 @@ class TestTraining:
assert ready_to_run._run() is None # just test, if nothing fails
def test_make_predict_function(self, init_without_run):
assert hasattr(init_without_run.model, "predict_function") is False
assert hasattr(init_without_run.model, "predict_function") is True
assert init_without_run.model.predict_function is None
init_without_run.make_predict_function()
assert hasattr(init_without_run.model, "predict_function")
assert isinstance(init_without_run.model.predict_function, Callable)
def test_set_gen(self, init_without_run):
assert init_without_run.train_set is None
......@@ -242,10 +259,10 @@ class TestTraining:
[getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets])
def test_train(self, ready_to_train, path):
assert not hasattr(ready_to_train.model, "history")
assert ready_to_train.model.history is None
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
ready_to_train.train()
assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"]
assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"]
assert ready_to_train.model.history.epoch == [0, 1]
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
......@@ -260,8 +277,8 @@ class TestTraining:
def test_load_best_model_no_weights(self, init_without_run, caplog):
caplog.set_level(logging.DEBUG)
init_without_run.load_best_model("notExisting")
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
init_without_run.load_best_model("notExisting.h5")
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5"))
assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
......@@ -290,3 +307,10 @@ class TestTraining:
history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
init_without_run.create_monitoring_plots(history, learning_rate)
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
def test_resume_training(self, ready_to_run):
with copy.copy(ready_to_run) as pre_run:
assert pre_run._run() is None # rune once to create model
ready_to_run.epochs = 4 # continue train up to epoch 4
assert ready_to_run._run() is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment