Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Commits
5accca18
Commit
5accca18
authored
5 years ago
by
lukas leufen
Browse files
Options
Downloads
Patches
Plain Diff
a little bit more docs and additional log if training is resumed
parent
09f7e2d2
No related branches found
No related tags found
2 merge requests
!37
include new development
,
!29
Lukas issue030 feat continue training
Pipeline
#28883
passed
5 years ago
Stage: test
Stage: pages
Stage: deploy
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/run_modules/training.py
+6
-3
6 additions, 3 deletions
src/run_modules/training.py
with
6 additions
and
3 deletions
src/run_modules/training.py
+
6
−
3
View file @
5accca18
...
...
@@ -24,7 +24,6 @@ class Training(RunEnvironment):
self
.
batch_size
=
self
.
data_store
.
get
(
"
batch_size
"
,
"
general.model
"
)
self
.
epochs
=
self
.
data_store
.
get
(
"
epochs
"
,
"
general.model
"
)
self
.
checkpoint
:
ModelCheckpointAdvanced
=
self
.
data_store
.
get
(
"
checkpoint
"
,
"
general.model
"
)
# self.callbacks = self.data_store.get("callbacks", "general.model")
self
.
lr_sc
=
self
.
data_store
.
get
(
"
lr_decay
"
,
"
general.model
"
)
self
.
hist
=
self
.
data_store
.
get
(
"
hist
"
,
"
general.model
"
)
self
.
experiment_name
=
self
.
data_store
.
get
(
"
experiment_name
"
,
"
general
"
)
...
...
@@ -38,7 +37,7 @@ class Training(RunEnvironment):
2) make_predict_function():
create predict function before distribution on multiple nodes (detailed information in method description)
3) train():
train
model and save callbacks
start or resume training of
model and save callbacks
4) save_model():
save best model from training as final model
"""
...
...
@@ -76,7 +75,10 @@ class Training(RunEnvironment):
def
train
(
self
)
->
None
:
"""
Perform training using keras fit_generator(). Callbacks are stored locally in the experiment directory. Best
model from training is saved for class variable model.
model from training is saved for class variable model. If the file path of checkpoint is not empty, this method
assumes, that this is not a new training starting from the very beginning, but a resumption from a previous
started but interrupted training (or a stopped and now continued training). Train will automatically load the
locally stored information and the corresponding model and proceed with the already started training.
"""
logging
.
info
(
f
"
Train with
{
len
(
self
.
train_set
)
}
mini batches.
"
)
if
not
os
.
path
.
exists
(
self
.
checkpoint
.
filepath
):
...
...
@@ -88,6 +90,7 @@ class Training(RunEnvironment):
validation_steps
=
len
(
self
.
val_set
),
callbacks
=
[
self
.
lr_sc
,
self
.
hist
,
self
.
checkpoint
])
else
:
logging
.
info
(
"
Found locally stored model and checkpoints. Training is resumed from the last checkpoint.
"
)
lr_filepath
=
self
.
checkpoint
.
callbacks
[
0
][
"
path
"
]
hist_filepath
=
self
.
checkpoint
.
callbacks
[
1
][
"
path
"
]
self
.
lr_sc
=
pickle
.
load
(
open
(
lr_filepath
,
"
rb
"
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment