Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
7fb1d8aa
Commit
7fb1d8aa
authored
2 years ago
by
masak1112
Browse files
Options
Downloads
Patches
Plain Diff
add metric_filesnames as argument in metapostprocessing step
parent
50e26227
No related branches found
No related tags found
No related merge requests found
Pipeline
#104589
passed
2 years ago
Stage: build
Stage: test
Stage: deploy
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
video_prediction_tools/main_scripts/main_meta_postprocess.py
+23
-13
23 additions, 13 deletions
video_prediction_tools/main_scripts/main_meta_postprocess.py
with
23 additions
and
13 deletions
video_prediction_tools/main_scripts/main_meta_postprocess.py
+
23
−
13
View file @
7fb1d8aa
...
...
@@ -31,7 +31,8 @@ def skill_score(tar_score,ref_score,best_score):
class
MetaPostprocess
(
object
):
def
__init__
(
self
,
root_dir
:
str
=
"
/p/project/deepacf/deeprain/video_prediction_shared_folder/
"
,
analysis_config
:
str
=
None
,
metric
:
str
=
"
mse
"
,
exp_id
:
str
=
None
,
enable_skill_scores
:
bool
=
False
,
enable_persit_plot
:
bool
=
False
):
analysis_config
:
str
=
None
,
metric
:
str
=
"
mse
"
,
exp_id
:
str
=
None
,
enable_skill_scores
:
bool
=
False
,
enable_persit_plot
:
bool
=
False
,
metrics_filename
=
"
evaluation_metrics.nc
"
):
"""
This class is used for calculating the evaluation metric, analyize the models
'
results and make comparsion
args:
...
...
@@ -42,6 +43,7 @@ class MetaPostprocess(object):
exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot
enable_skill_scores:bool, enable the skill scores plot
enable_persis_plot: bool, enable the persis prediction in the plot
metrics_filename :str , the .nc file stores the evaluation metrics
"""
self
.
root_dir
=
root_dir
self
.
analysis_config
=
analysis_config
...
...
@@ -50,6 +52,7 @@ class MetaPostprocess(object):
self
.
exp_id
=
exp_id
self
.
persist
=
enable_persit_plot
self
.
enable_skill_scores
=
enable_skill_scores
self
.
metrics_filename
=
metrics_filename
self
.
models_type
=
[]
self
.
metric_values
=
[]
# return the shape: [num_results, persi_values, model_values]
self
.
skill_scores
=
[]
# contain the calculated skill scores [num_results, skill_scores_values]
...
...
@@ -132,27 +135,31 @@ class MetaPostprocess(object):
self
.
get_meta_info
()
for
i
,
result_dir
in
enumerate
(
self
.
f
[
"
results
"
].
values
()):
vals
=
MetaPostprocess
.
get_one_metric_values
(
result_dir
,
self
.
metric
,
self
.
models_type
[
i
],
self
.
enable_skill_scores
)
vals
=
MetaPostprocess
.
get_one_metric_values
(
result_dir
,
self
.
metric
,
self
.
models_type
[
i
],
self
.
enable_skill_scores
,
self
.
metrics_filename
)
self
.
metric_values
.
append
(
vals
)
print
(
"
Get metrics values success
"
)
return
self
.
metric_values
@staticmethod
def
get_one_metric_values
(
result_dir
:
str
=
None
,
metric
:
str
=
"
mse
"
,
model
:
str
=
None
,
enable_skill_scores
:
bool
=
False
):
def
get_one_metric_values
(
result_dir
:
str
=
None
,
metric
:
str
=
"
mse
"
,
model
:
str
=
None
,
enable_skill_scores
:
bool
=
False
,
metrics_filename
:
str
=
"
evaluation_metrics.nc
"
):
"""
obtain the metric values (persistence and DL model) in the
"
evaluation_metrics.nc
"
file
return: list contains the evaluatioin metrics of one result. [persi,model]
"""
filename
=
'
evaluation_metrics.nc
'
filename
=
metrics_filename
filepath
=
os
.
path
.
join
(
result_dir
,
filename
)
try
:
with
xr
.
open_dataset
(
filepath
)
as
dfiles
:
with
xr
.
open_dataset
(
filepath
,
engine
=
"
netcdf4
"
)
as
dfiles
:
if
enable_skill_scores
:
persi
=
np
.
array
(
dfiles
[
'
2t_persistence_{}_bootstrapped
'
.
format
(
metric
)][:])
persi
=
np
.
array
(
dfiles
[
'
2t_persistence_{}_bootstrapped
'
.
format
(
metriic
)][:])
if
persi
.
shape
[
0
]
<
30
:
#20210713T143850_gong1_savp_t2opt_3vars/evaluation_metrics_72x44.nc shape is not correct
persi
=
np
.
transpose
(
persi
)
else
:
persi
=
[]
model
=
np
.
array
(
dfiles
[
'
2t_{}_{}_bootstrapped
'
.
format
(
model
,
metric
)][:])
if
model
.
shape
[
0
]
<
30
:
model
=
np
.
transpose
(
model
)
print
(
"
The values for evaluation metric
'
{}
'
values are obtained from file {}
"
.
format
(
metric
,
filepath
))
return
[
persi
,
model
]
except
Exception
as
e
:
...
...
@@ -201,7 +208,7 @@ class MetaPostprocess(object):
@staticmethod
def
map_ylabels
(
metric
):
if
metric
==
"
mse
"
:
ylabel
=
"
MSE
"
ylabel
=
"
MSE
[K$^2$]
"
elif
metric
==
"
acc
"
:
ylabel
=
"
ACC
"
elif
metric
==
"
ssim
"
:
...
...
@@ -220,7 +227,8 @@ class MetaPostprocess(object):
for
i
in
range
(
len
(
self
.
metric_values
)):
#loop number of test samples
assert
len
(
self
.
metric_values
[
0
])
==
2
score_plot
=
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.5
,
axis
=
0
)
print
(
"
score_plot
"
,
len
(
score_plot
))
print
(
"
self.n_leadtime
"
,
self
.
n_leadtime
)
assert
len
(
score_plot
)
==
self
.
n_leadtime
plt
.
plot
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
list
(
score_plot
),
label
=
self
.
labels
[
i
],
color
=
self
.
colors
[
i
],
marker
=
self
.
markers
[
i
],
markeredgecolor
=
'
k
'
,
linewidth
=
1.2
)
...
...
@@ -240,11 +248,12 @@ class MetaPostprocess(object):
plt
.
yticks
(
fontsize
=
16
)
plt
.
xticks
(
np
.
arange
(
1
,
self
.
n_leadtime
+
1
),
np
.
arange
(
1
,
self
.
n_leadtime
+
1
,
1
),
fontsize
=
16
)
legend
=
ax
.
legend
(
loc
=
'
upper right
'
,
bbox_to_anchor
=
(
1.46
,
0.
95
),
fontsize
=
1
4
)
# 'upper right', bbox_to_anchor=(1.38, 0.8),
legend
=
ax
.
legend
(
loc
=
'
upper right
'
,
bbox_to_anchor
=
(
0.92
,
0.
40
),
fontsize
=
1
2
)
# 'upper right', bbox_to_anchor=(1.38, 0.8),
ylabel
=
MetaPostprocess
.
map_ylabels
(
self
.
metric
)
ax
.
set_xlabel
(
"
Lead time (hours)
"
,
fontsize
=
21
)
ax
.
set_ylabel
(
ylabel
,
fontsize
=
21
)
plt
.
title
(
"
Sensitivity analysis for domain sizes
"
,
fontsize
=
16
)
fig_path
=
os
.
path
.
join
(
self
.
analysis_dir
,
self
.
metric
+
"
_abs_values.png
"
)
# fig_path = os.path.join(prefix,fig_name)
plt
.
savefig
(
fig_path
,
bbox_inches
=
"
tight
"
)
...
...
@@ -293,10 +302,11 @@ def main():
parser
.
add_argument
(
"
--exp_id
"
,
help
=
"
The experiment id which will be used as postfix of the output directory
"
,
default
=
"
exp1
"
)
parser
.
add_argument
(
"
--enable_skill_scores
"
,
help
=
"
compared by skill scores or the absolute evaluation values
"
,
default
=
False
)
parser
.
add_argument
(
"
--enable_persit_plot
"
,
help
=
"
If plot persistent foreasts
"
,
default
=
False
)
parser
.
add_argument
(
"
--metrics_filename
"
,
help
=
"
The .nc file contain the evaluation metrics
"
,
default
=
"
evaluation_metrics.nc
"
)
args
=
parser
.
parse_args
()
meta
=
MetaPostprocess
(
root_dir
=
args
.
root_dir
,
analysis_config
=
args
.
analysis_config
,
metric
=
args
.
metric
,
exp_id
=
args
.
exp_id
,
enable_skill_scores
=
args
.
enable_skill_scores
,
enable_persit_plot
=
args
.
enable_persit_plot
)
enable_skill_scores
=
args
.
enable_skill_scores
,
enable_persit_plot
=
args
.
enable_persit_plot
,
metrics_filename
=
args
.
metrics_filename
)
meta
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment