Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
MLAir
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
MLAir
Commits
1faf7c2d
Commit
1faf7c2d
authored
5 years ago
by
lukas leufen
Browse files
Options
Downloads
Patches
Plain Diff
modified experiment setup
parent
134288fa
Branches
Branches containing commit
Tags
Tags containing commit
2 merge requests
!17
update to v0.4.0
,
!15
new feat split subsets
Pipeline
#26462
passed
5 years ago
Stage: test
Stage: pages
Stage: deploy
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/modules/experiment_setup.py
+16
-9
16 additions, 9 deletions
src/modules/experiment_setup.py
test/test_modules/test_experiment_setup.py
+41
-8
41 additions, 8 deletions
test/test_modules/test_experiment_setup.py
with
57 additions
and
17 deletions
src/modules/experiment_setup.py
+
16
−
9
View file @
1faf7c2d
...
@@ -27,10 +27,12 @@ class ExperimentSetup(RunEnvironment):
...
@@ -27,10 +27,12 @@ class ExperimentSetup(RunEnvironment):
trainable: Train new model if true, otherwise try to load existing model
trainable: Train new model if true, otherwise try to load existing model
"""
"""
def
__init__
(
self
,
parser_args
=
None
,
var_all_dict
=
None
,
stations
=
None
,
network
=
None
,
variables
=
None
,
target_var
=
"
o3
"
,
def
__init__
(
self
,
parser_args
=
None
,
var_all_dict
=
None
,
stations
=
None
,
network
=
None
,
variables
=
None
,
target_dim
=
None
,
dimensions
=
None
,
interpolate_dim
=
None
,
train_start
=
None
,
train_end
=
None
,
statistics_per_var
=
None
,
start
=
None
,
end
=
None
,
window_history
=
None
,
target_var
=
"
o3
"
,
target_dim
=
None
,
val_start
=
None
,
val_end
=
None
,
test_start
=
None
,
test_end
=
None
,
use_all_stations_on_all_data_sets
=
True
,
window_lead_time
=
None
,
dimensions
=
None
,
interpolate_dim
=
None
,
interpolate_method
=
None
,
trainable
=
False
,
fraction_of_train
=
None
,
experiment_path
=
None
):
limit_nan_fill
=
None
,
train_start
=
None
,
train_end
=
None
,
val_start
=
None
,
val_end
=
None
,
test_start
=
None
,
test_end
=
None
,
use_all_stations_on_all_data_sets
=
True
,
trainable
=
False
,
fraction_of_train
=
None
,
experiment_path
=
None
):
# create run framework
# create run framework
super
().
__init__
()
super
().
__init__
()
...
@@ -52,14 +54,21 @@ class ExperimentSetup(RunEnvironment):
...
@@ -52,14 +54,21 @@ class ExperimentSetup(RunEnvironment):
self
.
_set_param
(
"
stations
"
,
stations
,
default
=
DEFAULT_STATIONS
)
self
.
_set_param
(
"
stations
"
,
stations
,
default
=
DEFAULT_STATIONS
)
self
.
_set_param
(
"
network
"
,
network
,
default
=
"
AIRBASE
"
)
self
.
_set_param
(
"
network
"
,
network
,
default
=
"
AIRBASE
"
)
self
.
_set_param
(
"
variables
"
,
variables
,
default
=
list
(
self
.
data_store
.
get
(
"
var_all_dict
"
,
"
general
"
).
keys
()))
self
.
_set_param
(
"
variables
"
,
variables
,
default
=
list
(
self
.
data_store
.
get
(
"
var_all_dict
"
,
"
general
"
).
keys
()))
self
.
_set_param
(
"
statistics_per_var
"
,
statistics_per_var
,
default
=
self
.
data_store
.
get
(
"
var_all_dict
"
,
"
general
"
))
self
.
_set_param
(
"
start
"
,
start
,
default
=
"
1997-01-01
"
,
scope
=
"
general
"
)
self
.
_set_param
(
"
end
"
,
end
,
default
=
"
2017-12-31
"
,
scope
=
"
general
"
)
self
.
_set_param
(
"
window_history
"
,
window_history
,
default
=
13
)
# target
# target
self
.
_set_param
(
"
target_var
"
,
target_var
,
default
=
"
o3
"
)
self
.
_set_param
(
"
target_var
"
,
target_var
,
default
=
"
o3
"
)
self
.
_set_param
(
"
target_dim
"
,
target_dim
,
default
=
'
variables
'
)
self
.
_set_param
(
"
target_dim
"
,
target_dim
,
default
=
'
variables
'
)
self
.
_set_param
(
"
window_lead_time
"
,
window_lead_time
,
default
=
3
)
# interpolation
# interpolation
self
.
_set_param
(
"
dimensions
"
,
dimensions
,
default
=
{
'
new_index
'
:
[
'
datetime
'
,
'
Stations
'
]})
self
.
_set_param
(
"
dimensions
"
,
dimensions
,
default
=
{
'
new_index
'
:
[
'
datetime
'
,
'
Stations
'
]})
self
.
_set_param
(
"
interpolate_dim
"
,
interpolate_dim
,
default
=
'
datetime
'
)
self
.
_set_param
(
"
interpolate_dim
"
,
interpolate_dim
,
default
=
'
datetime
'
)
self
.
_set_param
(
"
interpolate_method
"
,
interpolate_method
,
default
=
'
linear
'
)
self
.
_set_param
(
"
limit_nan_fill
"
,
limit_nan_fill
,
default
=
1
)
# train parameters
# train parameters
self
.
_set_param
(
"
start
"
,
train_start
,
default
=
"
1997-01-01
"
,
scope
=
"
general.train
"
)
self
.
_set_param
(
"
start
"
,
train_start
,
default
=
"
1997-01-01
"
,
scope
=
"
general.train
"
)
...
@@ -69,7 +78,7 @@ class ExperimentSetup(RunEnvironment):
...
@@ -69,7 +78,7 @@ class ExperimentSetup(RunEnvironment):
self
.
_set_param
(
"
start
"
,
val_start
,
default
=
"
2008-01-01
"
,
scope
=
"
general.val
"
)
self
.
_set_param
(
"
start
"
,
val_start
,
default
=
"
2008-01-01
"
,
scope
=
"
general.val
"
)
self
.
_set_param
(
"
end
"
,
val_end
,
default
=
"
2009-12-31
"
,
scope
=
"
general.val
"
)
self
.
_set_param
(
"
end
"
,
val_end
,
default
=
"
2009-12-31
"
,
scope
=
"
general.val
"
)
#
validation
parameters
#
test
parameters
self
.
_set_param
(
"
start
"
,
test_start
,
default
=
"
2010-01-01
"
,
scope
=
"
general.test
"
)
self
.
_set_param
(
"
start
"
,
test_start
,
default
=
"
2010-01-01
"
,
scope
=
"
general.test
"
)
self
.
_set_param
(
"
end
"
,
test_end
,
default
=
"
2017-12-31
"
,
scope
=
"
general.test
"
)
self
.
_set_param
(
"
end
"
,
test_end
,
default
=
"
2017-12-31
"
,
scope
=
"
general.test
"
)
...
@@ -83,15 +92,13 @@ class ExperimentSetup(RunEnvironment):
...
@@ -83,15 +92,13 @@ class ExperimentSetup(RunEnvironment):
logging
.
debug
(
f
"
set experiment attribute:
{
param
}
(
{
scope
}
)=
{
value
}
"
)
logging
.
debug
(
f
"
set experiment attribute:
{
param
}
(
{
scope
}
)=
{
value
}
"
)
@staticmethod
@staticmethod
def
_get_parser_args
(
args
:
Union
[
Dict
,
argparse
.
Namespace
,
argparse
.
ArgumentParser
])
->
Dict
:
def
_get_parser_args
(
args
:
Union
[
Dict
,
argparse
.
Namespace
])
->
Dict
:
"""
"""
Transform args to dict if given as argparse.Namespace
Transform args to dict if given as argparse.Namespace
:param args: either a dictionary or an argument parser instance
:param args: either a dictionary or an argument parser instance
:return: dictionary with all arguments
:return: dictionary with all arguments
"""
"""
if
isinstance
(
args
,
argparse
.
ArgumentParser
):
if
isinstance
(
args
,
argparse
.
Namespace
):
return
args
.
parse_args
().
__dict__
elif
isinstance
(
args
,
argparse
.
Namespace
):
return
args
.
__dict__
return
args
.
__dict__
elif
isinstance
(
args
,
dict
):
elif
isinstance
(
args
,
dict
):
return
args
return
args
...
...
This diff is collapsed.
Click to expand it.
test/test_modules/test_experiment_setup.py
+
41
−
8
View file @
1faf7c2d
...
@@ -45,15 +45,18 @@ class TestExperimentSetup:
...
@@ -45,15 +45,18 @@ class TestExperimentSetup:
def
test_init_default
(
self
):
def
test_init_default
(
self
):
exp_setup
=
ExperimentSetup
()
exp_setup
=
ExperimentSetup
()
data_store
=
exp_setup
.
data_store
data_store
=
exp_setup
.
data_store
# experiment setup
assert
data_store
.
get
(
"
data_path
"
,
"
general
"
)
==
prepare_host
()
assert
data_store
.
get
(
"
data_path
"
,
"
general
"
)
==
prepare_host
()
assert
data_store
.
get
(
"
trainable
"
,
"
general
"
)
is
False
assert
data_store
.
get
(
"
trainable
"
,
"
general
"
)
is
False
assert
data_store
.
get
(
"
fraction_of_train
"
,
"
general
"
)
==
0.8
assert
data_store
.
get
(
"
fraction_of_train
"
,
"
general
"
)
==
0.8
# set experiment name
assert
data_store
.
get
(
"
experiment_name
"
,
"
general
"
)
==
"
TestExperiment
"
assert
data_store
.
get
(
"
experiment_name
"
,
"
general
"
)
==
"
TestExperiment
"
path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
..
"
,
"
..
"
,
"
TestExperiment
"
))
path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
..
"
,
"
..
"
,
"
TestExperiment
"
))
assert
data_store
.
get
(
"
experiment_path
"
,
"
general
"
)
==
path
assert
data_store
.
get
(
"
experiment_path
"
,
"
general
"
)
==
path
default_var_all_dict
=
{
'
o3
'
:
'
dma8eu
'
,
'
relhum
'
:
'
average_values
'
,
'
temp
'
:
'
maximum
'
,
'
u
'
:
'
average_values
'
,
default_var_all_dict
=
{
'
o3
'
:
'
dma8eu
'
,
'
relhum
'
:
'
average_values
'
,
'
temp
'
:
'
maximum
'
,
'
u
'
:
'
average_values
'
,
'
v
'
:
'
average_values
'
,
'
no
'
:
'
dma8eu
'
,
'
no2
'
:
'
dma8eu
'
,
'
cloudcover
'
:
'
average_values
'
,
'
v
'
:
'
average_values
'
,
'
no
'
:
'
dma8eu
'
,
'
no2
'
:
'
dma8eu
'
,
'
cloudcover
'
:
'
average_values
'
,
'
pblheight
'
:
'
maximum
'
}
'
pblheight
'
:
'
maximum
'
}
# setup for data
assert
data_store
.
get
(
"
var_all_dict
"
,
"
general
"
)
==
default_var_all_dict
assert
data_store
.
get
(
"
var_all_dict
"
,
"
general
"
)
==
default_var_all_dict
default_stations
=
[
'
DEBW107
'
,
'
DEBY081
'
,
'
DEBW013
'
,
'
DEBW076
'
,
'
DEBW087
'
,
'
DEBY052
'
,
'
DEBY032
'
,
'
DEBW022
'
,
default_stations
=
[
'
DEBW107
'
,
'
DEBY081
'
,
'
DEBW013
'
,
'
DEBW076
'
,
'
DEBW087
'
,
'
DEBY052
'
,
'
DEBY032
'
,
'
DEBW022
'
,
'
DEBY004
'
,
'
DEBY020
'
,
'
DEBW030
'
,
'
DEBW037
'
,
'
DEBW031
'
,
'
DEBW015
'
,
'
DEBW073
'
,
'
DEBY039
'
,
'
DEBY004
'
,
'
DEBY020
'
,
'
DEBW030
'
,
'
DEBW037
'
,
'
DEBW031
'
,
'
DEBW015
'
,
'
DEBW073
'
,
'
DEBY039
'
,
...
@@ -65,50 +68,80 @@ class TestExperimentSetup:
...
@@ -65,50 +68,80 @@ class TestExperimentSetup:
assert
data_store
.
get
(
"
stations
"
,
"
general
"
)
==
default_stations
assert
data_store
.
get
(
"
stations
"
,
"
general
"
)
==
default_stations
assert
data_store
.
get
(
"
network
"
,
"
general
"
)
==
"
AIRBASE
"
assert
data_store
.
get
(
"
network
"
,
"
general
"
)
==
"
AIRBASE
"
assert
data_store
.
get
(
"
variables
"
,
"
general
"
)
==
list
(
default_var_all_dict
.
keys
())
assert
data_store
.
get
(
"
variables
"
,
"
general
"
)
==
list
(
default_var_all_dict
.
keys
())
assert
data_store
.
get
(
"
statistics_per_var
"
,
"
general
"
)
==
default_var_all_dict
assert
data_store
.
get
(
"
start
"
,
"
general
"
)
==
"
1997-01-01
"
assert
data_store
.
get
(
"
end
"
,
"
general
"
)
==
"
2017-12-31
"
assert
data_store
.
get
(
"
window_history
"
,
"
general
"
)
==
13
# target
assert
data_store
.
get
(
"
target_var
"
,
"
general
"
)
==
"
o3
"
assert
data_store
.
get
(
"
target_var
"
,
"
general
"
)
==
"
o3
"
assert
data_store
.
get
(
"
target_dim
"
,
"
general
"
)
==
"
variables
"
assert
data_store
.
get
(
"
target_dim
"
,
"
general
"
)
==
"
variables
"
assert
data_store
.
get
(
"
window_lead_time
"
,
"
general
"
)
==
3
# interpolation
assert
data_store
.
get
(
"
dimensions
"
,
"
general
"
)
==
{
'
new_index
'
:
[
'
datetime
'
,
'
Stations
'
]}
assert
data_store
.
get
(
"
dimensions
"
,
"
general
"
)
==
{
'
new_index
'
:
[
'
datetime
'
,
'
Stations
'
]}
assert
data_store
.
get
(
"
interpolate_dim
"
,
"
general
"
)
==
"
datetime
"
assert
data_store
.
get
(
"
interpolate_dim
"
,
"
general
"
)
==
"
datetime
"
with
pytest
.
raises
(
NameNotFoundInScope
):
assert
data_store
.
get
(
"
interpolate_method
"
,
"
general
"
)
==
"
linear
"
data_store
.
get
(
"
start
"
,
"
general
"
)
assert
data_store
.
get
(
"
limit_nan_fill
"
,
"
general
"
)
==
1
with
pytest
.
raises
(
NameNotFoundInScope
):
# train parameters
data_store
.
get
(
"
end
"
,
"
general
"
)
assert
data_store
.
get
(
"
start
"
,
"
general.train
"
)
==
"
1997-01-01
"
assert
data_store
.
get
(
"
start
"
,
"
general.train
"
)
==
"
1997-01-01
"
assert
data_store
.
get
(
"
end
"
,
"
general.train
"
)
==
"
2007-12-31
"
assert
data_store
.
get
(
"
end
"
,
"
general.train
"
)
==
"
2007-12-31
"
# validation parameters
assert
data_store
.
get
(
"
start
"
,
"
general.val
"
)
==
"
2008-01-01
"
assert
data_store
.
get
(
"
start
"
,
"
general.val
"
)
==
"
2008-01-01
"
assert
data_store
.
get
(
"
end
"
,
"
general.val
"
)
==
"
2009-12-31
"
assert
data_store
.
get
(
"
end
"
,
"
general.val
"
)
==
"
2009-12-31
"
# test parameters
assert
data_store
.
get
(
"
start
"
,
"
general.test
"
)
==
"
2010-01-01
"
assert
data_store
.
get
(
"
start
"
,
"
general.test
"
)
==
"
2010-01-01
"
assert
data_store
.
get
(
"
end
"
,
"
general.test
"
)
==
"
2017-12-31
"
assert
data_store
.
get
(
"
end
"
,
"
general.test
"
)
==
"
2017-12-31
"
# use all stations on all data sets (train, val, test)
assert
data_store
.
get
(
"
use_all_stations_on_all_data_sets
"
,
"
general
"
)
is
True
def
test_init_no_default
(
self
):
def
test_init_no_default
(
self
):
experiment_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
..
"
,
"
data
"
,
"
testExperimentFolder
"
))
experiment_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
..
"
,
"
data
"
,
"
testExperimentFolder
"
))
kwargs
=
dict
(
parser_args
=
{
"
experiment_date
"
:
"
TODAY
"
},
kwargs
=
dict
(
parser_args
=
{
"
experiment_date
"
:
"
TODAY
"
},
var_all_dict
=
{
'
o3
'
:
'
dma8eu
'
,
'
relhum
'
:
'
average_values
'
,
'
temp
'
:
'
maximum
'
},
var_all_dict
=
{
'
o3
'
:
'
dma8eu
'
,
'
relhum
'
:
'
average_values
'
,
'
temp
'
:
'
maximum
'
},
stations
=
[
'
DEBY053
'
,
'
DEBW059
'
,
'
DEBW027
'
],
network
=
"
INTERNET
"
,
variables
=
[
"
o3
"
,
"
temp
"
],
stations
=
[
'
DEBY053
'
,
'
DEBW059
'
,
'
DEBW027
'
],
network
=
"
INTERNET
"
,
variables
=
[
"
o3
"
,
"
temp
"
],
target_var
=
"
temp
"
,
target_dim
=
"
target
"
,
dimensions
=
"
dim1
"
,
interpolate_dim
=
"
int_dim
"
,
statistics_per_var
=
None
,
start
=
"
1999-01-01
"
,
end
=
"
2001-01-01
"
,
window_history
=
4
,
train_start
=
"
2000-01-01
"
,
train_end
=
"
2000-01-02
"
,
val_start
=
"
2000-01-03
"
,
val_end
=
"
2000-01-04
"
,
target_var
=
"
temp
"
,
target_dim
=
"
target
"
,
window_lead_time
=
10
,
dimensions
=
"
dim1
"
,
test_start
=
"
2000-01-05
"
,
test_end
=
"
2000-01-06
"
,
use_all_stations_on_all_data_sets
=
False
,
interpolate_dim
=
"
int_dim
"
,
interpolate_method
=
"
cubic
"
,
limit_nan_fill
=
5
,
train_start
=
"
2000-01-01
"
,
trainable
=
True
,
fraction_of_train
=
0.5
,
experiment_path
=
experiment_path
)
train_end
=
"
2000-01-02
"
,
val_start
=
"
2000-01-03
"
,
val_end
=
"
2000-01-04
"
,
test_start
=
"
2000-01-05
"
,
test_end
=
"
2000-01-06
"
,
use_all_stations_on_all_data_sets
=
False
,
trainable
=
True
,
fraction_of_train
=
0.5
,
experiment_path
=
experiment_path
)
exp_setup
=
ExperimentSetup
(
**
kwargs
)
exp_setup
=
ExperimentSetup
(
**
kwargs
)
data_store
=
exp_setup
.
data_store
data_store
=
exp_setup
.
data_store
# experiment setup
assert
data_store
.
get
(
"
data_path
"
,
"
general
"
)
==
prepare_host
()
assert
data_store
.
get
(
"
data_path
"
,
"
general
"
)
==
prepare_host
()
assert
data_store
.
get
(
"
trainable
"
,
"
general
"
)
is
True
assert
data_store
.
get
(
"
trainable
"
,
"
general
"
)
is
True
assert
data_store
.
get
(
"
fraction_of_train
"
,
"
general
"
)
==
0.5
assert
data_store
.
get
(
"
fraction_of_train
"
,
"
general
"
)
==
0.5
# set experiment name
assert
data_store
.
get
(
"
experiment_name
"
,
"
general
"
)
==
"
TODAY_network/
"
assert
data_store
.
get
(
"
experiment_name
"
,
"
general
"
)
==
"
TODAY_network/
"
path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
..
"
,
"
data
"
,
"
testExperimentFolder
"
))
path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
..
"
,
"
data
"
,
"
testExperimentFolder
"
))
assert
data_store
.
get
(
"
experiment_path
"
,
"
general
"
)
==
path
assert
data_store
.
get
(
"
experiment_path
"
,
"
general
"
)
==
path
# setup for data
assert
data_store
.
get
(
"
var_all_dict
"
,
"
general
"
)
==
{
'
o3
'
:
'
dma8eu
'
,
'
relhum
'
:
'
average_values
'
,
assert
data_store
.
get
(
"
var_all_dict
"
,
"
general
"
)
==
{
'
o3
'
:
'
dma8eu
'
,
'
relhum
'
:
'
average_values
'
,
'
temp
'
:
'
maximum
'
}
'
temp
'
:
'
maximum
'
}
assert
data_store
.
get
(
"
stations
"
,
"
general
"
)
==
[
'
DEBY053
'
,
'
DEBW059
'
,
'
DEBW027
'
]
assert
data_store
.
get
(
"
stations
"
,
"
general
"
)
==
[
'
DEBY053
'
,
'
DEBW059
'
,
'
DEBW027
'
]
assert
data_store
.
get
(
"
network
"
,
"
general
"
)
==
"
INTERNET
"
assert
data_store
.
get
(
"
network
"
,
"
general
"
)
==
"
INTERNET
"
assert
data_store
.
get
(
"
variables
"
,
"
general
"
)
==
[
"
o3
"
,
"
temp
"
]
assert
data_store
.
get
(
"
variables
"
,
"
general
"
)
==
[
"
o3
"
,
"
temp
"
]
assert
data_store
.
get
(
"
statistics_per_var
"
,
"
general
"
)
==
{
'
o3
'
:
'
dma8eu
'
,
'
relhum
'
:
'
average_values
'
,
'
temp
'
:
'
maximum
'
}
assert
data_store
.
get
(
"
start
"
,
"
general
"
)
==
"
1999-01-01
"
assert
data_store
.
get
(
"
end
"
,
"
general
"
)
==
"
2001-01-01
"
assert
data_store
.
get
(
"
window_history
"
,
"
general
"
)
==
4
# target
assert
data_store
.
get
(
"
target_var
"
,
"
general
"
)
==
"
temp
"
assert
data_store
.
get
(
"
target_var
"
,
"
general
"
)
==
"
temp
"
assert
data_store
.
get
(
"
target_dim
"
,
"
general
"
)
==
"
target
"
assert
data_store
.
get
(
"
target_dim
"
,
"
general
"
)
==
"
target
"
assert
data_store
.
get
(
"
window_lead_time
"
,
"
general
"
)
==
10
# interpolation
assert
data_store
.
get
(
"
dimensions
"
,
"
general
"
)
==
"
dim1
"
assert
data_store
.
get
(
"
dimensions
"
,
"
general
"
)
==
"
dim1
"
assert
data_store
.
get
(
"
interpolate_dim
"
,
"
general
"
)
==
"
int_dim
"
assert
data_store
.
get
(
"
interpolate_dim
"
,
"
general
"
)
==
"
int_dim
"
assert
data_store
.
get
(
"
interpolate_method
"
,
"
general
"
)
==
"
cubic
"
assert
data_store
.
get
(
"
limit_nan_fill
"
,
"
general
"
)
==
5
# train parameters
assert
data_store
.
get
(
"
start
"
,
"
general.train
"
)
==
"
2000-01-01
"
assert
data_store
.
get
(
"
start
"
,
"
general.train
"
)
==
"
2000-01-01
"
assert
data_store
.
get
(
"
end
"
,
"
general.train
"
)
==
"
2000-01-02
"
assert
data_store
.
get
(
"
end
"
,
"
general.train
"
)
==
"
2000-01-02
"
# validation parameters
assert
data_store
.
get
(
"
start
"
,
"
general.val
"
)
==
"
2000-01-03
"
assert
data_store
.
get
(
"
start
"
,
"
general.val
"
)
==
"
2000-01-03
"
assert
data_store
.
get
(
"
end
"
,
"
general.val
"
)
==
"
2000-01-04
"
assert
data_store
.
get
(
"
end
"
,
"
general.val
"
)
==
"
2000-01-04
"
# test parameters
assert
data_store
.
get
(
"
start
"
,
"
general.test
"
)
==
"
2000-01-05
"
assert
data_store
.
get
(
"
start
"
,
"
general.test
"
)
==
"
2000-01-05
"
assert
data_store
.
get
(
"
end
"
,
"
general.test
"
)
==
"
2000-01-06
"
assert
data_store
.
get
(
"
end
"
,
"
general.test
"
)
==
"
2000-01-06
"
# use all stations on all data sets (train, val, test)
assert
data_store
.
get
(
"
use_all_stations_on_all_data_sets
"
,
"
general.test
"
)
is
False
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment