Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
ebded8a4
Commit
ebded8a4
authored
3 years ago
by
BING GONG
Browse files
Options
Downloads
Patches
Plain Diff
Create checkpoint folders and save checkpoint at certain step
parent
e0de9e0b
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Pipeline
#76838
passed
3 years ago
Stage: build
Stage: test
Stage: deploy
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
video_prediction_tools/main_scripts/main_train_models.py
+53
-11
53 additions, 11 deletions
video_prediction_tools/main_scripts/main_train_models.py
with
53 additions
and
11 deletions
video_prediction_tools/main_scripts/main_train_models.py
+
53
−
11
View file @
ebded8a4
...
@@ -24,12 +24,13 @@ import matplotlib.pyplot as plt
...
@@ -24,12 +24,13 @@ import matplotlib.pyplot as plt
import
pickle
as
pkl
import
pickle
as
pkl
from
model_modules.video_prediction.utils
import
tf_utils
from
model_modules.video_prediction.utils
import
tf_utils
from
general_utils
import
*
from
general_utils
import
*
import
math
class
TrainModel
(
object
):
class
TrainModel
(
object
):
def
__init__
(
self
,
input_dir
:
str
=
None
,
output_dir
:
str
=
None
,
datasplit_dict
:
str
=
None
,
def
__init__
(
self
,
input_dir
:
str
=
None
,
output_dir
:
str
=
None
,
datasplit_dict
:
str
=
None
,
model_hparams_dict
:
str
=
None
,
model
:
str
=
None
,
checkpoint
:
str
=
None
,
dataset
:
str
=
None
,
model_hparams_dict
:
str
=
None
,
model
:
str
=
None
,
checkpoint
:
str
=
None
,
dataset
:
str
=
None
,
gpu_mem_frac
:
float
=
1.
,
seed
:
int
=
None
,
args
=
None
,
diag_intv_frac
:
float
=
0.01
):
gpu_mem_frac
:
float
=
1.
,
seed
:
int
=
None
,
args
=
None
,
diag_intv_frac
:
float
=
0.01
,
frac_save_model_start
:
float
=
None
,
prob_save_model
:
float
=
None
):
"""
"""
Class instance for training the models
Class instance for training the models
:param input_dir: parent directory under which
"
pickle
"
and
"
tfrecords
"
files directiory are located
:param input_dir: parent directory under which
"
pickle
"
and
"
tfrecords
"
files directiory are located
...
@@ -46,6 +47,8 @@ class TrainModel(object):
...
@@ -46,6 +47,8 @@ class TrainModel(object):
steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results
steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results
into a diagnozing intreval of 10 iteration steps (= interval over which validation loss
into a diagnozing intreval of 10 iteration steps (= interval over which validation loss
is averaged to identify best model performance)
is averaged to identify best model performance)
:param frac_save_model_start: fraction of total iterations steps as the start point to save checkpoints
:param prob_save_model: probabability that model are saved to checkpoint (control the frequences of saving model0)
"""
"""
self
.
input_dir
=
os
.
path
.
normpath
(
input_dir
)
self
.
input_dir
=
os
.
path
.
normpath
(
input_dir
)
self
.
output_dir
=
os
.
path
.
normpath
(
output_dir
)
self
.
output_dir
=
os
.
path
.
normpath
(
output_dir
)
...
@@ -58,6 +61,8 @@ class TrainModel(object):
...
@@ -58,6 +61,8 @@ class TrainModel(object):
self
.
seed
=
seed
self
.
seed
=
seed
self
.
args
=
args
self
.
args
=
args
self
.
diag_intv_frac
=
diag_intv_frac
self
.
diag_intv_frac
=
diag_intv_frac
self
.
frac_save_model_start
=
frac_save_model_start
self
.
prob_save_model
=
prob_save_model
# for diagnozing and saving the model during training
# for diagnozing and saving the model during training
self
.
saver_loss
=
None
# set in create_fetches_for_train-method
self
.
saver_loss
=
None
# set in create_fetches_for_train-method
self
.
saver_loss_name
=
None
# set in create_fetches_for_train-method
self
.
saver_loss_name
=
None
# set in create_fetches_for_train-method
...
@@ -77,6 +82,8 @@ class TrainModel(object):
...
@@ -77,6 +82,8 @@ class TrainModel(object):
self
.
create_saver_and_writer
()
self
.
create_saver_and_writer
()
self
.
setup_gpu_config
()
self
.
setup_gpu_config
()
self
.
calculate_samples_and_epochs
()
self
.
calculate_samples_and_epochs
()
self
.
calculate_checkpoint_saver_conf
()
def
set_seed
(
self
):
def
set_seed
(
self
):
"""
"""
...
@@ -240,6 +247,25 @@ class TrainModel(object):
...
@@ -240,6 +247,25 @@ class TrainModel(object):
print
(
"
%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}
"
print
(
"
%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}
"
.
format
(
method
,
batch_size
,
max_epochs
,
self
.
num_examples
,
self
.
steps_per_epoch
,
self
.
total_steps
))
.
format
(
method
,
batch_size
,
max_epochs
,
self
.
num_examples
,
self
.
steps_per_epoch
,
self
.
total_steps
))
def
calculate_checkpoint_saver_conf
(
self
):
"""
Calculate the start step for saving the checkpoint, and the frequences steps to save model
"""
method
=
TrainModel
.
calculate_checkpoint_saver_conf
.
__name__
if
hasattr
(
self
.
total_steps
,
"
attr_name
"
):
raise
SyntaxError
(
"
function
'
calculate_sample_and_epochs
'
is required to call to calcualte the total_step before all function {}
"
.
format
(
method
))
if
self
.
prob_save_model
>
1
or
self
.
prob_save_model
<
0
:
raise
ValueError
(
"
pro_save_model should be less than 1 and larger than 0
"
)
if
self
.
frac_save_model_start
>
1
or
self
.
frac_save_model_start
<
0
:
raise
ValueError
(
"
frac_save_model_start should be less than 1 and larger than 0
"
)
self
.
start_checkpoint_step
=
int
(
math
.
ceil
(
self
.
total_steps
*
self
.
frac_save_model_start
))
self
.
saver_interval_step
=
int
(
math
.
ceil
(
self
.
total_steps
*
self
.
prob_save_model
))
def
restore
(
self
,
sess
,
checkpoints
,
restore_to_checkpoint_mapping
=
None
):
def
restore
(
self
,
sess
,
checkpoints
,
restore_to_checkpoint_mapping
=
None
):
"""
"""
Restore the models checkpoints if the checkpoints is given
Restore the models checkpoints if the checkpoints is given
...
@@ -283,6 +309,22 @@ class TrainModel(object):
...
@@ -283,6 +309,22 @@ class TrainModel(object):
val_losses
=
pkl
.
load
(
f
)
val_losses
=
pkl
.
load
(
f
)
return
train_losses
,
val_losses
return
train_losses
,
val_losses
def
create_checkpoints_folder
(
self
,
step
:
int
=
None
):
"""
Create a folder to store checkpoint at certain step
:param step: the step you want to save the checkpoint
return : dir path to save model
"""
dir_name
=
"
checkpoint_
"
+
str
(
step
)
full_dir_name
=
os
.
path
.
join
(
self
.
output_dir
,
dir_name
)
if
os
.
path
.
isfile
(
os
.
path
.
join
(
full_dir_name
,
"
checkpoints
"
)):
print
(
"
The checkpoint at step {} exists
"
.
format
(
step
))
else
:
os
.
mkdir
(
full_dir_name
)
return
full_dir_name
def
train_model
(
self
):
def
train_model
(
self
):
"""
"""
Start session and train the model by looping over all iteration steps
Start session and train the model by looping over all iteration steps
...
@@ -303,7 +345,7 @@ class TrainModel(object):
...
@@ -303,7 +345,7 @@ class TrainModel(object):
# initialize auxiliary variables
# initialize auxiliary variables
time_per_iteration
=
[]
time_per_iteration
=
[]
run_start_time
=
time
.
time
()
run_start_time
=
time
.
time
()
val_loss_min
=
999.
# perform iteration
# perform iteration
for
step
in
range
(
start_step
,
self
.
total_steps
):
for
step
in
range
(
start_step
,
self
.
total_steps
):
timeit_start
=
time
.
time
()
timeit_start
=
time
.
time
()
...
@@ -324,12 +366,12 @@ class TrainModel(object):
...
@@ -324,12 +366,12 @@ class TrainModel(object):
time_iter
=
time
.
time
()
-
timeit_start
time_iter
=
time
.
time
()
-
timeit_start
time_per_iteration
.
append
(
time_iter
)
time_per_iteration
.
append
(
time_iter
)
print
(
"
%{0}: time needed for this step {1:.3f}s
"
.
format
(
method
,
time_iter
))
print
(
"
%{0}: time needed for this step {1:.3f}s
"
.
format
(
method
,
time_iter
))
if
step
>
self
.
diag_intv_step
and
(
step
%
self
.
diag_intv_step
==
0
or
step
==
self
.
total_steps
-
1
):
lsave
,
val_loss_min
=
TrainModel
.
set_model_saver_flag
(
val_losses
,
val_loss_min
,
self
.
diag_intv_step
)
if
step
>
self
.
start_checkpoint_step
and
(
step
%
self
.
saver_interval_step
==
0
or
step
==
self
.
total_steps
-
1
):
#
save best and final model state
#
create a checkpoint folder for step
if
lsave
or
step
=
=
self
.
total_steps
-
1
:
full_dir_name
=
self
.
create_checkpoints_folder
(
step
=
step
)
self
.
saver
.
save
(
sess
,
os
.
path
.
join
(
self
.
output_dir
,
"
model_best
"
if
lsave
else
"
model_last
"
),
self
.
saver
.
save
(
sess
,
os
.
path
.
join
(
full_dir_name
,
"
model_
"
),
global_step
=
step
)
global_step
=
step
)
# pickle file and plots are always created
# pickle file and plots are always created
TrainModel
.
save_results_to_pkl
(
train_losses
,
val_losses
,
self
.
output_dir
)
TrainModel
.
save_results_to_pkl
(
train_losses
,
val_losses
,
self
.
output_dir
)
TrainModel
.
plot_train
(
train_losses
,
val_losses
,
self
.
saver_loss_name
,
self
.
output_dir
)
TrainModel
.
plot_train
(
train_losses
,
val_losses
,
self
.
saver_loss_name
,
self
.
output_dir
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment