Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
73a7c6f3
Commit
73a7c6f3
authored
3 years ago
by
BING GONG
Browse files
Options
Downloads
Patches
Plain Diff
Impelment the skill scores plots
parent
4fb114c9
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Pipeline
#90897
failed
3 years ago
Stage: build
Stage: test
Stage: deploy
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
video_prediction_tools/main_scripts/main_meta_postprocess.py
+106
-43
106 additions, 43 deletions
video_prediction_tools/main_scripts/main_meta_postprocess.py
with
106 additions
and
43 deletions
video_prediction_tools/main_scripts/main_meta_postprocess.py
+
106
−
43
View file @
73a7c6f3
...
@@ -26,7 +26,7 @@ import xarray as xr
...
@@ -26,7 +26,7 @@ import xarray as xr
class
MetaPostprocess
(
object
):
class
MetaPostprocess
(
object
):
def
__init__
(
self
,
root_dir
:
str
=
"
/p/project/deepacf/deeprain/video_prediction_shared_folder/
"
,
def
__init__
(
self
,
root_dir
:
str
=
"
/p/project/deepacf/deeprain/video_prediction_shared_folder/
"
,
analysis_config
:
str
=
None
,
metric
:
str
=
"
mse
"
,
exp_id
=
None
,
enable_skill_scores
=
False
):
analysis_config
:
str
=
None
,
metric
:
str
=
"
mse
"
,
exp_id
:
str
=
None
,
enable_skill_scores
:
bool
=
False
):
"""
"""
This class is used for calculating the evaluation metric, analyize the models
'
results and make comparsion
This class is used for calculating the evaluation metric, analyize the models
'
results and make comparsion
args:
args:
...
@@ -35,6 +35,7 @@ class MetaPostprocess(object):
...
@@ -35,6 +35,7 @@ class MetaPostprocess(object):
analysis_dir :str, the path to save the analysis results
analysis_dir :str, the path to save the analysis results
metric :str, based on which evalution metric for comparison,
"
mse
"
,
"
ssim
"
,
"
texture
"
and
"
acc
"
metric :str, based on which evalution metric for comparison,
"
mse
"
,
"
ssim
"
,
"
texture
"
and
"
acc
"
exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot
exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot
enable_skill_scores:bool, the
"""
"""
self
.
root_dir
=
root_dir
self
.
root_dir
=
root_dir
self
.
analysis_config
=
analysis_config
self
.
analysis_config
=
analysis_config
...
@@ -43,21 +44,26 @@ class MetaPostprocess(object):
...
@@ -43,21 +44,26 @@ class MetaPostprocess(object):
self
.
exp_id
=
exp_id
self
.
exp_id
=
exp_id
self
.
enable_skill_scores
=
enable_skill_scores
self
.
enable_skill_scores
=
enable_skill_scores
self
.
models_type
=
[]
self
.
models_type
=
[]
self
.
metric_values
=
[]
# return the shape: [num_results, persi_values, model_values]
self
.
skill_scores
=
[]
# contain the calculated skill scores [num_results, skill_scores_values]
def
__call__
(
self
):
def
__call__
(
self
):
self
.
sanity_check
()
self
.
sanity_check
()
self
.
create_analysis_dir
()
self
.
create_analysis_dir
()
self
.
copy_analysis_config
()
self
.
copy_analysis_config
()
self
.
load_analysis_config
()
self
.
load_analysis_config
()
metric_values
=
self
.
get_metrics_values
()
self
.
get_metrics_values
()
self
.
plot_scores
(
metric_values
)
self
.
calculate_skill_scores
()
# self.calculate_skill_scores()
if
self
.
enable_skill_scores
:
# self.plot_scores()
self
.
plot_skill_scores
()
else
:
self
.
plot_abs_scores
()
def
sanity_check
(
self
):
def
sanity_check
(
self
):
available_metrics
=
[
"
mse
"
,
"
ssim
"
,
"
texture
"
,
"
acc
"
]
self
.
available_metrics
=
[
"
mse
"
,
"
ssim
"
,
"
texture
"
,
"
acc
"
]
if
self
.
metric
not
in
available_metrics
:
if
self
.
metric
not
in
self
.
available_metrics
:
raise
(
"
The
'
metric
'
must be one of the following:
"
,
available_metrics
)
raise
(
"
The
'
metric
'
must be one of the following:
"
,
available_metrics
)
if
type
(
self
.
exp_id
)
is
not
str
:
if
type
(
self
.
exp_id
)
is
not
str
:
raise
(
"'
exp_id
'
must be
'
str
'
type
"
)
raise
(
"'
exp_id
'
must be
'
str
'
type
"
)
...
@@ -112,18 +118,23 @@ class MetaPostprocess(object):
...
@@ -112,18 +118,23 @@ class MetaPostprocess(object):
return
None
return
None
def
get_metrics_values
(
self
):
def
get_metrics_values
(
self
):
"""
get the evaluation metric values of all the results, return a list [results,persi, model]
"""
self
.
get_meta_info
()
self
.
get_meta_info
()
metric_values
=
[]
for
i
,
result_dir
in
enumerate
(
self
.
f
[
"
results
"
].
values
()):
for
i
,
result_dir
in
enumerate
(
self
.
f
[
"
results
"
].
values
()):
vals
=
MetaPostprocess
.
get_one_metric_values
(
result_dir
,
self
.
metric
,
self
.
models_type
[
i
])
vals
=
MetaPostprocess
.
get_one_metric_values
(
result_dir
,
self
.
metric
,
self
.
models_type
[
i
])
metric_values
.
append
(
vals
)
# return the shape: [result_id, persi_values,model_values]
self
.
metric_values
.
append
(
vals
)
print
(
"
4. Get metrics values success
"
)
print
(
"
4. Get metrics values success
"
)
return
metric_values
return
self
.
metric_values
@staticmethod
@staticmethod
def
get_one_metric_values
(
result_dir
:
str
=
None
,
metric
:
str
=
"
mse
"
,
model
:
str
=
None
):
def
get_one_metric_values
(
result_dir
:
str
=
None
,
metric
:
str
=
"
mse
"
,
model
:
str
=
None
):
"""
"""
obtain the metric values (persistence and DL model) in the
"
evaluation_metrics.nc
"
file
obtain the metric values (persistence and DL model) in the
"
evaluation_metrics.nc
"
file
return: list contains the evaluatioin metrics of one result. [persi,model]
"""
"""
filename
=
'
evaluation_metrics.nc
'
filename
=
'
evaluation_metrics.nc
'
filepath
=
os
.
path
.
join
(
result_dir
,
filename
)
filepath
=
os
.
path
.
join
(
result_dir
,
filename
)
...
@@ -138,22 +149,41 @@ class MetaPostprocess(object):
...
@@ -138,22 +149,41 @@ class MetaPostprocess(object):
print
(
e
)
print
(
e
)
def
calculate_skill_scores
(
self
):
def
calculate_skill_scores
(
self
):
if
self
.
enable_skill_scores
:
"""
calculate the skill scores
"""
if
self
.
metric_values
is
None
:
raise
(
"
metric_values should be a list but None is provided
"
)
best_score
=
0
if
self
.
metric
==
"
mse
"
:
pass
pass
# do sometthing
elif
self
.
metric
in
[
"
ssim
"
,
"
acc
"
,
"
texture
"
]:
best_score
=
1
else
:
else
:
pass
raise
(
"
The metric should be one of the following available metrics :
"
,
self
.
available_metrics
)
if
self
.
enable_skill_scores
:
for
i
in
range
(
len
(
self
.
metric_values
)):
skill_val
=
skill_score
(
self
.
metric_values
[
i
][
1
],
self
.
metric_values
[
i
][
0
],
best_score
)
self
.
skill_scores
.
append
(
skill_val
)
return
self
.
skill_scores
else
:
return
None
def
get_lead_time_labels
(
metric_values
:
list
=
None
):
def
get_lead_time_labels
(
self
):
leadtimes
=
metric_values
[
0
][
0
].
shape
[
1
]
leadtimes
=
self
.
metric_values
[
0
][
0
].
shape
[
1
]
leadtimelist
=
[
"
leadhour
"
+
str
(
i
+
1
)
for
i
in
range
(
leadtimes
)]
leadtimelist
=
[
"
leadhour
"
+
str
(
i
+
1
)
for
i
in
range
(
leadtimes
)]
return
leadtimelist
return
leadtimelist
def
config_plots
(
self
,
metric_values
):
def
config_plots
(
self
):
self
.
leadtimelist
=
MetaPostprocess
.
get_lead_time_labels
(
metric_values
)
self
.
leadtimelist
=
self
.
get_lead_time_labels
()
self
.
labels
=
self
.
get_labels
()
self
.
labels
=
self
.
get_labels
()
self
.
markers
=
self
.
f
[
"
markers
"
]
self
.
markers
=
self
.
f
[
"
markers
"
]
self
.
colors
=
self
.
f
[
"
colors
"
]
self
.
colors
=
self
.
f
[
"
colors
"
]
self
.
n_leadtime
=
len
(
self
.
leadtimelist
)
@staticmethod
@staticmethod
def
map_ylabels
(
metric
):
def
map_ylabels
(
metric
):
...
@@ -169,35 +199,27 @@ class MetaPostprocess(object):
...
@@ -169,35 +199,27 @@ class MetaPostprocess(object):
raise
(
"
The metric is not correct!
"
)
raise
(
"
The metric is not correct!
"
)
return
ylabel
return
ylabel
def
plot_scores
(
self
,
metric_values
):
def
plot_abs_scores
(
self
):
self
.
config_plots
()
self
.
config_plots
(
metric_values
)
if
self
.
enable_skill_scores
:
self
.
plot_skill_scores
(
metric_values
)
else
:
self
.
plot_abs_scores
(
metric_values
)
def
plot_abs_scores
(
self
,
metric_values
:
list
=
None
):
n_leadtime
=
len
(
self
.
leadtimelist
)
fig
=
plt
.
figure
(
figsize
=
(
8
,
6
))
fig
=
plt
.
figure
(
figsize
=
(
8
,
6
))
ax
=
fig
.
add_axes
([
0.1
,
0.1
,
0.8
,
0.8
])
ax
=
fig
.
add_axes
([
0.1
,
0.1
,
0.8
,
0.8
])
for
i
in
range
(
len
(
metric_values
)):
for
i
in
range
(
len
(
self
.
metric_values
)):
score_plot
=
np
.
nanquantile
(
metric_values
[
i
][
1
],
0.5
,
axis
=
0
)
score_plot
=
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.5
,
axis
=
0
)
plt
.
plot
(
np
.
arange
(
1
,
1
+
n_leadtime
),
score_plot
,
label
=
self
.
labels
[
i
],
color
=
self
.
colors
[
i
],
plt
.
plot
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
score_plot
,
label
=
self
.
labels
[
i
],
color
=
self
.
colors
[
i
],
marker
=
self
.
markers
[
i
],
markeredgecolor
=
'
k
'
,
linewidth
=
1.2
)
marker
=
self
.
markers
[
i
],
markeredgecolor
=
'
k
'
,
linewidth
=
1.2
)
plt
.
fill_between
(
np
.
arange
(
1
,
1
+
n_leadtime
),
plt
.
fill_between
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
np
.
nanquantile
(
metric_values
[
i
][
1
],
0.95
,
axis
=
0
),
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.95
,
axis
=
0
),
np
.
nanquantile
(
metric_values
[
i
][
1
],
0.05
,
axis
=
0
),
color
=
self
.
colors
[
i
],
alpha
=
0.2
)
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.05
,
axis
=
0
),
color
=
self
.
colors
[
i
],
alpha
=
0.2
)
if
self
.
models_type
[
i
]
==
"
convLSTM
"
:
if
self
.
models_type
[
i
]
==
"
convLSTM
"
:
score_plot
=
np
.
nanquantile
(
metric_values
[
i
][
0
],
0.5
,
axis
=
0
)
score_plot
=
np
.
nanquantile
(
self
.
metric_values
[
i
][
0
],
0.5
,
axis
=
0
)
plt
.
plot
(
np
.
arange
(
1
,
1
+
n_leadtime
),
score_plot
,
label
=
"
Persi_cv
"
+
str
(
i
),
color
=
self
.
colors
[
i
],
plt
.
plot
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
score_plot
,
label
=
"
Persi_cv
"
+
str
(
i
),
marker
=
"
D
"
,
markeredgecolor
=
'
k
'
,
linewidth
=
1.2
)
color
=
self
.
colors
[
i
],
marker
=
"
D
"
,
markeredgecolor
=
'
k
'
,
linewidth
=
1.2
)
plt
.
fill_between
(
np
.
arange
(
1
,
1
+
n_leadtime
),
plt
.
fill_between
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
np
.
nanquantile
(
metric_values
[
i
][
0
],
0.95
,
axis
=
0
),
np
.
nanquantile
(
self
.
metric_values
[
i
][
0
],
0.95
,
axis
=
0
),
np
.
nanquantile
(
metric_values
[
i
][
0
],
0.05
,
axis
=
0
),
color
=
"
b
"
,
alpha
=
0.2
)
np
.
nanquantile
(
self
.
metric_values
[
i
][
0
],
0.05
,
axis
=
0
),
color
=
"
b
"
,
alpha
=
0.2
)
plt
.
yticks
(
fontsize
=
16
)
plt
.
yticks
(
fontsize
=
16
)
plt
.
xticks
(
np
.
arange
(
1
,
13
),
np
.
arange
(
1
,
13
,
1
),
fontsize
=
16
)
plt
.
xticks
(
np
.
arange
(
1
,
13
),
np
.
arange
(
1
,
13
,
1
),
fontsize
=
16
)
...
@@ -206,18 +228,59 @@ class MetaPostprocess(object):
...
@@ -206,18 +228,59 @@ class MetaPostprocess(object):
ylabel
=
MetaPostprocess
.
map_ylabels
(
self
.
metric
)
ylabel
=
MetaPostprocess
.
map_ylabels
(
self
.
metric
)
ax
.
set_xlabel
(
"
Lead time (hours)
"
,
fontsize
=
21
)
ax
.
set_xlabel
(
"
Lead time (hours)
"
,
fontsize
=
21
)
ax
.
set_ylabel
(
ylabel
,
fontsize
=
21
)
ax
.
set_ylabel
(
ylabel
,
fontsize
=
21
)
fig_path
=
os
.
path
.
join
(
self
.
analysis_dir
,
self
.
metric
+
"
abs_values.png
"
)
fig_path
=
os
.
path
.
join
(
self
.
analysis_dir
,
self
.
metric
+
"
_
abs_values.png
"
)
# fig_path = os.path.join(prefix,fig_name)
# fig_path = os.path.join(prefix,fig_name)
plt
.
savefig
(
fig_path
,
bbox_inches
=
"
tight
"
)
plt
.
savefig
(
fig_path
,
bbox_inches
=
"
tight
"
)
plt
.
show
()
plt
.
show
()
plt
.
close
()
plt
.
close
()
print
(
"
The plot saved to {}
"
.
format
(
fig_path
))
print
(
"
The plot saved to {}
"
.
format
(
fig_path
))
def
plot_skill_scores
(
self
):
"""
Plot the skill scores once the enable_skill is True
"""
self
.
config_plots
()
fig
=
plt
.
figure
(
figsize
=
(
8
,
6
))
ax
=
fig
.
add_axes
([
0.1
,
0.1
,
0.8
,
0.8
])
for
i
in
range
(
len
(
self
.
skill_scores
)):
if
self
.
models_type
[
i
]
==
"
convLSTM
"
:
c
=
"
r
"
elif
self
.
models_type
[
i
]
==
"
savp
"
:
c
=
"
b
"
else
:
raise
(
"
current only support convLSTM and SAVP for plotinig the skil scores
"
)
plt
.
boxplot
(
self
.
skill_scores
[
i
],
positions
=
np
.
arange
(
1
,
self
.
n_leadtime
+
1
),
medianprops
=
{
'
color
'
:
c
},
capprops
=
{
'
color
'
:
c
},
boxprops
=
{
'
color
'
:
c
},
showfliers
=
False
)
score_plot
=
np
.
nanquantile
(
self
.
skill_scores
[
i
],
0.5
,
axis
=
0
)
plt
.
plot
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
score_plot
,
color
=
c
,
linewidth
=
1.2
,
label
=
self
.
labels
[
i
])
legend
=
ax
.
legend
(
loc
=
'
upper right
'
,
bbox_to_anchor
=
(
1.46
,
0.95
),
fontsize
=
14
)
plt
.
yticks
(
fontsize
=
16
)
plt
.
xticks
(
np
.
arange
(
1
,
13
),
np
.
arange
(
1
,
13
,
1
),
fontsize
=
16
)
ax
.
set_xlabel
(
"
Lead time (hours)
"
,
fontsize
=
21
)
ax
.
set_ylabel
(
"
Skill scores of {}
"
.
format
(
self
.
metric
),
fontsize
=
21
)
fig_path
=
os
.
path
.
join
(
self
.
analysis_dir
,
self
.
metric
+
"
_skill_scores.png
"
)
plt
.
savefig
(
fig_path
,
bbox_inches
=
"
tight
"
)
plt
.
show
()
plt
.
close
()
print
(
"
The plot saved to {}
"
.
format
(
fig_path
))
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"
--analysis_config
"
,
type
=
str
,
required
=
True
,
help
=
"
The path points to the meta_postprocess configuration file.
"
,
default
=
"
../meta_postprocess_config/meta_config.json
"
)
parser
.
add_argument
(
"
--metric
"
,
help
=
"
Based on which the models are compared, the value should be in one of [mse,ssim,acc,texture]
"
,
default
=
"
mse
"
)
parser
.
add_argument
(
"
--exp_id
"
,
help
=
"
The experiment id which will be used as postfix of the output directory
"
,
default
=
"
exp1
"
)
parser
.
add_argument
(
"
--enable_skill_scores
"
,
help
=
"
compared by skill scores or the absolute evaluation values
"
,
default
=
True
)
args
=
parser
.
parse_args
()
meta
=
MetaPostprocess
(
analysis_config
=
args
.
analysis_config
,
metric
=
args
.
metric
,
exp_id
=
args
.
metric
,
enable_skill_scores
=
args
.
enable_skill_scores
)
meta
()
if
__name__
==
'
__main__
'
:
main
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment