Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Merge requests
!29
Lukas issue030 feat continue training
Code
Review changes
Check out branch
Download
Patches
Plain diff
Merged
Lukas issue030 feat continue training
lukas_issue030_feat_continue-training
into
develop
Overview
0
Commits
7
Pipelines
1
Changes
14
Merged
Ghost User
requested to merge
lukas_issue030_feat_continue-training
into
develop
5 years ago
Overview
0
Commits
7
Pipelines
1
Changes
14
Expand
0
0
Merge request reports
Compare
develop
develop (base)
and
latest version
latest version
5accca18
7 commits,
5 years ago
14 files
+
288
−
137
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
Files
14
Search (e.g. *.vue) (Ctrl+P)
src/model_modules/keras_extensions.py
0 → 100644
+
150
−
0
Options
__author__
=
'
Lukas Leufen, Felix Kleinert
'
__date__
=
'
2020-01-31
'
import
logging
import
math
import
pickle
from
typing
import
Union
import
numpy
as
np
from
keras
import
backend
as
K
from
keras.callbacks
import
History
,
ModelCheckpoint
class
HistoryAdvanced
(
History
):
"""
This is almost an identical clone of the original History class. The only difference is that attributes epoch and
history are instantiated during the init phase and not during on_train_begin. This is required to resume an already
started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as
additional callback. To get the full history use this object for further steps instead of the default return of
training methods like fit_generator().
hist = HistoryAdvanced()
history = model.fit_generator(generator=.... , callbacks=[hist])
history = hist
If training was started from beginning this class is identical to the returned history class object.
"""
def
__init__
(
self
):
self
.
epoch
=
[]
self
.
history
=
{}
super
().
__init__
()
def
on_train_begin
(
self
,
logs
=
None
):
pass
class
LearningRateDecay
(
History
):
"""
Decay learning rate during model training. Start with a base learning rate and lower this rate after every
n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
"""
def
__init__
(
self
,
base_lr
:
float
=
0.01
,
drop
:
float
=
0.96
,
epochs_drop
:
int
=
8
):
super
().
__init__
()
self
.
lr
=
{
'
lr
'
:
[]}
self
.
base_lr
=
self
.
check_param
(
base_lr
,
'
base_lr
'
)
self
.
drop
=
self
.
check_param
(
drop
,
'
drop
'
)
self
.
epochs_drop
=
self
.
check_param
(
epochs_drop
,
'
epochs_drop
'
,
upper
=
None
)
self
.
epoch
=
[]
self
.
history
=
{}
@staticmethod
def
check_param
(
value
:
float
,
name
:
str
,
lower
:
Union
[
float
,
None
]
=
0
,
upper
:
Union
[
float
,
None
]
=
1
):
"""
Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
value without any check.
:param value: value to check
:param name: name of the variable to display in error message
:param lower: left (lower) endpoint of interval, opened
:param upper: right (upper) endpoint of interval, closed
:return: unchanged value or raise ValueError
"""
if
lower
is
None
:
lower
=
-
np
.
inf
if
upper
is
None
:
upper
=
np
.
inf
if
lower
<
value
<=
upper
:
return
value
else
:
raise
ValueError
(
f
"
{
name
}
is out of allowed range (
{
lower
}
,
{
upper
}{
'
)
'
if
upper
==
np
.
inf
else
'
]
'
}
:
"
f
"
{
name
}
=
{
value
}
"
)
def
on_train_begin
(
self
,
logs
=
None
):
pass
def
on_epoch_begin
(
self
,
epoch
:
int
,
logs
=
None
):
"""
Lower learning rate every epochs_drop epochs by factor drop.
:param epoch: current epoch
:param logs: ?
:return: update keras learning rate
"""
current_lr
=
self
.
base_lr
*
math
.
pow
(
self
.
drop
,
math
.
floor
(
epoch
/
self
.
epochs_drop
))
K
.
set_value
(
self
.
model
.
optimizer
.
lr
,
current_lr
)
self
.
lr
[
'
lr
'
].
append
(
current_lr
)
logging
.
info
(
f
"
Set learning rate to
{
current_lr
}
"
)
return
K
.
get_value
(
self
.
model
.
optimizer
.
lr
)
class
ModelCheckpointAdvanced
(
ModelCheckpoint
):
"""
Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Specify this callbacks as follow:
lr = CustomLearningRate()
hist = CustomHistory()
callbacks_name =
"
your_custom_path_%s.pickle
"
callbacks = [{
"
callback
"
: lr,
"
path
"
: callbacks_name %
"
lr
"
},
{
"
callback
"
: hist,
"
path
"
: callbacks_name %
"
hist
"
}]
ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks)
Add this ckpt_callbacks as all other additional callbacks to the callback list. IMPORTANT: Always add ckpt_callbacks
as last callback to properly update all tracked callbacks, e.g.
fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks])
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
callbacks
=
kwargs
.
pop
(
"
callbacks
"
)
super
().
__init__
(
*
args
,
**
kwargs
)
def
update_best
(
self
,
hist
):
"""
Update internal best on resuming a training process. Otherwise best is set to +/- inf depending on the
performance metric and the first trained model (first of the resuming training process) will always saved as
best model because its performance will be better than infinity. To prevent this behaviour and compare the
performance with the best model performance, call this method before resuming the training process.
:param hist: The History object from the previous (interrupted) training.
"""
self
.
best
=
hist
.
history
.
get
(
self
.
monitor
)[
-
1
]
def
update_callbacks
(
self
,
callbacks
):
"""
Update all stored callback objects. The argument callbacks needs to follow the same convention like described
in the class description (list of dictionaries). Must be run before resuming a training process.
"""
self
.
callbacks
=
callbacks
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
"""
Save model as usual (see ModelCheckpoint class), but also save additional callbacks.
"""
super
().
on_epoch_end
(
epoch
,
logs
)
for
callback
in
self
.
callbacks
:
file_path
=
callback
[
"
path
"
]
if
self
.
epochs_since_last_save
==
0
and
epoch
!=
0
:
if
self
.
save_best_only
:
current
=
logs
.
get
(
self
.
monitor
)
if
current
==
self
.
best
:
if
self
.
verbose
>
0
:
print
(
'
\n
Epoch %05d: save to %s
'
%
(
epoch
+
1
,
file_path
))
with
open
(
file_path
,
"
wb
"
)
as
f
:
pickle
.
dump
(
callback
[
"
callback
"
],
f
)
else
:
with
open
(
file_path
,
"
wb
"
)
as
f
:
if
self
.
verbose
>
0
:
print
(
'
\n
Epoch %05d: save to %s
'
%
(
epoch
+
1
,
file_path
))
pickle
.
dump
(
callback
[
"
callback
"
],
f
)
Loading