Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
b80c3305
Commit
b80c3305
authored
Feb 28, 2022
by
Michael Langguth
Browse files
Options
Downloads
Patches
Plain Diff
Adapt postprocessing and training to allow for bootstrapping on tiny datasets.
parent
c8adb6bc
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
video_prediction_tools/main_scripts/main_train_models.py
+8
-3
8 additions, 3 deletions
video_prediction_tools/main_scripts/main_train_models.py
video_prediction_tools/main_scripts/main_visualize_postprocess.py
+10
-5
10 additions, 5 deletions
...ediction_tools/main_scripts/main_visualize_postprocess.py
with
18 additions
and
8 deletions
video_prediction_tools/main_scripts/main_train_models.py
+
8
−
3
View file @
b80c3305
...
@@ -562,7 +562,8 @@ class BestModelSelector(object):
...
@@ -562,7 +562,8 @@ class BestModelSelector(object):
Class to select the best performing model from multiple checkpoints created during training
Class to select the best performing model from multiple checkpoints created during training
"""
"""
def
__init__
(
self
,
model_dir
:
str
,
eval_metric
:
str
,
criterion
:
str
=
"
min
"
,
channel
:
int
=
0
,
seed
:
int
=
42
):
def
__init__
(
self
,
model_dir
:
str
,
eval_metric
:
str
,
ltest
:
bool
,
criterion
:
str
=
"
min
"
,
channel
:
int
=
0
,
seed
:
int
=
42
):
"""
"""
Class to retrieve the best model checkpoint. The last one is also retained.
Class to retrieve the best model checkpoint. The last one is also retained.
:param model_dir: path to directory where checkpoints are saved (the trained model output directory)
:param model_dir: path to directory where checkpoints are saved (the trained model output directory)
...
@@ -570,6 +571,7 @@ class BestModelSelector(object):
...
@@ -570,6 +571,7 @@ class BestModelSelector(object):
:param criterion: set to
'
min
'
(
'
max
'
) for negatively (positively) oriented metrics
:param criterion: set to
'
min
'
(
'
max
'
) for negatively (positively) oriented metrics
:param channel: channel of data used for selection
:param channel: channel of data used for selection
:param seed: seed for the Postprocess-instance
:param seed: seed for the Postprocess-instance
:param ltest: flag to allow bootstrapping in Postprocessing on tiny datasets
"""
"""
method
=
self
.
__class__
.
__name__
method
=
self
.
__class__
.
__name__
# sanity check
# sanity check
...
@@ -581,6 +583,7 @@ class BestModelSelector(object):
...
@@ -581,6 +583,7 @@ class BestModelSelector(object):
self
.
channel
=
channel
self
.
channel
=
channel
self
.
metric
=
eval_metric
self
.
metric
=
eval_metric
self
.
checkpoint_base_dir
=
model_dir
self
.
checkpoint_base_dir
=
model_dir
self
.
ltest
=
ltest
self
.
checkpoints_all
=
BestModelSelector
.
get_checkpoints_dirs
(
model_dir
)
self
.
checkpoints_all
=
BestModelSelector
.
get_checkpoints_dirs
(
model_dir
)
self
.
ncheckpoints
=
len
(
self
.
checkpoints_all
)
self
.
ncheckpoints
=
len
(
self
.
checkpoints_all
)
# evaluate all checkpoints...
# evaluate all checkpoints...
...
@@ -604,7 +607,7 @@ class BestModelSelector(object):
...
@@ -604,7 +607,7 @@ class BestModelSelector(object):
results_dir_eager
=
os
.
path
.
join
(
checkpoint
,
"
results_eager
"
)
results_dir_eager
=
os
.
path
.
join
(
checkpoint
,
"
results_eager
"
)
eager_eval
=
Postprocess
(
results_dir
=
results_dir_eager
,
checkpoint
=
checkpoint
,
data_mode
=
"
val
"
,
batch_size
=
32
,
eager_eval
=
Postprocess
(
results_dir
=
results_dir_eager
,
checkpoint
=
checkpoint
,
data_mode
=
"
val
"
,
batch_size
=
32
,
seed
=
self
.
seed
,
eval_metrics
=
[
eval_metric
],
channel
=
self
.
channel
,
frac_data
=
0.33
,
seed
=
self
.
seed
,
eval_metrics
=
[
eval_metric
],
channel
=
self
.
channel
,
frac_data
=
0.33
,
lquick
=
True
)
lquick
=
True
,
ltest
=
self
.
ltest
)
eager_eval
.
run
()
eager_eval
.
run
()
eager_eval
.
handle_eval_metrics
()
eager_eval
.
handle_eval_metrics
()
...
@@ -728,6 +731,8 @@ def main():
...
@@ -728,6 +731,8 @@ def main():
parser
.
add_argument
(
"
--frac_intv_save
"
,
type
=
float
,
default
=
0.01
,
parser
.
add_argument
(
"
--frac_intv_save
"
,
type
=
float
,
default
=
0.01
,
help
=
"
Fraction of all iteration steps to define the saving interval.
"
)
help
=
"
Fraction of all iteration steps to define the saving interval.
"
)
parser
.
add_argument
(
"
--seed
"
,
default
=
1234
,
type
=
int
)
parser
.
add_argument
(
"
--seed
"
,
default
=
1234
,
type
=
int
)
parser
.
add_argument
(
"
--test_mode
"
,
"
-test
"
,
dest
=
"
test_mode
"
,
default
=
False
,
action
=
"
store_true
"
,
help
=
"
Test mode for postprocessing to allow bootstrapping on small datasets.
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# start timing for the whole run
# start timing for the whole run
...
@@ -753,7 +758,7 @@ def main():
...
@@ -753,7 +758,7 @@ def main():
# select best model
# select best model
if
args
.
dataset
==
"
era5
"
and
args
.
frac_start_save
<
1.
:
if
args
.
dataset
==
"
era5
"
and
args
.
frac_start_save
<
1.
:
_
=
BestModelSelector
(
args
.
output_dir
,
"
mse
"
)
_
=
BestModelSelector
(
args
.
output_dir
,
"
mse
"
,
args
.
test_mode
)
timeit_finish
=
time
.
time
()
timeit_finish
=
time
.
time
()
print
(
"
Selecting the best model checkpoint took {0:.2f} minutes.
"
.
format
((
timeit_finish
-
timeit_after_train
)
/
60.
))
print
(
"
Selecting the best model checkpoint took {0:.2f} minutes.
"
.
format
((
timeit_finish
-
timeit_after_train
)
/
60.
))
else
:
else
:
...
...
This diff is collapsed.
Click to expand it.
video_prediction_tools/main_scripts/main_visualize_postprocess.py
+
10
−
5
View file @
b80c3305
...
@@ -37,8 +37,9 @@ class Postprocess(TrainModel):
...
@@ -37,8 +37,9 @@ class Postprocess(TrainModel):
def
__init__
(
self
,
results_dir
:
str
=
None
,
checkpoint
:
str
=
None
,
data_mode
:
str
=
"
test
"
,
batch_size
:
int
=
None
,
def
__init__
(
self
,
results_dir
:
str
=
None
,
checkpoint
:
str
=
None
,
data_mode
:
str
=
"
test
"
,
batch_size
:
int
=
None
,
gpu_mem_frac
:
float
=
None
,
num_stochastic_samples
:
int
=
1
,
stochastic_plot_id
:
int
=
0
,
gpu_mem_frac
:
float
=
None
,
num_stochastic_samples
:
int
=
1
,
stochastic_plot_id
:
int
=
0
,
seed
:
int
=
None
,
channel
:
int
=
0
,
run_mode
:
str
=
"
deterministic
"
,
lquick
:
bool
=
None
,
seed
:
int
=
None
,
channel
:
int
=
0
,
run_mode
:
str
=
"
deterministic
"
,
lquick
:
bool
=
None
,
frac_data
:
float
=
1.
,
eval_metrics
:
List
=
(
"
mse
"
,
"
psnr
"
,
"
ssim
"
,
"
acc
"
),
args
=
None
,
frac_data
:
float
=
1.
,
eval_metrics
:
List
=
(
"
mse
"
,
"
psnr
"
,
"
ssim
"
,
"
acc
"
),
ltest
=
False
,
clim_path
:
str
=
"
/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/climatology_t2m_1991-2020.nc
"
):
clim_path
:
str
=
"
/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/
"
+
"
climatology_t2m_1991-2020.nc
"
,
args
=
None
):
"""
"""
Initialization of the class instance for postprocessing (generation of forecasts from trained model +
Initialization of the class instance for postprocessing (generation of forecasts from trained model +
basic evauation).
basic evauation).
...
@@ -56,6 +57,7 @@ class Postprocess(TrainModel):
...
@@ -56,6 +57,7 @@ class Postprocess(TrainModel):
:param lquick: flag for quick evaluation
:param lquick: flag for quick evaluation
:param frac_data: fraction of dataset to be used for evaluation (only applied when shuffling is active)
:param frac_data: fraction of dataset to be used for evaluation (only applied when shuffling is active)
:param eval_metrics: metrics used to evaluate the trained model
:param eval_metrics: metrics used to evaluate the trained model
:param ltest: flag for test mode to allow bootstrapping on tiny datasets
:param clim_path: the path to the netCDF-file storing climatolgical data
:param clim_path: the path to the netCDF-file storing climatolgical data
:param args: namespace of parsed arguments
:param args: namespace of parsed arguments
"""
"""
...
@@ -86,6 +88,7 @@ class Postprocess(TrainModel):
...
@@ -86,6 +88,7 @@ class Postprocess(TrainModel):
self
.
eval_metrics
=
eval_metrics
self
.
eval_metrics
=
eval_metrics
self
.
nboots_block
=
1000
self
.
nboots_block
=
1000
self
.
block_length
=
7
*
24
# this corresponds to a block length of 7 days in case of hourly forecasts
self
.
block_length
=
7
*
24
# this corresponds to a block length of 7 days in case of hourly forecasts
if
ltest
:
self
.
block_length
=
1
# initialize evrything to get an executable Postprocess instance
# initialize evrything to get an executable Postprocess instance
if
args
is
not
None
:
if
args
is
not
None
:
self
.
save_args_to_option_json
()
# create options.json in results directory
self
.
save_args_to_option_json
()
# create options.json in results directory
...
@@ -1265,8 +1268,10 @@ def main():
...
@@ -1265,8 +1268,10 @@ def main():
help
=
"
(Only) metric to evaluate when quick evaluation (-lquick) is chosen.
"
)
help
=
"
(Only) metric to evaluate when quick evaluation (-lquick) is chosen.
"
)
parser
.
add_argument
(
"
--climatology_file
"
,
"
-clim_fl
"
,
dest
=
"
clim_fl
"
,
type
=
str
,
default
=
False
,
parser
.
add_argument
(
"
--climatology_file
"
,
"
-clim_fl
"
,
dest
=
"
clim_fl
"
,
type
=
str
,
default
=
False
,
help
=
"
The path to the climatology_t2m_1991-2020.nc file
"
)
help
=
"
The path to the climatology_t2m_1991-2020.nc file
"
)
parse
.
add_argument
(
"
--frac_data
"
,
"
-f_dt
"
,
dest
=
"
f_dt
"
,
type
=
float
,
default
=
1
,
parser
.
add_argument
(
"
--frac_data
"
,
"
-f_dt
"
,
dest
=
"
f_dt
"
,
type
=
float
,
default
=
1.
,
help
=
"
fraction of dataset to be used for evaluation (only applied when shuffling is active)
"
)
help
=
"
Fraction of dataset to be used for evaluation (only applied when shuffling is active).
"
)
parser
.
add_argument
(
"
--test_mode
"
,
"
-test
"
,
dest
=
"
test_mode
"
,
default
=
False
,
action
=
"
store_true
"
,
help
=
"
Test mode for postprocessing to allow bootstrapping on small datasets.
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
method
=
os
.
path
.
basename
(
__file__
)
method
=
os
.
path
.
basename
(
__file__
)
...
@@ -1293,7 +1298,7 @@ def main():
...
@@ -1293,7 +1298,7 @@ def main():
batch_size
=
args
.
batch_size
,
num_stochastic_samples
=
args
.
num_stochastic_samples
,
batch_size
=
args
.
batch_size
,
num_stochastic_samples
=
args
.
num_stochastic_samples
,
gpu_mem_frac
=
args
.
gpu_mem_frac
,
seed
=
args
.
seed
,
args
=
args
,
gpu_mem_frac
=
args
.
gpu_mem_frac
,
seed
=
args
.
seed
,
args
=
args
,
eval_metrics
=
eval_metrics
,
channel
=
args
.
channel
,
lquick
=
args
.
lquick
,
eval_metrics
=
eval_metrics
,
channel
=
args
.
channel
,
lquick
=
args
.
lquick
,
clim_path
=
args
.
clim_fl
,
frac_data
=
args
.
frac_data
)
clim_path
=
args
.
clim_fl
,
frac_data
=
args
.
frac_data
,
ltest
=
args
.
test_mode
)
# run the postprocessing
# run the postprocessing
postproc_instance
.
run
()
postproc_instance
.
run
()
postproc_instance
.
handle_eval_metrics
()
postproc_instance
.
handle_eval_metrics
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
sign in
to comment