Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
88041cad
Commit
88041cad
authored
3 years ago
by
Bing Gong
Browse files
Options
Downloads
Patches
Plain Diff
update main_meta_postprocess to unenable the persist prediction in plot
parent
daef8d16
Branches
Branches containing commit
No related tags found
No related merge requests found
Pipeline
#92842
canceled
3 years ago
Stage: build
Stage: test
Stage: deploy
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
video_prediction_tools/main_scripts/main_meta_postprocess.py
+28
-17
28 additions, 17 deletions
video_prediction_tools/main_scripts/main_meta_postprocess.py
with
28 additions
and
17 deletions
video_prediction_tools/main_scripts/main_meta_postprocess.py
+
28
−
17
View file @
88041cad
...
@@ -31,7 +31,7 @@ def skill_score(tar_score,ref_score,best_score):
...
@@ -31,7 +31,7 @@ def skill_score(tar_score,ref_score,best_score):
class
MetaPostprocess
(
object
):
class
MetaPostprocess
(
object
):
def
__init__
(
self
,
root_dir
:
str
=
"
/p/project/deepacf/deeprain/video_prediction_shared_folder/
"
,
def
__init__
(
self
,
root_dir
:
str
=
"
/p/project/deepacf/deeprain/video_prediction_shared_folder/
"
,
analysis_config
:
str
=
None
,
metric
:
str
=
"
mse
"
,
exp_id
:
str
=
None
,
enable_skill_scores
:
bool
=
False
):
analysis_config
:
str
=
None
,
metric
:
str
=
"
mse
"
,
exp_id
:
str
=
None
,
enable_skill_scores
:
bool
=
False
,
enable_persit_plot
:
bool
=
False
):
"""
"""
This class is used for calculating the evaluation metric, analyize the models
'
results and make comparsion
This class is used for calculating the evaluation metric, analyize the models
'
results and make comparsion
args:
args:
...
@@ -40,13 +40,15 @@ class MetaPostprocess(object):
...
@@ -40,13 +40,15 @@ class MetaPostprocess(object):
analysis_dir :str, the path to save the analysis results
analysis_dir :str, the path to save the analysis results
metric :str, based on which evalution metric for comparison,
"
mse
"
,
"
ssim
"
,
"
texture
"
and
"
acc
"
metric :str, based on which evalution metric for comparison,
"
mse
"
,
"
ssim
"
,
"
texture
"
and
"
acc
"
exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot
exp_id :str, the given exp_id which is used as the name of postfix of the folder to store the plot
enable_skill_scores:bool, the
enable_skill_scores:bool, enable the skill scores plot
enable_persis_plot: bool, enable the persis prediction in the plot
"""
"""
self
.
root_dir
=
root_dir
self
.
root_dir
=
root_dir
self
.
analysis_config
=
analysis_config
self
.
analysis_config
=
analysis_config
self
.
analysis_dir
=
os
.
path
.
join
(
root_dir
,
"
meta_postprocess
"
,
exp_id
)
self
.
analysis_dir
=
os
.
path
.
join
(
root_dir
,
"
meta_postprocess
"
,
exp_id
)
self
.
metric
=
metric
self
.
metric
=
metric
self
.
exp_id
=
exp_id
self
.
exp_id
=
exp_id
self
.
persist
=
enable_persit_plot
self
.
enable_skill_scores
=
enable_skill_scores
self
.
enable_skill_scores
=
enable_skill_scores
self
.
models_type
=
[]
self
.
models_type
=
[]
self
.
metric_values
=
[]
# return the shape: [num_results, persi_values, model_values]
self
.
metric_values
=
[]
# return the shape: [num_results, persi_values, model_values]
...
@@ -59,8 +61,8 @@ class MetaPostprocess(object):
...
@@ -59,8 +61,8 @@ class MetaPostprocess(object):
self
.
copy_analysis_config
()
self
.
copy_analysis_config
()
self
.
load_analysis_config
()
self
.
load_analysis_config
()
self
.
get_metrics_values
()
self
.
get_metrics_values
()
self
.
calculate_skill_scores
()
if
self
.
enable_skill_scores
:
if
self
.
enable_skill_scores
:
self
.
calculate_skill_scores
()
self
.
plot_skill_scores
()
self
.
plot_skill_scores
()
else
:
else
:
self
.
plot_abs_scores
()
self
.
plot_abs_scores
()
...
@@ -129,13 +131,13 @@ class MetaPostprocess(object):
...
@@ -129,13 +131,13 @@ class MetaPostprocess(object):
self
.
get_meta_info
()
self
.
get_meta_info
()
for
i
,
result_dir
in
enumerate
(
self
.
f
[
"
results
"
].
values
()):
for
i
,
result_dir
in
enumerate
(
self
.
f
[
"
results
"
].
values
()):
vals
=
MetaPostprocess
.
get_one_metric_values
(
result_dir
,
self
.
metric
,
self
.
models_type
[
i
])
vals
=
MetaPostprocess
.
get_one_metric_values
(
result_dir
,
self
.
metric
,
self
.
models_type
[
i
]
,
self
.
enable_skill_scores
)
self
.
metric_values
.
append
(
vals
)
self
.
metric_values
.
append
(
vals
)
print
(
"
4. Get metrics values success
"
)
print
(
"
4. Get metrics values success
"
)
return
self
.
metric_values
return
self
.
metric_values
@staticmethod
@staticmethod
def
get_one_metric_values
(
result_dir
:
str
=
None
,
metric
:
str
=
"
mse
"
,
model
:
str
=
None
):
def
get_one_metric_values
(
result_dir
:
str
=
None
,
metric
:
str
=
"
mse
"
,
model
:
str
=
None
,
enable_skill_scores
:
bool
=
False
):
"""
"""
obtain the metric values (persistence and DL model) in the
"
evaluation_metrics.nc
"
file
obtain the metric values (persistence and DL model) in the
"
evaluation_metrics.nc
"
file
...
@@ -145,7 +147,10 @@ class MetaPostprocess(object):
...
@@ -145,7 +147,10 @@ class MetaPostprocess(object):
filepath
=
os
.
path
.
join
(
result_dir
,
filename
)
filepath
=
os
.
path
.
join
(
result_dir
,
filename
)
try
:
try
:
with
xr
.
open_dataset
(
filepath
)
as
dfiles
:
with
xr
.
open_dataset
(
filepath
)
as
dfiles
:
if
enable_skill_scores
:
persi
=
np
.
array
(
dfiles
[
'
2t_persistence_{}_bootstrapped
'
.
format
(
metric
)][:])
persi
=
np
.
array
(
dfiles
[
'
2t_persistence_{}_bootstrapped
'
.
format
(
metric
)][:])
else
:
persi
=
[]
model
=
np
.
array
(
dfiles
[
'
2t_{}_{}_bootstrapped
'
.
format
(
model
,
metric
)][:])
model
=
np
.
array
(
dfiles
[
'
2t_{}_{}_bootstrapped
'
.
format
(
model
,
metric
)][:])
print
(
"
The values for evaluation metric
'
{}
'
values are obtained from file {}
"
.
format
(
metric
,
filepath
))
print
(
"
The values for evaluation metric
'
{}
'
values are obtained from file {}
"
.
format
(
metric
,
filepath
))
return
[
persi
,
model
]
return
[
persi
,
model
]
...
@@ -179,7 +184,8 @@ class MetaPostprocess(object):
...
@@ -179,7 +184,8 @@ class MetaPostprocess(object):
return
None
return
None
def
get_lead_time_labels
(
self
):
def
get_lead_time_labels
(
self
):
leadtimes
=
self
.
metric_values
[
0
][
0
].
shape
[
1
]
assert
len
(
self
.
metric_values
)
==
2
leadtimes
=
np
.
array
(
self
.
metric_values
[
0
][
1
]).
shape
[
1
]
leadtimelist
=
[
"
leadhour
"
+
str
(
i
+
1
)
for
i
in
range
(
leadtimes
)]
leadtimelist
=
[
"
leadhour
"
+
str
(
i
+
1
)
for
i
in
range
(
leadtimes
)]
return
leadtimelist
return
leadtimelist
...
@@ -209,16 +215,20 @@ class MetaPostprocess(object):
...
@@ -209,16 +215,20 @@ class MetaPostprocess(object):
fig
=
plt
.
figure
(
figsize
=
(
8
,
6
))
fig
=
plt
.
figure
(
figsize
=
(
8
,
6
))
ax
=
fig
.
add_axes
([
0.1
,
0.1
,
0.8
,
0.8
])
ax
=
fig
.
add_axes
([
0.1
,
0.1
,
0.8
,
0.8
])
for
i
in
range
(
len
(
self
.
metric_values
)):
for
i
in
range
(
len
(
self
.
metric_values
)):
#loop number of test samples
assert
len
(
self
.
metric_values
)
==
2
score_plot
=
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.5
,
axis
=
0
)
score_plot
=
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.5
,
axis
=
0
)
plt
.
plot
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
score_plot
,
label
=
self
.
labels
[
i
],
color
=
self
.
colors
[
i
],
assert
len
(
score_plot
)
==
self
.
n_leadtime
plt
.
plot
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
list
(
score_plot
),
label
=
self
.
labels
[
i
],
color
=
self
.
colors
[
i
],
marker
=
self
.
markers
[
i
],
markeredgecolor
=
'
k
'
,
linewidth
=
1.2
)
marker
=
self
.
markers
[
i
],
markeredgecolor
=
'
k
'
,
linewidth
=
1.2
)
plt
.
fill_between
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
plt
.
fill_between
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.95
,
axis
=
0
),
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.95
,
axis
=
0
),
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.05
,
axis
=
0
),
color
=
self
.
colors
[
i
],
np
.
nanquantile
(
self
.
metric_values
[
i
][
1
],
0.05
,
axis
=
0
),
color
=
self
.
colors
[
i
],
alpha
=
0.2
)
alpha
=
0.2
)
#only plot the persist prediction when the enabled
if
self
.
persist
:
if
self
.
models_type
[
i
]
==
"
convLSTM
"
:
score_plot
=
np
.
nanquantile
(
self
.
metric_values
[
i
][
0
],
0.5
,
axis
=
0
)
score_plot
=
np
.
nanquantile
(
self
.
metric_values
[
i
][
0
],
0.5
,
axis
=
0
)
plt
.
plot
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
score_plot
,
label
=
"
Persi_cv
"
+
str
(
i
),
plt
.
plot
(
np
.
arange
(
1
,
1
+
self
.
n_leadtime
),
score_plot
,
label
=
"
Persi_cv
"
+
str
(
i
),
color
=
self
.
colors
[
i
],
marker
=
"
D
"
,
markeredgecolor
=
'
k
'
,
linewidth
=
1.2
)
color
=
self
.
colors
[
i
],
marker
=
"
D
"
,
markeredgecolor
=
'
k
'
,
linewidth
=
1.2
)
...
@@ -227,7 +237,7 @@ class MetaPostprocess(object):
...
@@ -227,7 +237,7 @@ class MetaPostprocess(object):
np
.
nanquantile
(
self
.
metric_values
[
i
][
0
],
0.05
,
axis
=
0
),
color
=
"
b
"
,
alpha
=
0.2
)
np
.
nanquantile
(
self
.
metric_values
[
i
][
0
],
0.05
,
axis
=
0
),
color
=
"
b
"
,
alpha
=
0.2
)
plt
.
yticks
(
fontsize
=
16
)
plt
.
yticks
(
fontsize
=
16
)
plt
.
xticks
(
np
.
arange
(
1
,
13
),
np
.
arange
(
1
,
13
,
1
),
fontsize
=
16
)
plt
.
xticks
(
np
.
arange
(
1
,
self
.
n_leadtime
+
1
),
np
.
arange
(
1
,
self
.
n_leadtime
+
1
,
1
),
fontsize
=
16
)
legend
=
ax
.
legend
(
loc
=
'
upper right
'
,
bbox_to_anchor
=
(
1.46
,
0.95
),
legend
=
ax
.
legend
(
loc
=
'
upper right
'
,
bbox_to_anchor
=
(
1.46
,
0.95
),
fontsize
=
14
)
# 'upper right', bbox_to_anchor=(1.38, 0.8),
fontsize
=
14
)
# 'upper right', bbox_to_anchor=(1.38, 0.8),
ylabel
=
MetaPostprocess
.
map_ylabels
(
self
.
metric
)
ylabel
=
MetaPostprocess
.
map_ylabels
(
self
.
metric
)
...
@@ -262,7 +272,7 @@ class MetaPostprocess(object):
...
@@ -262,7 +272,7 @@ class MetaPostprocess(object):
legend
=
ax
.
legend
(
loc
=
'
upper right
'
,
bbox_to_anchor
=
(
1.46
,
0.95
),
fontsize
=
14
)
legend
=
ax
.
legend
(
loc
=
'
upper right
'
,
bbox_to_anchor
=
(
1.46
,
0.95
),
fontsize
=
14
)
plt
.
yticks
(
fontsize
=
16
)
plt
.
yticks
(
fontsize
=
16
)
plt
.
xticks
(
np
.
arange
(
1
,
13
),
np
.
arange
(
1
,
13
,
1
),
fontsize
=
16
)
plt
.
xticks
(
np
.
arange
(
1
,
self
.
n_leadtime
+
1
),
np
.
arange
(
1
,
self
.
n_leadtime
+
1
,
1
),
fontsize
=
16
)
ax
.
set_xlabel
(
"
Lead time (hours)
"
,
fontsize
=
21
)
ax
.
set_xlabel
(
"
Lead time (hours)
"
,
fontsize
=
21
)
ax
.
set_ylabel
(
"
Skill scores of {}
"
.
format
(
self
.
metric
),
fontsize
=
21
)
ax
.
set_ylabel
(
"
Skill scores of {}
"
.
format
(
self
.
metric
),
fontsize
=
21
)
fig_path
=
os
.
path
.
join
(
self
.
analysis_dir
,
self
.
metric
+
"
_skill_scores.png
"
)
fig_path
=
os
.
path
.
join
(
self
.
analysis_dir
,
self
.
metric
+
"
_skill_scores.png
"
)
...
@@ -279,11 +289,12 @@ def main():
...
@@ -279,11 +289,12 @@ def main():
default
=
"
../meta_postprocess_config/meta_config.json
"
)
default
=
"
../meta_postprocess_config/meta_config.json
"
)
parser
.
add_argument
(
"
--metric
"
,
help
=
"
Based on which the models are compared, the value should be in one of [mse,ssim,acc,texture]
"
,
default
=
"
mse
"
)
parser
.
add_argument
(
"
--metric
"
,
help
=
"
Based on which the models are compared, the value should be in one of [mse,ssim,acc,texture]
"
,
default
=
"
mse
"
)
parser
.
add_argument
(
"
--exp_id
"
,
help
=
"
The experiment id which will be used as postfix of the output directory
"
,
default
=
"
exp1
"
)
parser
.
add_argument
(
"
--exp_id
"
,
help
=
"
The experiment id which will be used as postfix of the output directory
"
,
default
=
"
exp1
"
)
parser
.
add_argument
(
"
--enable_skill_scores
"
,
help
=
"
compared by skill scores or the absolute evaluation values
"
,
default
=
True
)
parser
.
add_argument
(
"
--enable_skill_scores
"
,
help
=
"
compared by skill scores or the absolute evaluation values
"
,
default
=
False
)
parser
.
add_argument
(
"
--enable_persit_plot
"
,
help
=
"
If plot persistent foreasts
"
,
default
=
False
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
meta
=
MetaPostprocess
(
root_dir
=
args
.
root_dir
,
analysis_config
=
args
.
analysis_config
,
metric
=
args
.
metric
,
exp_id
=
args
.
metric
,
meta
=
MetaPostprocess
(
root_dir
=
args
.
root_dir
,
analysis_config
=
args
.
analysis_config
,
metric
=
args
.
metric
,
exp_id
=
args
.
exp_id
,
enable_skill_scores
=
args
.
enable_skill_scores
)
enable_skill_scores
=
args
.
enable_skill_scores
,
enable_persit_plot
=
args
.
enable_persit_plot
)
meta
()
meta
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment