Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Commits
811bd7d6
Commit
811bd7d6
authored
5 years ago
by
lukas leufen
Browse files
Options
Downloads
Patches
Plain Diff
somre renaming and first tests for bootstrap generator
parent
46b3b49f
Branches
Branches containing commit
Tags
Tags containing commit
3 merge requests
!90
WIP: new release update
,
!89
Resolve "release branch / CI on gpu"
,
!61
Resolve "REFAC: clean-up bootstrap workflow"
Pipeline
#31451
passed
5 years ago
Stage: test
Stage: pages
Stage: deploy
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/data_handling/bootstraps.py
+45
-24
45 additions, 24 deletions
src/data_handling/bootstraps.py
test/test_data_handling/test_bootstraps.py
+57
-2
57 additions, 2 deletions
test/test_data_handling/test_bootstraps.py
with
102 additions
and
26 deletions
src/data_handling/bootstraps.py
+
45
−
24
View file @
811bd7d6
...
...
@@ -10,30 +10,20 @@ import xarray as xr
import
os
import
re
from
src
import
helpers
from
typing
import
List
from
typing
import
List
,
Union
class
BootStrapGenerator
:
def
__init__
(
self
,
orig_generator
,
boots
,
chunksize
,
bootstrap_path
):
def
__init__
(
self
,
orig_generator
,
number_of_boots
,
bootstrap_path
):
self
.
orig_generator
:
DataGenerator
=
orig_generator
self
.
stations
=
self
.
orig_generator
.
stations
self
.
variables
=
self
.
orig_generator
.
variables
self
.
boots
=
boots
self
.
chunksize
=
chunksize
self
.
number_of_boots
=
number_of_boots
self
.
bootstrap_path
=
bootstrap_path
self
.
_iterator
=
0
def
__len__
(
self
):
"""
display the number of stations
"""
return
len
(
self
.
orig_generator
)
*
self
.
boots
*
len
(
self
.
variables
)
def
get_labels
(
self
,
key
):
_
,
label
=
self
.
orig_generator
[
key
]
for
_
in
range
(
self
.
boots
):
yield
label
return
len
(
self
.
orig_generator
)
*
self
.
number_of_boots
*
len
(
self
.
variables
)
def
get_generator
(
self
):
"""
...
...
@@ -46,10 +36,10 @@ class BootStrapGenerator:
station
=
self
.
orig_generator
.
get_station_key
(
i
)
logging
.
info
(
f
"
station:
{
station
}
"
)
hist
,
label
=
data
shuffled_data
=
self
.
load_
boot
_data
(
station
)
shuffled_data
=
self
.
load_
shuffled
_data
(
station
,
self
.
variables
)
for
var
in
self
.
variables
:
logging
.
info
(
f
"
var:
{
var
}
"
)
for
boot
in
range
(
self
.
boots
):
logging
.
debug
(
f
"
var:
{
var
}
"
)
for
boot
in
range
(
self
.
number_of_
boots
):
logging
.
debug
(
f
"
boot:
{
boot
}
"
)
boot_hist
=
hist
.
sel
(
variables
=
helpers
.
list_pop
(
self
.
variables
,
var
))
shuffled_var
=
shuffled_data
.
sel
(
variables
=
var
,
boots
=
boot
).
expand_dims
(
"
variables
"
).
drop
(
"
boots
"
).
transpose
(
"
datetime
"
,
"
window
"
,
"
Stations
"
,
"
variables
"
)
...
...
@@ -67,23 +57,54 @@ class BootStrapGenerator:
for
station
in
self
.
stations
:
label
=
self
.
orig_generator
.
get_data_generator
(
station
).
get_transposed_label
()
for
var
in
self
.
variables
:
for
boot
in
range
(
self
.
boots
):
for
boot
in
range
(
self
.
number_of_
boots
):
bootstrap_meta
.
extend
([[
var
,
station
]]
*
len
(
label
))
return
bootstrap_meta
def
get_orig_prediction
(
self
,
path
,
file_name
,
prediction_name
=
"
CNN
"
):
def
get_labels
(
self
,
key
:
Union
[
str
,
int
]):
"""
Reepats labels for given key by the number of boots and yield it one by one.
:param key: key of station (either station name as string or the position in generator as integer)
:return: yields labels for length of boots
"""
_
,
label
=
self
.
orig_generator
[
key
]
for
_
in
range
(
self
.
number_of_boots
):
yield
label
def
get_orig_prediction
(
self
,
path
:
str
,
file_name
:
str
,
prediction_name
:
str
=
"
CNN
"
):
"""
Repeats predictions from given file(_name) in path by the number of boots.
:param path: path to file
:param file_name: file name
:param prediction_name: name of the prediction to select from loaded file
:return: yields predictions for length of boots
"""
file
=
os
.
path
.
join
(
path
,
file_name
)
data
=
xr
.
open_dataarray
(
file
)
for
_
in
range
(
self
.
boots
):
for
_
in
range
(
self
.
number_of_
boots
):
yield
data
.
sel
(
type
=
prediction_name
).
squeeze
()
def
load_boot_data
(
self
,
station
):
def
load_shuffled_data
(
self
,
station
:
str
,
variables
:
List
[
str
])
->
xr
.
DataArray
:
"""
Load shuffled data from bootstrap path. Data is stored as
'
<station>_<var1>_<var2>_..._hist<histsize>_nboots<nboots>_shuffled.nc
'
, e.g.
'
DEBW107_cloudcover_no_no2_temp_u_v_hist13_nboots20_shuffled.nc
'
:param station:
:param variables:
:return: shuffled data as xarray
"""
files
=
os
.
listdir
(
self
.
bootstrap_path
)
regex
=
re
.
compile
(
rf
"
{
station
}
_\w*\.nc
"
)
regex
=
self
.
create_file_regex
(
station
,
variables
)
file_name
=
os
.
path
.
join
(
self
.
bootstrap_path
,
list
(
filter
(
regex
.
search
,
files
))[
0
])
shuffled_data
=
xr
.
open_dataarray
(
file_name
,
chunks
=
100
)
return
shuffled_data
@staticmethod
def
create_file_regex
(
station
,
variables
):
var_regex
=
""
.
join
([
rf
'
(_\w+)*_
{
v
}
(_\w+)*
'
for
v
in
sorted
(
variables
)])
regex
=
re
.
compile
(
rf
"
{
station
}{
var_regex
}
_shuffled\.nc
"
)
return
regex
class
BootStraps
:
...
...
@@ -93,7 +114,7 @@ class BootStraps:
self
.
bootstrap_path
=
bootstrap_path
self
.
chunks
=
self
.
get_chunk_size
()
self
.
create_shuffled_data
()
self
.
_boot_strap_generator
=
BootStrapGenerator
(
self
.
data
,
self
.
number_bootstraps
,
self
.
chunks
,
self
.
bootstrap_path
)
self
.
_boot_strap_generator
=
BootStrapGenerator
(
self
.
data
,
self
.
number_bootstraps
,
self
.
bootstrap_path
)
def
get_boot_strap_meta
(
self
):
return
self
.
_boot_strap_generator
.
get_bootstrap_meta
()
...
...
@@ -135,7 +156,7 @@ class BootStraps:
randomly selected variables. If there is a suitable local file for requested window size and number of
bootstraps, no additional file will be created inside this function.
"""
logging
.
info
(
"
create shuffled bootstrap data
"
)
logging
.
info
(
"
create
/ check
shuffled bootstrap data
"
)
variables_str
=
'
_
'
.
join
(
sorted
(
self
.
data
.
variables
))
window
=
self
.
data
.
window_history_size
for
station
in
self
.
data
.
stations
:
...
...
This diff is collapsed.
Click to expand it.
test/test_data_handling/test_bootstraps.py
+
57
−
2
View file @
811bd7d6
from
src.data_handling.bootstraps
import
BootStraps
from
src.data_handling.bootstraps
import
BootStraps
,
BootStrapGenerator
from
src.data_handling.data_generator
import
DataGenerator
import
pytest
import
os
import
numpy
as
np
import
xarray
as
xr
class
TestBootstraps
:
...
...
@@ -61,4 +63,57 @@ class TestBootstraps:
assert
set
(
np
.
unique
(
res
)).
issubset
({
1
,
2
,
3
})
def
test_create_shuffled_data
(
self
):
pass
\ No newline at end of file
pass
class
TestBootstrapGenerator
:
@pytest.fixture
def
orig_generator
(
self
):
return
DataGenerator
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
data
'
),
'
AIRBASE
'
,
[
'
DEBW107
'
,
'
DEBW013
'
],
[
'
o3
'
,
'
temp
'
],
'
datetime
'
,
'
variables
'
,
'
o3
'
,
start
=
2010
,
end
=
2014
)
@pytest.fixture
def
boot_gen
(
self
,
orig_generator
):
path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
data
'
)
dummy_content
=
xr
.
DataArray
([
1
,
2
,
3
],
dims
=
"
dummy
"
)
dummy_content
.
to_netcdf
(
os
.
path
.
join
(
path
,
"
DEBW107_o3_temp_shuffled.nc
"
))
dummy_content
.
to_netcdf
(
os
.
path
.
join
(
path
,
"
DEBW013_o3_temp_shuffled.nc
"
))
return
BootStrapGenerator
(
orig_generator
,
20
,
path
)
def
test_init
(
self
,
orig_generator
):
gen
=
BootStrapGenerator
(
orig_generator
,
20
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
data
'
))
assert
gen
.
stations
==
[
"
DEBW107
"
,
"
DEBW013
"
]
assert
gen
.
variables
==
[
"
o3
"
,
"
temp
"
]
assert
gen
.
number_of_boots
==
20
assert
gen
.
bootstrap_path
==
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
data
'
)
def
test_len
(
self
,
boot_gen
):
assert
len
(
boot_gen
)
==
80
def
test_get_generator
(
self
,
boot_gen
):
pass
def
test_get_bootstrap_meta
(
self
,
boot_gen
):
pass
def
test_get_labels
(
self
,
boot_gen
):
pass
def
test_get_orig_prediction
(
self
,
boot_gen
):
pass
def
test_load_shuffled_data
(
self
,
boot_gen
):
shuffled_data
=
boot_gen
.
load_shuffled_data
(
"
DEBW107
"
,
[
"
o3
"
,
"
temp
"
])
assert
isinstance
(
shuffled_data
,
xr
.
DataArray
)
assert
all
(
shuffled_data
.
compute
().
values
==
[
1
,
2
,
3
])
def
test_create_file_regex
(
self
,
boot_gen
):
regex
=
boot_gen
.
create_file_regex
(
"
DEBW108
"
,
[
"
o3
"
,
"
temp
"
,
"
h2o
"
])
test_list
=
[
"
DEBW108_o3_test23_test_shuffled.nc
"
,
"
DEBW107_o3_test23_test_shuffled.nc
"
,
"
DEBW108_o3_test23_test.nc
"
,
"
DEBW108_h2o_o3_temp_test_shuffled.nc
"
,
"
DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc
"
]
assert
list
(
filter
(
regex
.
search
,
test_list
))
==
[
"
DEBW108_h2o_o3_temp_test_shuffled.nc
"
,
"
DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc
"
]
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment