Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Commits
aee2fe01
Commit
aee2fe01
authored
5 years ago
by
lukas leufen
Browse files
Options
Downloads
Patches
Plain Diff
implemented CallbackHandler
parent
6bb881f5
No related branches found
No related tags found
2 merge requests
!50
release for v0.7.0
,
!42
implemented CallbackHandler
Pipeline
#29481
passed
5 years ago
Stage: test
Stage: pages
Stage: deploy
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/model_modules/keras_extensions.py
+61
-0
61 additions, 0 deletions
src/model_modules/keras_extensions.py
test/test_model_modules/test_keras_extensions.py
+121
-0
121 additions, 0 deletions
test/test_model_modules/test_keras_extensions.py
with
182 additions
and
0 deletions
src/model_modules/keras_extensions.py
+
61
−
0
View file @
aee2fe01
...
@@ -150,3 +150,64 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
...
@@ -150,3 +150,64 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if
self
.
verbose
>
0
:
# pragma: no branch
if
self
.
verbose
>
0
:
# pragma: no branch
print
(
'
\n
Epoch %05d: save to %s
'
%
(
epoch
+
1
,
file_path
))
print
(
'
\n
Epoch %05d: save to %s
'
%
(
epoch
+
1
,
file_path
))
pickle
.
dump
(
callback
[
"
callback
"
],
f
)
pickle
.
dump
(
callback
[
"
callback
"
],
f
)
class
CallbackHandler
:
def
__init__
(
self
):
self
.
__callbacks
=
[]
self
.
_checkpoint
=
None
self
.
editable
=
True
@property
def
_callbacks
(
self
):
return
[{
"
callback
"
:
clbk
[
clbk
[
"
name
"
]],
"
path
"
:
clbk
[
"
path
"
]}
for
clbk
in
self
.
__callbacks
]
@_callbacks.setter
def
_callbacks
(
self
,
value
):
name
,
callback
,
callback_path
=
value
self
.
__callbacks
.
append
({
"
name
"
:
name
,
name
:
callback
,
"
path
"
:
callback_path
})
def
_update_callback
(
self
,
pos
,
value
):
name
=
self
.
__callbacks
[
pos
][
"
name
"
]
self
.
__callbacks
[
pos
][
name
]
=
value
def
add_callback
(
self
,
callback
,
callback_path
,
name
=
"
callback
"
):
if
self
.
editable
:
self
.
_callbacks
=
(
name
,
callback
,
callback_path
)
else
:
raise
PermissionError
(
f
"
{
__class__
.
__name__
}
is protected and cannot be edited.
"
)
def
get_callbacks
(
self
,
as_dict
=
True
):
if
as_dict
:
return
self
.
_get_callbacks
()
else
:
return
[
clb
[
"
callback
"
]
for
clb
in
self
.
_get_callbacks
()]
def
get_callback_by_name
(
self
,
obj_name
):
if
obj_name
!=
"
callback
"
:
return
[
clbk
[
clbk
[
"
name
"
]]
for
clbk
in
self
.
__callbacks
if
clbk
[
"
name
"
]
==
obj_name
][
0
]
def
_get_callbacks
(
self
):
clbks
=
self
.
_callbacks
if
self
.
_checkpoint
is
not
None
:
clbks
+=
[{
"
callback
"
:
self
.
_checkpoint
,
"
path
"
:
self
.
_checkpoint
.
filepath
}]
return
clbks
def
get_checkpoint
(
self
):
if
self
.
_checkpoint
is
not
None
:
return
self
.
_checkpoint
def
create_model_checkpoint
(
self
,
**
kwargs
):
self
.
_checkpoint
=
ModelCheckpointAdvanced
(
callbacks
=
self
.
_callbacks
,
**
kwargs
)
self
.
editable
=
False
def
load_callbacks
(
self
):
for
pos
,
callback
in
enumerate
(
self
.
__callbacks
):
path
=
callback
[
"
path
"
]
clb
=
pickle
.
load
(
open
(
path
,
"
rb
"
))
self
.
_update_callback
(
pos
,
clb
)
def
update_checkpoint
(
self
,
history_name
=
"
hist
"
):
self
.
_checkpoint
.
update_callbacks
(
self
.
_callbacks
)
self
.
_checkpoint
.
update_best
(
self
.
get_callback_by_name
(
history_name
))
This diff is collapsed.
Click to expand it.
test/test_model_modules/test_keras_extensions.py
+
121
−
0
View file @
aee2fe01
...
@@ -110,3 +110,124 @@ class TestModelCheckpointAdvanced:
...
@@ -110,3 +110,124 @@ class TestModelCheckpointAdvanced:
assert
"
callback_hist
"
in
os
.
listdir
(
path
)
assert
"
callback_hist
"
in
os
.
listdir
(
path
)
os
.
remove
(
os
.
path
.
join
(
path
,
"
callback_hist
"
))
os
.
remove
(
os
.
path
.
join
(
path
,
"
callback_hist
"
))
os
.
remove
(
os
.
path
.
join
(
path
,
"
callback_lr
"
))
os
.
remove
(
os
.
path
.
join
(
path
,
"
callback_lr
"
))
class
TestCallbackHandler
:
@pytest.fixture
def
clbk_handler
(
self
):
return
CallbackHandler
()
@pytest.fixture
def
clbk_handler_with_dummies
(
self
,
clbk_handler
):
clbk_handler
.
add_callback
(
"
callback_new_instance
"
,
"
this_path
"
)
clbk_handler
.
add_callback
(
"
callback_other
"
,
"
otherpath
"
,
"
other_clbk
"
)
return
clbk_handler
@pytest.fixture
def
callback_handler
(
self
,
clbk_handler
):
clbk_handler
.
add_callback
(
HistoryAdvanced
(),
"
callbacks_hist.pickle
"
,
"
hist
"
)
clbk_handler
.
add_callback
(
LearningRateDecay
(),
"
callbacks_lr.pickle
"
,
"
lr
"
)
return
clbk_handler
@pytest.fixture
def
prepare_pickle_files
(
self
):
hist
=
HistoryAdvanced
()
hist
.
epoch
=
[
1
,
2
,
3
]
hist
.
history
=
{
"
val_loss
"
:
[
10
,
5
,
4
]}
lr
=
LearningRateDecay
()
lr
.
epoch
=
[
1
,
2
,
3
]
pickle
.
dump
(
hist
,
open
(
"
callbacks_hist.pickle
"
,
"
wb
"
))
pickle
.
dump
(
lr
,
open
(
"
callbacks_lr.pickle
"
,
"
wb
"
))
yield
os
.
remove
(
"
callbacks_hist.pickle
"
)
os
.
remove
(
"
callbacks_lr.pickle
"
)
def
test_init
(
self
,
clbk_handler
):
assert
len
(
clbk_handler
.
_CallbackHandler__callbacks
)
==
0
assert
clbk_handler
.
_checkpoint
is
None
assert
clbk_handler
.
editable
is
True
def
test_callbacks_set
(
self
,
clbk_handler
):
clbk_handler
.
_callbacks
=
(
"
default
"
,
"
callback_instance
"
,
"
callback_path
"
)
assert
clbk_handler
.
_CallbackHandler__callbacks
==
[{
"
name
"
:
"
default
"
,
"
default
"
:
"
callback_instance
"
,
"
path
"
:
"
callback_path
"
}]
clbk_handler
.
_callbacks
=
(
"
another
"
,
"
callback_instance2
"
,
"
callback_path
"
)
assert
clbk_handler
.
_CallbackHandler__callbacks
==
[{
"
name
"
:
"
default
"
,
"
default
"
:
"
callback_instance
"
,
"
path
"
:
"
callback_path
"
},
{
"
name
"
:
"
another
"
,
"
another
"
:
"
callback_instance2
"
,
"
path
"
:
"
callback_path
"
}]
def
test_callbacks_get
(
self
,
clbk_handler
):
clbk_handler
.
_callbacks
=
(
"
default
"
,
"
callback_instance
"
,
"
callback_path
"
)
clbk_handler
.
_callbacks
=
(
"
another
"
,
"
callback_instance2
"
,
"
callback_path2
"
)
assert
clbk_handler
.
_callbacks
==
[{
"
callback
"
:
"
callback_instance
"
,
"
path
"
:
"
callback_path
"
},
{
"
callback
"
:
"
callback_instance2
"
,
"
path
"
:
"
callback_path2
"
}]
def
test_update_callback
(
self
,
clbk_handler_with_dummies
):
clbk_handler_with_dummies
.
_update_callback
(
0
,
"
old_instance
"
)
assert
clbk_handler_with_dummies
.
get_callbacks
()
==
[{
"
callback
"
:
"
old_instance
"
,
"
path
"
:
"
this_path
"
},
{
"
callback
"
:
"
callback_other
"
,
"
path
"
:
"
otherpath
"
}]
def
test_add_callback
(
self
,
clbk_handler
):
clbk_handler
.
add_callback
(
"
callback_new_instance
"
,
"
this_path
"
)
assert
clbk_handler
.
_CallbackHandler__callbacks
==
[{
"
name
"
:
"
callback
"
,
"
callback
"
:
"
callback_new_instance
"
,
"
path
"
:
"
this_path
"
}]
clbk_handler
.
add_callback
(
"
callback_other
"
,
"
otherpath
"
,
"
other_clbk
"
)
assert
clbk_handler
.
_CallbackHandler__callbacks
==
[{
"
name
"
:
"
callback
"
,
"
callback
"
:
"
callback_new_instance
"
,
"
path
"
:
"
this_path
"
},
{
"
name
"
:
"
other_clbk
"
,
"
other_clbk
"
:
"
callback_other
"
,
"
path
"
:
"
otherpath
"
}]
def
test_get_callbacks_as_dict
(
self
,
clbk_handler_with_dummies
):
clbk
=
clbk_handler_with_dummies
assert
clbk
.
get_callbacks
()
==
[{
"
callback
"
:
"
callback_new_instance
"
,
"
path
"
:
"
this_path
"
},
{
"
callback
"
:
"
callback_other
"
,
"
path
"
:
"
otherpath
"
}]
assert
clbk
.
get_callbacks
()
==
clbk
.
get_callbacks
(
as_dict
=
True
)
def
test_get_callbacks_no_dict
(
self
,
clbk_handler_with_dummies
):
assert
clbk_handler_with_dummies
.
get_callbacks
(
as_dict
=
False
)
==
[
"
callback_new_instance
"
,
"
callback_other
"
]
def
test_get_callback_by_name
(
self
,
clbk_handler_with_dummies
):
assert
clbk_handler_with_dummies
.
get_callback_by_name
(
"
other_clbk
"
)
==
"
callback_other
"
assert
clbk_handler_with_dummies
.
get_callback_by_name
(
"
callback
"
)
is
None
def
test__get_callbacks
(
self
,
clbk_handler_with_dummies
):
clbk
=
clbk_handler_with_dummies
assert
clbk
.
_get_callbacks
()
==
[{
"
callback
"
:
"
callback_new_instance
"
,
"
path
"
:
"
this_path
"
},
{
"
callback
"
:
"
callback_other
"
,
"
path
"
:
"
otherpath
"
}]
ckpt
=
keras
.
callbacks
.
ModelCheckpoint
(
"
testFilePath
"
)
clbk
.
_checkpoint
=
ckpt
assert
clbk
.
_get_callbacks
()
==
[{
"
callback
"
:
"
callback_new_instance
"
,
"
path
"
:
"
this_path
"
},
{
"
callback
"
:
"
callback_other
"
,
"
path
"
:
"
otherpath
"
},
{
"
callback
"
:
ckpt
,
"
path
"
:
"
testFilePath
"
}]
def
test_get_checkpoint
(
self
,
clbk_handler
):
assert
clbk_handler
.
get_checkpoint
()
is
None
clbk_handler
.
_checkpoint
=
"
testCKPT
"
assert
clbk_handler
.
get_checkpoint
()
==
"
testCKPT
"
def
test_create_model_checkpoint
(
self
,
callback_handler
):
callback_handler
.
create_model_checkpoint
(
filepath
=
"
tester_path
"
,
verbose
=
1
)
assert
callback_handler
.
editable
is
False
assert
isinstance
(
callback_handler
.
_checkpoint
,
ModelCheckpointAdvanced
)
assert
callback_handler
.
_checkpoint
.
filepath
==
"
tester_path
"
assert
callback_handler
.
_checkpoint
.
verbose
==
1
assert
callback_handler
.
_checkpoint
.
monitor
==
"
val_loss
"
def
test_load_callbacks
(
self
,
callback_handler
,
prepare_pickle_files
):
assert
len
(
callback_handler
.
get_callback_by_name
(
"
hist
"
).
epoch
)
==
0
assert
len
(
callback_handler
.
get_callback_by_name
(
"
lr
"
).
epoch
)
==
0
callback_handler
.
load_callbacks
()
assert
len
(
callback_handler
.
get_callback_by_name
(
"
hist
"
).
epoch
)
==
3
assert
len
(
callback_handler
.
get_callback_by_name
(
"
lr
"
).
epoch
)
==
3
def
test_update_checkpoint
(
self
,
callback_handler
,
prepare_pickle_files
):
assert
len
(
callback_handler
.
get_callback_by_name
(
"
hist
"
).
epoch
)
==
0
assert
len
(
callback_handler
.
get_callback_by_name
(
"
lr
"
).
epoch
)
==
0
callback_handler
.
create_model_checkpoint
(
filepath
=
"
tester_path
"
,
verbose
=
1
)
callback_handler
.
load_callbacks
()
callback_handler
.
update_checkpoint
()
assert
len
(
callback_handler
.
get_callback_by_name
(
"
hist
"
).
epoch
)
==
3
assert
len
(
callback_handler
.
get_callback_by_name
(
"
lr
"
).
epoch
)
==
3
assert
callback_handler
.
_checkpoint
.
best
==
4
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment