Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Commits
ab80fd42
Commit
ab80fd42
authored
5 years ago
by
lukas leufen
Browse files
Options
Downloads
Patches
Plain Diff
added docs for some generator methods
parent
202f2311
Branches
Branches containing commit
Tags
Tags containing commit
2 merge requests
!9
new version v0.2.0
,
!8
data generator
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/data_generator.py
+24
-8
24 additions, 8 deletions
src/data_generator.py
test/test_data_generator.py
+3
-3
3 additions, 3 deletions
test/test_data_generator.py
with
27 additions
and
11 deletions
src/data_generator.py
+
24
−
8
View file @
ab80fd42
...
@@ -5,9 +5,10 @@ import keras
...
@@ -5,9 +5,10 @@ import keras
from
src
import
helpers
from
src
import
helpers
from
src.data_preparation
import
DataPrep
from
src.data_preparation
import
DataPrep
import
os
import
os
from
typing
import
Union
,
List
from
typing
import
Union
,
List
,
Tuple
import
decimal
import
decimal
import
numpy
as
np
import
numpy
as
np
import
xarray
as
xr
class
DataGenerator
(
keras
.
utils
.
Sequence
):
class
DataGenerator
(
keras
.
utils
.
Sequence
):
...
@@ -52,14 +53,24 @@ class DataGenerator(keras.utils.Sequence):
...
@@ -52,14 +53,24 @@ class DataGenerator(keras.utils.Sequence):
"""
"""
return
len
(
self
.
stations
)
return
len
(
self
.
stations
)
def
__iter__
(
self
):
def
__iter__
(
self
)
->
"
DataGenerator
"
:
self
.
iterator
=
0
"""
Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
`_iterator` to 0.
:return:
"""
self
.
_iterator
=
0
return
self
return
self
def
__next__
(
self
):
def
__next__
(
self
)
->
Tuple
[
xr
.
DataArray
,
xr
.
DataArray
]:
if
self
.
iterator
<
self
.
__len__
():
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
if
self
.
_iterator
<
self
.
__len__
():
data
=
self
.
get_data_generator
()
data
=
self
.
get_data_generator
()
self
.
iterator
+=
1
self
.
_
iterator
+=
1
if
data
.
history
is
not
None
and
data
.
label
is
not
None
:
if
data
.
history
is
not
None
and
data
.
label
is
not
None
:
return
data
.
history
.
transpose
(
"
datetime
"
,
"
window
"
,
"
Stations
"
,
"
variables
"
),
\
return
data
.
history
.
transpose
(
"
datetime
"
,
"
window
"
,
"
Stations
"
,
"
variables
"
),
\
data
.
label
.
squeeze
(
"
Stations
"
).
transpose
(
"
datetime
"
,
"
window
"
)
data
.
label
.
squeeze
(
"
Stations
"
).
transpose
(
"
datetime
"
,
"
window
"
)
...
@@ -68,7 +79,12 @@ class DataGenerator(keras.utils.Sequence):
...
@@ -68,7 +79,12 @@ class DataGenerator(keras.utils.Sequence):
else
:
else
:
raise
StopIteration
raise
StopIteration
def
__getitem__
(
self
,
item
:
Union
[
str
,
int
]):
def
__getitem__
(
self
,
item
:
Union
[
str
,
int
])
->
Tuple
[
xr
.
DataArray
,
xr
.
DataArray
]:
"""
Defines the get item method for this generator. Retrieve data from generator and return history and labels.
:param item: station key to choose the data generator.
:return: The generator
'
s time series of history data and its labels
"""
data
=
self
.
get_data_generator
(
key
=
item
)
data
=
self
.
get_data_generator
(
key
=
item
)
return
data
.
history
.
transpose
(
"
datetime
"
,
"
window
"
,
"
Stations
"
,
"
variables
"
),
\
return
data
.
history
.
transpose
(
"
datetime
"
,
"
window
"
,
"
Stations
"
,
"
variables
"
),
\
data
.
label
.
squeeze
(
"
Stations
"
).
transpose
(
"
datetime
"
,
"
window
"
)
data
.
label
.
squeeze
(
"
Stations
"
).
transpose
(
"
datetime
"
,
"
window
"
)
...
@@ -113,7 +129,7 @@ class DataGenerator(keras.utils.Sequence):
...
@@ -113,7 +129,7 @@ class DataGenerator(keras.utils.Sequence):
raise
KeyError
(
f
"
More than one key was given:
{
key
}
"
)
raise
KeyError
(
f
"
More than one key was given:
{
key
}
"
)
# return station name either from key or the recent element from iterator
# return station name either from key or the recent element from iterator
if
key
is
None
:
if
key
is
None
:
return
self
.
stations
[
self
.
iterator
]
return
self
.
stations
[
self
.
_
iterator
]
else
:
else
:
if
isinstance
(
key
,
int
):
if
isinstance
(
key
,
int
):
if
key
<
self
.
__len__
():
if
key
<
self
.
__len__
():
...
...
This diff is collapsed.
Click to expand it.
test/test_data_generator.py
+
3
−
3
View file @
ab80fd42
...
@@ -46,12 +46,12 @@ class TestDataGenerator:
...
@@ -46,12 +46,12 @@ class TestDataGenerator:
assert
hasattr
(
gen
,
'
iterator
'
)
is
False
assert
hasattr
(
gen
,
'
iterator
'
)
is
False
iter
(
gen
)
iter
(
gen
)
assert
hasattr
(
gen
,
'
iterator
'
)
assert
hasattr
(
gen
,
'
iterator
'
)
assert
gen
.
iterator
==
0
assert
gen
.
_
iterator
==
0
def
test_next
(
self
,
gen
):
def
test_next
(
self
,
gen
):
gen
.
kwargs
=
{
'
statistics_per_var
'
:
{
'
o3
'
:
'
dma8eu
'
,
'
temp
'
:
'
maximum
'
}}
gen
.
kwargs
=
{
'
statistics_per_var
'
:
{
'
o3
'
:
'
dma8eu
'
,
'
temp
'
:
'
maximum
'
}}
for
i
,
d
in
enumerate
(
gen
,
start
=
1
):
for
i
,
d
in
enumerate
(
gen
,
start
=
1
):
assert
i
==
gen
.
iterator
assert
i
==
gen
.
_
iterator
def
test_getitem
(
self
,
gen
):
def
test_getitem
(
self
,
gen
):
gen
.
kwargs
=
{
'
statistics_per_var
'
:
{
'
o3
'
:
'
dma8eu
'
,
'
temp
'
:
'
maximum
'
}}
gen
.
kwargs
=
{
'
statistics_per_var
'
:
{
'
o3
'
:
'
dma8eu
'
,
'
temp
'
:
'
maximum
'
}}
...
@@ -74,7 +74,7 @@ class TestDataGenerator:
...
@@ -74,7 +74,7 @@ class TestDataGenerator:
def
test_get_key_representation
(
self
,
gen
):
def
test_get_key_representation
(
self
,
gen
):
gen
.
stations
.
append
(
"
DEBW108
"
)
gen
.
stations
.
append
(
"
DEBW108
"
)
f
=
gen
.
__iter__
()
.
get_station_key
f
=
gen
.
__iter__
.
get_station_key
assert
f
(
None
)
==
"
DEBW107
"
assert
f
(
None
)
==
"
DEBW107
"
assert
f
([
None
])
==
"
DEBW107
"
assert
f
([
None
])
==
"
DEBW107
"
with
pytest
.
raises
(
KeyError
)
as
e
:
with
pytest
.
raises
(
KeyError
)
as
e
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment