Skip to content
Snippets Groups Projects
Commit a43bbba0 authored by leufen1's avatar leufen1
Browse files

tests should pass now, at least for training run module

parent e4796194
No related branches found
No related tags found
5 merge requests!413update release branch,!412Resolve "release v2.0.0",!361name of pdf starts now with feature_importance, there is now also another...,!350Resolve "upgrade code to TensorFlow V2",!335Resolve "upgrade code to TensorFlow V2"
Pipeline #82162 failed
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
__author__ = 'Lukas Leufen, Felix Kleinert' __author__ = 'Lukas Leufen, Felix Kleinert'
__date__ = '2020-01-31' __date__ = '2020-01-31'
import copy
import logging import logging
import math import math
import pickle import pickle
...@@ -199,12 +200,18 @@ class ModelCheckpointAdvanced(ModelCheckpoint): ...@@ -199,12 +200,18 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.verbose > 0: # pragma: no branch if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
pickle.dump(callback["callback"], f) c = copy.copy(callback["callback"])
if hasattr(c, "model"):
c.model = None
pickle.dump(c, f)
else: else:
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
if self.verbose > 0: # pragma: no branch if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
pickle.dump(callback["callback"], f) c = copy.copy(callback["callback"])
if hasattr(c, "model"):
c.model = None
pickle.dump(c, f)
clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str}) clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str})
...@@ -346,6 +353,8 @@ class CallbackHandler: ...@@ -346,6 +353,8 @@ class CallbackHandler:
for pos, callback in enumerate(self.__callbacks): for pos, callback in enumerate(self.__callbacks):
path = callback["path"] path = callback["path"]
clb = pickle.load(open(path, "rb")) clb = pickle.load(open(path, "rb"))
if clb.model is None:
clb.model = self._checkpoint.model
self._update_callback(pos, clb) self._update_callback(pos, clb)
def update_checkpoint(self, history_name: str = "hist") -> None: def update_checkpoint(self, history_name: str = "hist") -> None:
......
import copy
import glob import glob
import json import json
import logging import logging
import os import os
import shutil import shutil
from typing import Callable
import tensorflow.keras as keras import tensorflow.keras as keras
import mock import mock
...@@ -76,10 +78,24 @@ class TestTraining: ...@@ -76,10 +78,24 @@ class TestTraining:
obj.data_store.set("plot_path", path_plot, "general") obj.data_store.set("plot_path", path_plot, "general")
obj._train_model = True obj._train_model = True
obj._create_new_model = False obj._create_new_model = False
try:
yield obj yield obj
finally:
if os.path.exists(path): if os.path.exists(path):
shutil.rmtree(path) shutil.rmtree(path)
try:
RunEnvironment().__del__() RunEnvironment().__del__()
except AssertionError:
pass
# try:
# yield obj
# finally:
# if os.path.exists(path):
# shutil.rmtree(path)
# try:
# RunEnvironment().__del__()
# except AssertionError:
# pass
@pytest.fixture @pytest.fixture
def learning_rate(self): def learning_rate(self):
...@@ -223,9 +239,10 @@ class TestTraining: ...@@ -223,9 +239,10 @@ class TestTraining:
assert ready_to_run._run() is None # just test, if nothing fails assert ready_to_run._run() is None # just test, if nothing fails
def test_make_predict_function(self, init_without_run): def test_make_predict_function(self, init_without_run):
assert hasattr(init_without_run.model, "predict_function") is False assert hasattr(init_without_run.model, "predict_function") is True
assert init_without_run.model.predict_function is None
init_without_run.make_predict_function() init_without_run.make_predict_function()
assert hasattr(init_without_run.model, "predict_function") assert isinstance(init_without_run.model.predict_function, Callable)
def test_set_gen(self, init_without_run): def test_set_gen(self, init_without_run):
assert init_without_run.train_set is None assert init_without_run.train_set is None
...@@ -242,10 +259,10 @@ class TestTraining: ...@@ -242,10 +259,10 @@ class TestTraining:
[getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets]) [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets])
def test_train(self, ready_to_train, path): def test_train(self, ready_to_train, path):
assert not hasattr(ready_to_train.model, "history") assert ready_to_train.model.history is None
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
ready_to_train.train() ready_to_train.train()
assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"] assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"]
assert ready_to_train.model.history.epoch == [0, 1] assert ready_to_train.model.history.epoch == [0, 1]
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
...@@ -260,8 +277,8 @@ class TestTraining: ...@@ -260,8 +277,8 @@ class TestTraining:
def test_load_best_model_no_weights(self, init_without_run, caplog): def test_load_best_model_no_weights(self, init_without_run, caplog):
caplog.set_level(logging.DEBUG) caplog.set_level(logging.DEBUG)
init_without_run.load_best_model("notExisting") init_without_run.load_best_model("notExisting.h5")
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting")) assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5"))
assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
...@@ -290,3 +307,10 @@ class TestTraining: ...@@ -290,3 +307,10 @@ class TestTraining:
history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"]) history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
init_without_run.create_monitoring_plots(history, learning_rate) init_without_run.create_monitoring_plots(history, learning_rate)
assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
def test_resume_training(self, ready_to_run):
with copy.copy(ready_to_run) as pre_run:
assert pre_run._run() is None # rune once to create model
ready_to_run.epochs = 4 # continue train up to epoch 4
assert ready_to_run._run() is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment