Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Commits
468e4383
Commit
468e4383
authored
5 years ago
by
lukas leufen
Browse files
Options
Downloads
Patches
Plain Diff
worked on split methods
parent
ae85c9ec
No related branches found
No related tags found
2 merge requests
!17
update to v0.4.0
,
!15
new feat split subsets
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
run.py
+2
-2
2 additions, 2 deletions
run.py
src/experiment_setup.py
+6
-0
6 additions, 0 deletions
src/experiment_setup.py
src/modules.py
+68
-6
68 additions, 6 deletions
src/modules.py
test/test_modules.py
+52
-4
52 additions, 4 deletions
test/test_modules.py
with
128 additions
and
12 deletions
run.py
+
2
−
2
View file @
468e4383
...
...
@@ -11,7 +11,7 @@ from src.modules import run, PreProcessing, Training, PostProcessing
def
main
():
with
run
():
exp_setup
=
ExperimentSetup
(
args
,
trainable
=
True
)
exp_setup
=
ExperimentSetup
(
args
,
trainable
=
True
,
stations
=
[
'
DEBW107
'
,
'
DEBY081
'
,
'
DEBW013
'
,
'
DEBW076
'
,
'
DEBW087
'
]
)
PreProcessing
(
exp_setup
)
...
...
@@ -23,7 +23,7 @@ def main():
if
__name__
==
"
__main__
"
:
formatter
=
'
%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]
'
logging
.
basicConfig
(
format
=
formatter
,
level
=
logging
.
DEBUG
)
logging
.
basicConfig
(
format
=
formatter
,
level
=
logging
.
INFO
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--experiment_date
'
,
metavar
=
'
--exp_date
'
,
type
=
str
,
nargs
=
1
,
default
=
None
,
...
...
This diff is collapsed.
Click to expand it.
src/experiment_setup.py
+
6
−
0
View file @
468e4383
...
...
@@ -29,6 +29,9 @@ class ExperimentSetup(object):
self
.
interpolate_dim
=
None
self
.
target_dim
=
None
self
.
target_var
=
None
self
.
train_kwargs
=
None
self
.
val_kwargs
=
None
self
.
test_kwargs
=
None
self
.
setup_experiment
(
**
kwargs
)
def
_set_param
(
self
,
param
,
value
,
default
=
None
):
...
...
@@ -86,3 +89,6 @@ class ExperimentSetup(object):
self
.
_set_param
(
"
interpolate_dim
"
,
kwargs
,
default
=
'
datetime
'
)
self
.
_set_param
(
"
target_dim
"
,
kwargs
,
default
=
'
variables
'
)
self
.
_set_param
(
"
target_var
"
,
kwargs
,
default
=
"
o3
"
)
self
.
_set_param
(
"
train_kwargs
"
,
kwargs
,
default
=
{
"
start
"
:
"
1997-01-01
"
,
"
end
"
:
"
2007-12-31
"
})
self
.
_set_param
(
"
val_kwargs
"
,
kwargs
,
default
=
{
"
start
"
:
"
2008-01-01
"
,
"
end
"
:
"
2009-12-31
"
})
self
.
_set_param
(
"
test_kwargs
"
,
kwargs
,
default
=
{
"
start
"
:
"
2010-01-01
"
,
"
end
"
:
"
2017-12-31
"
})
This diff is collapsed.
Click to expand it.
src/modules.py
+
68
−
6
View file @
468e4383
...
...
@@ -4,7 +4,7 @@ import time
from
src.data_generator
import
DataGenerator
from
src.experiment_setup
import
ExperimentSetup
import
argparse
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Any
,
Tuple
class
run
(
object
):
...
...
@@ -63,15 +63,77 @@ class PreProcessing(run):
kwargs
=
{
'
start
'
:
'
1997-01-01
'
,
'
end
'
:
'
2017-12-31
'
,
'
limit_nan_fill
'
:
1
,
'
window_history
'
:
13
,
'
window_lead_time
'
:
3
,
'
interpolate_method
'
:
'
linear
'
,
'
statistics_per_var
'
:
self
.
setup
.
var_all_dict
,
}
valid_stations
=
self
.
check_valid_stations
(
self
.
setup
.
__dict__
,
kwargs
,
self
.
setup
.
stations
)
args
=
self
.
setup
.
__dict__
args
[
"
stations
"
]
=
valid_stations
valid_stations
=
self
.
check_valid_stations
(
args
,
kwargs
,
self
.
setup
.
stations
)
args
=
self
.
update_key
(
args
,
"
stations
"
,
valid_stations
)
data_gen
=
DataGenerator
(
**
args
,
**
kwargs
)
train
,
val
,
test
=
self
.
split_train_val_test
()
train
,
val
,
test
=
self
.
split_train_val_test
(
data_gen
,
valid_stations
,
args
,
kwargs
)
# print stats of data
def
split_train_val_test
(
self
,
data
,
stations
,
args
,
kwargs
):
train_index
,
val_index
,
test_index
=
self
.
split_set_indices
(
len
(
stations
),
args
[
"
fraction_of_training
"
])
train
=
self
.
create_set_split
(
stations
,
args
,
kwargs
,
train_index
,
"
train
"
)
val
=
self
.
create_set_split
(
stations
,
args
,
kwargs
,
val_index
,
"
val
"
)
test
=
self
.
create_set_split
(
stations
,
args
,
kwargs
,
test_index
,
"
test
"
)
return
train
,
val
,
test
@staticmethod
def
split_set_indices
(
total_length
:
int
,
fraction
:
float
)
->
Tuple
[
slice
,
slice
,
slice
]:
"""
create the training, validation and test subset slice indices for given total_length. The test data consists on
(1-fraction) of total_length (fraction*len:end). Train and validation data therefore are made from fraction of
total_length (0:fraction*len). Train and validation data is split by the factor 0.8 for train and 0.2 for
validation.
:param total_length: list with all objects to split
:param fraction: ratio between test and union of train/val data
:return: slices for each subset in the order: train, val, test
"""
pos_test_split
=
int
(
total_length
*
fraction
)
train_index
=
slice
(
0
,
int
(
pos_test_split
*
0.8
))
val_index
=
slice
(
int
(
pos_test_split
*
0.8
),
pos_test_split
)
test_index
=
slice
(
pos_test_split
,
total_length
)
return
train_index
,
val_index
,
test_index
def
create_set_split
(
self
,
stations
,
args
,
kwargs
,
index_list
,
set_name
):
if
args
[
"
use_all_stations_on_all_data_sets
"
]:
set_stations
=
stations
else
:
set_stations
=
stations
[
index_list
]
logging
.
debug
(
f
"
{
set_name
.
capitalize
()
}
stations (len=
{
set_stations
}
):
{
set_stations
}
"
)
set_kwargs
=
self
.
update_kwargs
(
args
,
kwargs
,
f
"
{
set_name
}
_kwargs
"
)
set_stations
=
self
.
check_valid_stations
(
args
,
set_kwargs
,
set_stations
)
set_args
=
self
.
update_key
(
args
,
"
stations
"
,
set_stations
)
data_set
=
DataGenerator
(
**
set_args
,
**
set_kwargs
)
return
data_set
@staticmethod
def
split_train_val_test
():
return
None
,
None
,
None
def
update_key
(
orig_dict
:
Dict
,
key
:
str
,
value
:
Any
)
->
Dict
:
"""
create copy of `orig_dict` and update given key by value, returns a copied dict. The original input dict
`orig_dict` is not modified by this function.
:param orig_dict: dictionary with arguments that should be updated
:param key: the key to update
:param value: the update itself for given key
:return: updated dict
"""
updated
=
orig_dict
.
copy
()
updated
.
update
({
key
:
value
})
return
updated
@staticmethod
def
update_kwargs
(
args
:
Dict
,
kwargs
:
Dict
,
kwargs_name
:
str
):
"""
copy kwargs and update kwargs parameters by another dictionary stored in args. Not existing keys in kwargs are
created, existing keys overwritten.
:param args: dict with the new kwargs parameters stored with key `kwargs_name`
:param kwargs: dict to update
:param kwargs_name: key in `args` to find the updates for `kwargs`
:return: updated kwargs dict
"""
kwargs_updated
=
kwargs
.
copy
()
if
kwargs_name
in
args
.
keys
()
and
args
[
kwargs_name
]:
kwargs_updated
.
update
(
args
[
kwargs_name
])
return
kwargs_updated
@staticmethod
def
check_valid_stations
(
args
:
Dict
,
kwargs
:
Dict
,
all_stations
:
List
[
str
]):
...
...
This diff is collapsed.
Click to expand it.
test/test_modules.py
+
52
−
4
View file @
468e4383
...
...
@@ -4,8 +4,10 @@ from src.modules import run, PreProcessing
from
src.helpers
import
TimeTracking
import
src.helpers
from
src.experiment_setup
import
ExperimentSetup
from
src.data_generator
import
DataGenerator
import
re
import
mock
import
numpy
as
np
class
pytest_regex
:
...
...
@@ -29,7 +31,7 @@ class TestRun:
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
'
run started
'
)
assert
isinstance
(
r
.
time
,
TimeTracking
)
r
.
do_stuff
(
0.1
)
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
"
run finished after \d+\.\d+s
"
))
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
r
"
run finished after \d+\.\d+s
"
))
def
test_init_del
(
self
,
caplog
):
caplog
.
set_level
(
logging
.
INFO
)
...
...
@@ -37,7 +39,7 @@ class TestRun:
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
'
run started
'
)
r
.
do_stuff
(
0.2
)
del
r
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
"
run finished after \d+\.\d+s
"
))
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
r
"
run finished after \d+\.\d+s
"
))
class
TestPreProcessing
:
...
...
@@ -49,7 +51,7 @@ class TestPreProcessing:
pre
=
PreProcessing
(
setup
)
assert
caplog
.
record_tuples
[
0
]
==
(
'
root
'
,
20
,
'
PreProcessing started
'
)
assert
caplog
.
record_tuples
[
1
]
==
(
'
root
'
,
20
,
'
check valid stations started
'
)
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
'
run for \d+\.\d+s to check 5 station\(s\)
'
))
assert
caplog
.
record_tuples
[
-
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
r
'
run for \d+\.\d+s to check 5 station\(s\)
'
))
def
test_run
(
self
):
pre_processing
=
object
.
__new__
(
PreProcessing
)
...
...
@@ -73,4 +75,50 @@ class TestPreProcessing:
valids
=
pre
.
check_valid_stations
(
pre
.
setup
.
__dict__
,
kwargs
,
pre
.
setup
.
stations
)
assert
valids
==
pre
.
setup
.
stations
assert
caplog
.
record_tuples
[
0
]
==
(
'
root
'
,
20
,
'
check valid stations started
'
)
assert
caplog
.
record_tuples
[
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
'
run for \d+\.\d+s to check 5 station\(s\)
'
))
assert
caplog
.
record_tuples
[
1
]
==
(
'
root
'
,
20
,
pytest_regex
(
r
'
run for \d+\.\d+s to check 5 station\(s\)
'
))
def
test_update_kwargs
(
self
):
args
=
{
"
testName
"
:
{
"
testAttribute
"
:
"
TestValue
"
,
"
optional
"
:
"
2019-11-21
"
}}
kwargs
=
{
"
testAttribute
"
:
"
DefaultValue
"
,
"
defaultAttribute
"
:
3
}
updated
=
PreProcessing
.
update_kwargs
(
args
,
kwargs
,
"
testName
"
)
assert
updated
==
{
"
testAttribute
"
:
"
TestValue
"
,
"
defaultAttribute
"
:
3
,
"
optional
"
:
"
2019-11-21
"
}
assert
kwargs
==
{
"
testAttribute
"
:
"
DefaultValue
"
,
"
defaultAttribute
"
:
3
}
args
=
{
"
testName
"
:
None
}
updated
=
PreProcessing
.
update_kwargs
(
args
,
kwargs
,
"
testName
"
)
assert
updated
==
{
"
testAttribute
"
:
"
DefaultValue
"
,
"
defaultAttribute
"
:
3
}
args
=
{
"
dummy
"
:
"
notMeaningful
"
}
updated
=
PreProcessing
.
update_kwargs
(
args
,
kwargs
,
"
testName
"
)
assert
updated
==
{
"
testAttribute
"
:
"
DefaultValue
"
,
"
defaultAttribute
"
:
3
}
def
test_update_key
(
self
):
orig_dict
=
{
"
Test1
"
:
3
,
"
Test2
"
:
"
4
"
,
"
test3
"
:
[
1
,
2
,
3
]}
f
=
PreProcessing
.
update_key
assert
f
(
orig_dict
,
"
Test2
"
,
4
)
==
{
"
Test1
"
:
3
,
"
Test2
"
:
4
,
"
test3
"
:
[
1
,
2
,
3
]}
assert
orig_dict
==
{
"
Test1
"
:
3
,
"
Test2
"
:
"
4
"
,
"
test3
"
:
[
1
,
2
,
3
]}
assert
f
(
orig_dict
,
"
Test3
"
,
4
)
==
{
"
Test1
"
:
3
,
"
Test2
"
:
"
4
"
,
"
test3
"
:
[
1
,
2
,
3
],
"
Test3
"
:
4
}
def
test_split_set_indices
(
self
):
dummy_list
=
list
(
range
(
0
,
15
))
train
,
val
,
test
=
PreProcessing
.
split_set_indices
(
len
(
dummy_list
),
0.9
)
assert
dummy_list
[
train
]
==
list
(
range
(
0
,
10
))
assert
dummy_list
[
val
]
==
list
(
range
(
10
,
13
))
assert
dummy_list
[
test
]
==
list
(
range
(
13
,
15
))
@mock.patch
(
"
DataGenerator
"
,
return_value
=
object
.
__new__
(
DataGenerator
))
@mock.patch
(
"
DataGenerator[station]
"
,
return_value
=
(
np
.
ones
(
10
),
np
.
zeros
(
10
)))
def
test_create_set_split
(
self
):
stations
=
[
'
DEBW107
'
,
'
DEBY081
'
,
'
DEBW013
'
,
'
DEBW076
'
,
'
DEBW087
'
]
pre
=
object
.
__new__
(
PreProcessing
)
pre
.
setup
=
ExperimentSetup
({},
stations
=
stations
,
var_all_dict
=
{
'
o3
'
:
'
dma8eu
'
,
'
temp
'
:
'
maximum
'
},
train_kwargs
=
{
"
start
"
:
"
2000-01-01
"
,
"
end
"
:
"
2007-12-31
"
})
kwargs
=
{
'
start
'
:
'
1997-01-01
'
,
'
end
'
:
'
2017-12-31
'
,
'
statistics_per_var
'
:
pre
.
setup
.
var_all_dict
,
}
train
=
pre
.
create_set_split
(
stations
,
pre
.
setup
.
__dict__
,
kwargs
,
slice
(
0
,
3
),
"
train
"
)
# stopped here. It is a mess with all the different kwargs, args etc. Restructure the idea of how to implement
# the data sets. Because there are multiple kwargs declarations and which counts in the end. And there are
# multiple declarations of the DataGenerator class. Why this? Is it somehow possible, to select elements from
# this iterator class. Furthermore the names of the DataPrep class is not distinct, because there is no time
# range provided in file's name. Given the case, that first to total DataGen is called with a short period for
# data loading. But then, for the data split (I don't know why this could happen, but it is very likely because
# osf the current multiple declarations of kwargs arguments) the desired time range exceeds the previou
# mentioned and short time range. But nevertheless, the file with the short period is loaded and used (because
# during DataPrep loading, the available range is checked).
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment