Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Commits
15b95a11
Commit
15b95a11
authored
5 years ago
by
lukas leufen
Browse files
Options
Downloads
Patches
Plain Diff
implemented first methods and its tests
parent
430cc664
Branches
Branches containing commit
Tags
Tags containing commit
2 merge requests
!9
new version v0.2.0
,
!8
data generator
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/data_generator.py
+68
-0
68 additions, 0 deletions
src/data_generator.py
test/test_data_generator.py
+49
-0
49 additions, 0 deletions
test/test_data_generator.py
with
117 additions
and
0 deletions
src/data_generator.py
0 → 100644
+
68
−
0
View file @
15b95a11
__author__
=
'
Felix Kleinert, Lukas Leufen
'
__date__
=
'
2019-11-07
'
import
keras
from
src
import
helpers
import
os
from
typing
import
Union
,
List
import
decimal
import
numpy
as
np
class
DataGenerator
(
keras
.
utils
.
Sequence
):
"""
This class is a generator to handle large arrays for machine learning. This class can be used with keras
'
fit_generator and predict_generator. Individual stations are the iterables. This class uses class Dataprep and
returns X, y when an item is called.
Item can be called manually by position (integer) or station id (string). Methods also accept lists with exactly
one entry of integer or string
"""
def
__init__
(
self
,
path
:
str
,
network
:
str
,
stations
:
Union
[
str
,
List
[
str
]],
variables
:
List
[
str
],
dim
:
str
,
target_dim
:
str
,
target_var
:
str
,
**
kwargs
):
self
.
path
=
os
.
path
.
abspath
(
path
)
self
.
network
=
network
self
.
stations
=
helpers
.
to_list
(
stations
)
self
.
variables
=
variables
self
.
dim
=
dim
self
.
target_dim
=
target_dim
self
.
target_var
=
target_var
self
.
kwargs
=
kwargs
self
.
threshold
=
self
.
threshold_setup
()
def
__repr__
(
self
):
"""
display all class attributes
"""
return
f
"
DataGenerator(path=
'
{
self
.
path
}
'
, network=
'
{
self
.
network
}
'
, stations=
{
self
.
stations
}
,
"
\
f
"
variables=
{
self
.
variables
}
, dim=
'
{
self
.
dim
}
'
, target_dim=
'
{
self
.
target_dim
}
'
, target_var=
'"
\
f
"
{
self
.
target_var
}
'
, **
{
self
.
kwargs
}
)
"
def
__len__
(
self
):
"""
display the number of stations
"""
return
len
(
self
.
stations
)
def
__iter__
(
self
):
self
.
iterator
=
0
return
self
def
__next__
(
self
):
raise
NotImplementedError
def
__getitem__
(
self
,
item
):
raise
NotImplementedError
def
threshold_setup
(
self
)
->
List
[
str
]:
"""
set threshold for given min/max and number of steps. defaults are [0, 100] with n=200 steps
:return:
"""
thr_min
=
self
.
kwargs
.
get
(
'
thr_min
'
,
0
)
thr_max
=
self
.
kwargs
.
get
(
'
thr_max
'
,
100
)
thr_number_of_steps
=
self
.
kwargs
.
get
(
'
thr_number_of_steps
'
,
200
)
return
[
str
(
decimal
.
Decimal
(
"
%.4f
"
%
e
))
for
e
in
np
.
linspace
(
thr_min
,
thr_max
,
thr_number_of_steps
)]
This diff is collapsed.
Click to expand it.
test/test_data_generator.py
0 → 100644
+
49
−
0
View file @
15b95a11
import
pytest
import
os
from
src.data_generator
import
DataGenerator
import
logging
import
numpy
as
np
import
xarray
as
xr
import
datetime
as
dt
import
pandas
as
pd
from
operator
import
itemgetter
class
TestDataGenerator
:
@pytest.fixture
def
gen
(
self
):
return
DataGenerator
(
'
data
'
,
'
UBA
'
,
'
DEBW107
'
,
[
'
o3
'
,
'
temp
'
],
'
datetime
'
,
'
datetime
'
,
'
o3
'
)
def
test_init
(
self
,
gen
):
assert
gen
.
path
==
os
.
path
.
abspath
(
'
data
'
)
assert
gen
.
network
==
'
UBA
'
assert
gen
.
stations
==
[
'
DEBW107
'
]
assert
gen
.
variables
==
[
'
o3
'
,
'
temp
'
]
assert
gen
.
dim
==
'
datetime
'
assert
gen
.
target_dim
==
'
datetime
'
assert
gen
.
target_var
==
'
o3
'
assert
gen
.
threshold
is
not
None
def
test_repr
(
self
,
gen
):
path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
data
'
)
assert
gen
.
__repr__
().
rstrip
()
==
f
"
DataGenerator(path=
'
{
path
}
'
, network=
'
UBA
'
, stations=[
'
DEBW107
'
],
"
\
f
"
variables=[
'
o3
'
,
'
temp
'
], dim=
'
datetime
'
, target_dim=
'
datetime
'
,
"
\
f
"
target_var=
'
o3
'
, **{{}})
"
.
rstrip
()
def
test_len
(
self
,
gen
):
assert
len
(
gen
)
==
1
gen
.
stations
=
[
'
station1
'
,
'
station2
'
,
'
station3
'
]
assert
len
(
gen
)
==
3
def
test_threshold_setup
(
self
,
gen
):
def
res
(
arg
,
val
):
gen
.
kwargs
[
arg
]
=
val
return
list
(
map
(
float
,
gen
.
threshold_setup
()))
compare
=
np
.
testing
.
assert_array_almost_equal
assert
compare
(
res
(
''
,
''
),
np
.
linspace
(
0
,
100
,
200
),
decimal
=
3
)
is
None
assert
compare
(
res
(
'
thr_min
'
,
10
),
np
.
linspace
(
10
,
100
,
200
),
decimal
=
3
)
is
None
assert
compare
(
res
(
'
thr_max
'
,
40
),
np
.
linspace
(
10
,
40
,
200
),
decimal
=
3
)
is
None
assert
compare
(
res
(
'
thr_number_of_steps
'
,
10
),
np
.
linspace
(
10
,
40
,
10
),
decimal
=
3
)
is
None
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment