Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Commits
caf82558
Commit
caf82558
authored
5 years ago
by
lukas leufen
Browse files
Options
Downloads
Patches
Plain Diff
can load data during pre-processing. /close
#13
parent
e1144c96
Branches
Branches containing commit
Tags
Tags containing commit
2 merge requests
!17
update to v0.4.0
,
!15
new feat split subsets
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/modules.py
+72
-27
72 additions, 27 deletions
src/modules.py
test/test_modules.py
+71
-0
71 additions, 0 deletions
test/test_modules.py
with
143 additions
and
27 deletions
src/modules.py
+
72
−
27
View file @
caf82558
...
@@ -4,6 +4,7 @@ import time
...
@@ -4,6 +4,7 @@ import time
from
src.data_generator
import
DataGenerator
from
src.data_generator
import
DataGenerator
from
src.experiment_setup
import
ExperimentSetup
from
src.experiment_setup
import
ExperimentSetup
import
argparse
import
argparse
from
typing
import
Dict
,
List
class
run
(
object
):
class
run
(
object
):
...
@@ -12,52 +13,96 @@ class run(object):
...
@@ -12,52 +13,96 @@ class run(object):
after finishing the measurement. The duration result is logged.
after finishing the measurement. The duration result is logged.
"""
"""
del_by_exit
=
False
def
__init__
(
self
):
def
__init__
(
self
):
"""
Starts time tracking automatically and logs as info.
"""
self
.
time
=
TimeTracking
()
self
.
time
=
TimeTracking
()
logging
.
info
(
f
"
{
self
.
__class__
.
__name__
}
started
"
)
logging
.
info
(
f
"
{
self
.
__class__
.
__name__
}
started
"
)
def
__del__
(
self
):
def
__del__
(
self
):
"""
This is the class finalizer. The code is not executed if already called by exit method to prevent duplicated
logging (__exit__ is always executed before __del__) it this class was used in a with statement.
"""
if
not
self
.
del_by_exit
:
self
.
time
.
stop
()
self
.
time
.
stop
()
logging
.
info
(
f
"
{
self
.
__class__
.
__name__
}
finished after
{
self
.
time
}
"
)
logging
.
info
(
f
"
{
self
.
__class__
.
__name__
}
finished after
{
self
.
time
}
"
)
self
.
del_by_exit
=
True
def
__enter__
(
self
):
def
__enter__
(
self
):
pass
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
pass
self
.
__del__
()
def
do_stuff
(
self
):
@staticmethod
time
.
sleep
(
2
)
def
do_stuff
(
length
=
2
):
time
.
sleep
(
length
)
class
PreProcessing
(
run
):
class
PreProcessing
(
run
):
def
__init__
(
self
,
setup
):
"""
Pre-process your data by using this class. It includes time tracking and uses the experiment setup to look for data
and stores it if not already in local disk. Further, it provides this data as a generator and checks for valid
stations (in this context: valid=data available). Finally, it splits the data into valid training, validation and
testing subsets.
"""
def
__init__
(
self
,
experiment_setup
:
ExperimentSetup
):
super
().
__init__
()
super
().
__init__
()
self
.
setup
=
setup
self
.
setup
=
experiment_
setup
self
.
kwargs
=
None
self
.
kwargs
=
None
self
.
valid_stations
=
[]
self
.
_run
()
self
.
_run
()
def
_run
(
self
):
def
_run
(
self
):
self
.
kwargs
=
{
'
start
'
:
'
1997-01-01
'
,
'
end
'
:
'
2017-12-31
'
,
'
limit
'
:
1
,
'
window_history
'
:
13
,
kwargs
=
{
'
start
'
:
'
1997-01-01
'
,
'
end
'
:
'
2017-12-31
'
,
'
limit
_nan_fill
'
:
1
,
'
window_history
'
:
13
,
'
window_lead_time
'
:
3
,
'
method
'
:
'
linear
'
,
'
window_lead_time
'
:
3
,
'
interpolate_
method
'
:
'
linear
'
,
'
statistics_per_var
'
:
self
.
setup
.
var_all_dict
,
}
'
statistics_per_var
'
:
self
.
setup
.
var_all_dict
,
}
self
.
check_valid_stations
()
valid_stations
=
self
.
check_valid_stations
(
self
.
setup
.
__dict__
,
kwargs
,
self
.
setup
.
stations
)
def
check_valid_stations
(
self
):
t
=
TimeTracking
logging
.
debug
(
"
check valid stations started
"
)
window_lead_time
=
self
.
kwargs
.
get
(
"
window_lead_time
"
,
None
)
valid_stations
=
[]
for
s
in
self
.
setup
.
stations
:
valid
=
False
args
=
self
.
setup
.
__dict__
args
=
self
.
setup
.
__dict__
args
[
"
stations
"
]
=
s
args
[
"
stations
"
]
=
valid_stations
data_gen
=
DataGenerator
(
**
args
,
**
kwargs
)
train
,
val
,
test
=
self
.
split_train_val_test
()
h
=
DataGenerator
(
**
args
,
**
self
.
kwargs
)
@staticmethod
da_it
=
h
.
get_data_generator
(
s
)
def
split_train_val_test
():
print
(
'
hi
'
)
return
None
,
None
,
None
@staticmethod
def
check_valid_stations
(
args
:
Dict
,
kwargs
:
Dict
,
all_stations
:
List
[
str
]):
"""
Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
:param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
`variables`, `interpolate_dim`, `target_dim`, `target_var`).
:param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`,
`window_lead_time`).
:param all_stations: All stations to check.
:return: Corrected list containing only valid station IDs.
"""
t_outer
=
TimeTracking
()
t_inner
=
TimeTracking
(
start
=
False
)
logging
.
info
(
"
check valid stations started
"
)
valid_stations
=
[]
# all required arguments of the DataGenerator can be found in args, positional arguments in args and kwargs
data_gen
=
DataGenerator
(
**
args
,
**
kwargs
)
for
station
in
all_stations
:
t_inner
.
run
()
try
:
(
history
,
label
)
=
data_gen
[
station
]
valid_stations
.
append
(
station
)
logging
.
debug
(
f
"
{
station
}
: history_shape =
{
history
.
shape
}
"
)
logging
.
debug
(
f
"
{
station
}
: loading time =
{
t_inner
}
"
)
except
AttributeError
:
continue
logging
.
info
(
f
"
run for
{
t_outer
}
to check
{
len
(
all_stations
)
}
station(s)
"
)
return
valid_stations
class
Training
(
run
):
class
Training
(
run
):
...
@@ -82,7 +127,7 @@ if __name__ == "__main__":
...
@@ -82,7 +127,7 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--experiment_date
'
,
metavar
=
'
--exp_date
'
,
type
=
str
,
nargs
=
1
,
default
=
None
,
parser
.
add_argument
(
'
--experiment_date
'
,
metavar
=
'
--exp_date
'
,
type
=
str
,
nargs
=
1
,
default
=
None
,
help
=
"
set experiment date as string
"
)
help
=
"
set experiment date as string
"
)
args
=
parser
.
parse_args
()
parser_
args
=
parser
.
parse_args
()
with
run
():
with
run
():
setup
=
ExperimentSetup
(
args
,
test
=
True
)
setup
=
ExperimentSetup
(
parser_args
,
stations
=
[
'
DEBW107
'
,
'
DEBY081
'
,
'
DEBW013
'
,
'
DEBW076
'
,
'
DEBW087
'
]
)
PreProcessing
(
setup
)
PreProcessing
(
setup
)
This diff is collapsed.
Click to expand it.
test/test_modules.py
0 → 100644
+
71
−
0
View file @
caf82558
import
pytest
import
logging
from
src.modules
import
run
,
PreProcessing
from
src.helpers
import
TimeTracking
import
src.helpers
from
src.experiment_setup
import
ExperimentSetup
import
re
import
mock
class
pytest_regex
:
"""
Assert that a given string meets some expectations.
"""
def
__init__
(
self
,
pattern
,
flags
=
0
):
self
.
_regex
=
re
.
compile
(
pattern
,
flags
)
def
__eq__
(
self
,
actual
):
return
bool
(
self
.
_regex
.
match
(
actual
))
def
__repr__
(
self
):
return
self
.
_regex
.
pattern
class
TestRun
:
def
test_enter_exit
(
self
,
caplog
):
caplog
.
set_level
(
logging
.
INFO
)
with
run
()
as
r
:
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
'
run started
'
)
assert
isinstance
(
r
.
time
,
TimeTracking
)
r
.
do_stuff
(
0.1
)
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
"
run finished after \d+\.\d+s
"
))
def
test_init_del
(
self
,
caplog
):
caplog
.
set_level
(
logging
.
INFO
)
r
=
run
()
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
'
run started
'
)
r
.
do_stuff
(
0.2
)
del
r
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
"
run finished after \d+\.\d+s
"
))
class
TestPreProcessing
:
def
test_init
(
self
,
caplog
):
caplog
.
set_level
(
logging
.
INFO
)
setup
=
ExperimentSetup
({},
stations
=
[
'
DEBW107
'
,
'
DEBY081
'
,
'
DEBW013
'
,
'
DEBW076
'
,
'
DEBW087
'
])
pre
=
PreProcessing
(
setup
)
assert
caplog
.
record_tuples
[
0
]
==
(
'
root
'
,
20
,
'
PreProcessing started
'
)
assert
caplog
.
record_tuples
[
1
]
==
(
'
root
'
,
20
,
'
check valid stations started
'
)
assert
caplog
.
record_tuples
[
2
]
==
(
'
root
'
,
20
,
pytest_regex
(
'
run for \d+\.\d+s to check 5 station\(s\)
'
))
def
test_run
(
self
):
pre_processing
=
object
.
__new__
(
PreProcessing
)
pre_processing
.
setup
=
ExperimentSetup
({},
stations
=
[
'
DEBW107
'
,
'
DEBY081
'
,
'
DEBW013
'
,
'
DEBW076
'
,
'
DEBW087
'
])
assert
pre_processing
.
_run
()
is
None
def
test_split_train_val_test
(
self
):
pass
def
test_check_valid_stations
(
self
,
caplog
):
caplog
.
set_level
(
logging
.
INFO
)
pre
=
object
.
__new__
(
PreProcessing
)
pre
.
setup
=
ExperimentSetup
({},
stations
=
[
'
DEBW107
'
,
'
DEBY081
'
,
'
DEBW013
'
,
'
DEBW076
'
,
'
DEBW087
'
])
kwargs
=
{
'
start
'
:
'
1997-01-01
'
,
'
end
'
:
'
2017-12-31
'
,
'
limit_nan_fill
'
:
1
,
'
window_history
'
:
13
,
'
window_lead_time
'
:
3
,
'
interpolate_method
'
:
'
linear
'
,
'
statistics_per_var
'
:
pre
.
setup
.
var_all_dict
,
}
valids
=
pre
.
check_valid_stations
(
pre
.
setup
.
__dict__
,
kwargs
,
pre
.
setup
.
stations
)
assert
valids
==
pre
.
setup
.
stations
assert
caplog
.
record_tuples
[
0
]
==
(
'
root
'
,
20
,
'
check valid stations started
'
)
assert
caplog
.
record_tuples
[
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
'
run for \d+\.\d+s to check 5 station\(s\)
'
))
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment