Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Commits
a43bbba0
Commit
a43bbba0
authored
3 years ago
by
leufen1
Browse files
Options
Downloads
Patches
Plain Diff
tests should pass now, at least for training run module
parent
e4796194
No related branches found
No related tags found
5 merge requests
!413
update release branch
,
!412
Resolve "release v2.0.0"
,
!361
name of pdf starts now with feature_importance, there is now also another...
,
!350
Resolve "upgrade code to TensorFlow V2"
,
!335
Resolve "upgrade code to TensorFlow V2"
Pipeline
#82162
failed
3 years ago
Stage: test
Stage: docs
Stage: pages
Stage: deploy
Changes
2
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
mlair/model_modules/keras_extensions.py
+11
-2
11 additions, 2 deletions
mlair/model_modules/keras_extensions.py
test/test_run_modules/test_training.py
+34
-10
34 additions, 10 deletions
test/test_run_modules/test_training.py
with
45 additions
and
12 deletions
mlair/model_modules/keras_extensions.py
+
11
−
2
View file @
a43bbba0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
__author__
=
'
Lukas Leufen, Felix Kleinert
'
__author__
=
'
Lukas Leufen, Felix Kleinert
'
__date__
=
'
2020-01-31
'
__date__
=
'
2020-01-31
'
import
copy
import
logging
import
logging
import
math
import
math
import
pickle
import
pickle
...
@@ -199,12 +200,18 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
...
@@ -199,12 +200,18 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if
self
.
verbose
>
0
:
# pragma: no branch
if
self
.
verbose
>
0
:
# pragma: no branch
print
(
'
\n
Epoch %05d: save to %s
'
%
(
epoch
+
1
,
file_path
))
print
(
'
\n
Epoch %05d: save to %s
'
%
(
epoch
+
1
,
file_path
))
with
open
(
file_path
,
"
wb
"
)
as
f
:
with
open
(
file_path
,
"
wb
"
)
as
f
:
pickle
.
dump
(
callback
[
"
callback
"
],
f
)
c
=
copy
.
copy
(
callback
[
"
callback
"
])
if
hasattr
(
c
,
"
model
"
):
c
.
model
=
None
pickle
.
dump
(
c
,
f
)
else
:
else
:
with
open
(
file_path
,
"
wb
"
)
as
f
:
with
open
(
file_path
,
"
wb
"
)
as
f
:
if
self
.
verbose
>
0
:
# pragma: no branch
if
self
.
verbose
>
0
:
# pragma: no branch
print
(
'
\n
Epoch %05d: save to %s
'
%
(
epoch
+
1
,
file_path
))
print
(
'
\n
Epoch %05d: save to %s
'
%
(
epoch
+
1
,
file_path
))
pickle
.
dump
(
callback
[
"
callback
"
],
f
)
c
=
copy
.
copy
(
callback
[
"
callback
"
])
if
hasattr
(
c
,
"
model
"
):
c
.
model
=
None
pickle
.
dump
(
c
,
f
)
clbk_type
=
TypedDict
(
"
clbk_type
"
,
{
"
name
"
:
str
,
str
:
Callback
,
"
path
"
:
str
})
clbk_type
=
TypedDict
(
"
clbk_type
"
,
{
"
name
"
:
str
,
str
:
Callback
,
"
path
"
:
str
})
...
@@ -346,6 +353,8 @@ class CallbackHandler:
...
@@ -346,6 +353,8 @@ class CallbackHandler:
for
pos
,
callback
in
enumerate
(
self
.
__callbacks
):
for
pos
,
callback
in
enumerate
(
self
.
__callbacks
):
path
=
callback
[
"
path
"
]
path
=
callback
[
"
path
"
]
clb
=
pickle
.
load
(
open
(
path
,
"
rb
"
))
clb
=
pickle
.
load
(
open
(
path
,
"
rb
"
))
if
clb
.
model
is
None
:
clb
.
model
=
self
.
_checkpoint
.
model
self
.
_update_callback
(
pos
,
clb
)
self
.
_update_callback
(
pos
,
clb
)
def
update_checkpoint
(
self
,
history_name
:
str
=
"
hist
"
)
->
None
:
def
update_checkpoint
(
self
,
history_name
:
str
=
"
hist
"
)
->
None
:
...
...
This diff is collapsed.
Click to expand it.
test/test_run_modules/test_training.py
+
34
−
10
View file @
a43bbba0
import
copy
import
glob
import
glob
import
json
import
json
import
logging
import
logging
import
os
import
os
import
shutil
import
shutil
from
typing
import
Callable
import
tensorflow.keras
as
keras
import
tensorflow.keras
as
keras
import
mock
import
mock
...
@@ -76,10 +78,24 @@ class TestTraining:
...
@@ -76,10 +78,24 @@ class TestTraining:
obj
.
data_store
.
set
(
"
plot_path
"
,
path_plot
,
"
general
"
)
obj
.
data_store
.
set
(
"
plot_path
"
,
path_plot
,
"
general
"
)
obj
.
_train_model
=
True
obj
.
_train_model
=
True
obj
.
_create_new_model
=
False
obj
.
_create_new_model
=
False
try
:
yield
obj
yield
obj
finally
:
if
os
.
path
.
exists
(
path
):
if
os
.
path
.
exists
(
path
):
shutil
.
rmtree
(
path
)
shutil
.
rmtree
(
path
)
try
:
RunEnvironment
().
__del__
()
RunEnvironment
().
__del__
()
except
AssertionError
:
pass
# try:
# yield obj
# finally:
# if os.path.exists(path):
# shutil.rmtree(path)
# try:
# RunEnvironment().__del__()
# except AssertionError:
# pass
@pytest.fixture
@pytest.fixture
def
learning_rate
(
self
):
def
learning_rate
(
self
):
...
@@ -223,9 +239,10 @@ class TestTraining:
...
@@ -223,9 +239,10 @@ class TestTraining:
assert
ready_to_run
.
_run
()
is
None
# just test, if nothing fails
assert
ready_to_run
.
_run
()
is
None
# just test, if nothing fails
def
test_make_predict_function
(
self
,
init_without_run
):
def
test_make_predict_function
(
self
,
init_without_run
):
assert
hasattr
(
init_without_run
.
model
,
"
predict_function
"
)
is
False
assert
hasattr
(
init_without_run
.
model
,
"
predict_function
"
)
is
True
assert
init_without_run
.
model
.
predict_function
is
None
init_without_run
.
make_predict_function
()
init_without_run
.
make_predict_function
()
assert
hasattr
(
init_without_run
.
model
,
"
predict_function
"
)
assert
isinstance
(
init_without_run
.
model
.
predict_function
,
Callable
)
def
test_set_gen
(
self
,
init_without_run
):
def
test_set_gen
(
self
,
init_without_run
):
assert
init_without_run
.
train_set
is
None
assert
init_without_run
.
train_set
is
None
...
@@ -242,10 +259,10 @@ class TestTraining:
...
@@ -242,10 +259,10 @@ class TestTraining:
[
getattr
(
init_without_run
,
f
"
{
obj
}
_set
"
).
_collection
.
return_value
==
f
"
mock_
{
obj
}
_gen
"
for
obj
in
sets
])
[
getattr
(
init_without_run
,
f
"
{
obj
}
_set
"
).
_collection
.
return_value
==
f
"
mock_
{
obj
}
_gen
"
for
obj
in
sets
])
def
test_train
(
self
,
ready_to_train
,
path
):
def
test_train
(
self
,
ready_to_train
,
path
):
assert
not
hasattr
(
ready_to_train
.
model
,
"
history
"
)
assert
ready_to_train
.
model
.
history
is
None
assert
len
(
glob
.
glob
(
os
.
path
.
join
(
path
,
"
plots
"
,
"
TestExperiment_history_*.pdf
"
)))
==
0
assert
len
(
glob
.
glob
(
os
.
path
.
join
(
path
,
"
plots
"
,
"
TestExperiment_history_*.pdf
"
)))
==
0
ready_to_train
.
train
()
ready_to_train
.
train
()
assert
list
(
ready_to_train
.
model
.
history
.
history
.
keys
())
==
[
"
val_
loss
"
,
"
loss
"
]
assert
sorted
(
list
(
ready_to_train
.
model
.
history
.
history
.
keys
())
)
==
[
"
loss
"
,
"
val_
loss
"
]
assert
ready_to_train
.
model
.
history
.
epoch
==
[
0
,
1
]
assert
ready_to_train
.
model
.
history
.
epoch
==
[
0
,
1
]
assert
len
(
glob
.
glob
(
os
.
path
.
join
(
path
,
"
plots
"
,
"
TestExperiment_history_*.pdf
"
)))
==
2
assert
len
(
glob
.
glob
(
os
.
path
.
join
(
path
,
"
plots
"
,
"
TestExperiment_history_*.pdf
"
)))
==
2
...
@@ -260,8 +277,8 @@ class TestTraining:
...
@@ -260,8 +277,8 @@ class TestTraining:
def
test_load_best_model_no_weights
(
self
,
init_without_run
,
caplog
):
def
test_load_best_model_no_weights
(
self
,
init_without_run
,
caplog
):
caplog
.
set_level
(
logging
.
DEBUG
)
caplog
.
set_level
(
logging
.
DEBUG
)
init_without_run
.
load_best_model
(
"
notExisting
"
)
init_without_run
.
load_best_model
(
"
notExisting
.h5
"
)
assert
caplog
.
record_tuples
[
0
]
==
(
"
root
"
,
10
,
PyTestRegex
(
"
load best model: notExisting
"
))
assert
caplog
.
record_tuples
[
0
]
==
(
"
root
"
,
10
,
PyTestRegex
(
"
load best model: notExisting
.h5
"
))
assert
caplog
.
record_tuples
[
1
]
==
(
"
root
"
,
20
,
PyTestRegex
(
"
no weights to reload...
"
))
assert
caplog
.
record_tuples
[
1
]
==
(
"
root
"
,
20
,
PyTestRegex
(
"
no weights to reload...
"
))
def
test_save_callbacks_history_created
(
self
,
init_without_run
,
history
,
learning_rate
,
epo_timing
,
model_path
):
def
test_save_callbacks_history_created
(
self
,
init_without_run
,
history
,
learning_rate
,
epo_timing
,
model_path
):
...
@@ -290,3 +307,10 @@ class TestTraining:
...
@@ -290,3 +307,10 @@ class TestTraining:
history
.
model
.
metrics_names
=
mock
.
MagicMock
(
return_value
=
[
"
loss
"
,
"
mean_squared_error
"
])
history
.
model
.
metrics_names
=
mock
.
MagicMock
(
return_value
=
[
"
loss
"
,
"
mean_squared_error
"
])
init_without_run
.
create_monitoring_plots
(
history
,
learning_rate
)
init_without_run
.
create_monitoring_plots
(
history
,
learning_rate
)
assert
len
(
glob
.
glob
(
os
.
path
.
join
(
path
,
"
plots
"
,
"
TestExperiment_history_*.pdf
"
)))
==
2
assert
len
(
glob
.
glob
(
os
.
path
.
join
(
path
,
"
plots
"
,
"
TestExperiment_history_*.pdf
"
)))
==
2
def
test_resume_training
(
self
,
ready_to_run
):
with
copy
.
copy
(
ready_to_run
)
as
pre_run
:
assert
pre_run
.
_run
()
is
None
# rune once to create model
ready_to_run
.
epochs
=
4
# continue train up to epoch 4
assert
ready_to_run
.
_run
()
is
None
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment