Skip to content
Snippets Groups Projects
Commit 0e36df1e authored by lukas leufen's avatar lukas leufen
Browse files

apply refac, /close #61

parents e98d67dd 0e86d647
No related branches found
No related tags found
1 merge request!59Develop
Pipeline #31116 passed
......@@ -184,7 +184,7 @@ class DataGenerator(keras.utils.Sequence):
if self.transformation is not None:
data.transform("datetime", **helpers.dict_pop(self.transformation, "scope"))
data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
data.make_history_window(self.interpolate_dim, self.window_history_size)
data.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim)
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.make_observation(self.target_dim, self.target_var, self.interpolate_dim)
data.remove_nan(self.interpolate_dim)
......
......@@ -27,6 +27,7 @@ class TestDataPrep:
d.network = 'UBA'
d.station = ['DEBW107']
d.variables = ['o3', 'temp']
d.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
d.station_type = "background"
d.sampling = "daily"
d.kwargs = None
......@@ -125,6 +126,7 @@ class TestDataPrep:
d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "data")
d.station = 'TESTSTATION'
d.variables = ['a', 'bc']
d.statistics_per_var = {'a': 'dma8eu', 'bc': 'maximum'}
assert d._set_file_name() == os.path.join(os.path.abspath(os.path.dirname(__file__)),
"data/TESTSTATION_a_bc.nc")
assert d._set_meta_file_name() == os.path.join(os.path.abspath(os.path.dirname(__file__)),
......
......@@ -54,11 +54,10 @@ class TestExperimentSetup:
assert data_store.get("experiment_name", "general") == "TestExperiment"
path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment"))
assert data_store.get("experiment_path", "general") == path
default_var_all_dict = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
'pblheight': 'maximum'}
default_statistics_per_var = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum',
'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu',
'cloudcover': 'average_values', 'pblheight': 'maximum'}
# setup for data
assert data_store.get("var_all_dict", "general") == default_var_all_dict
default_stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022',
'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039',
'DEBW038', 'DEBW081', 'DEBY075', 'DEBW040', 'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072',
......@@ -69,8 +68,8 @@ class TestExperimentSetup:
assert data_store.get("stations", "general") == default_stations
assert data_store.get("network", "general") == "AIRBASE"
assert data_store.get("station_type", "general") is None
assert data_store.get("variables", "general") == list(default_var_all_dict.keys())
assert data_store.get("statistics_per_var", "general") == default_var_all_dict
assert data_store.get("variables", "general") == list(default_statistics_per_var.keys())
assert data_store.get("statistics_per_var", "general") == default_statistics_per_var
assert data_store.get("start", "general") == "1997-01-01"
assert data_store.get("end", "general") == "2017-12-31"
assert data_store.get("window_history_size", "general") == 13
......@@ -98,11 +97,10 @@ class TestExperimentSetup:
def test_init_no_default(self):
experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
kwargs = dict(parser_args={"experiment_date": "TODAY"},
var_all_dict={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'},
statistics_per_var={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'},
stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", station_type="background",
variables=["o3", "temp"],
statistics_per_var=None, start="1999-01-01", end="2001-01-01", window_history_size=4,
target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1",
variables=["o3", "temp"], start="1999-01-01", end="2001-01-01", window_history_size=4,
target_var="relhum", target_dim="target", window_lead_time=10, dimensions="dim1",
interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01",
train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05",
test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False,
......@@ -120,8 +118,6 @@ class TestExperimentSetup:
"TODAY_network"))
assert data_store.get("experiment_path", "general") == path
# setup for data
assert data_store.get("var_all_dict", "general") == {'o3': 'dma8eu', 'relhum': 'average_values',
'temp': 'maximum'}
assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027']
assert data_store.get("network", "general") == "INTERNET"
assert data_store.get("station_type", "general") == "background"
......@@ -132,7 +128,7 @@ class TestExperimentSetup:
assert data_store.get("end", "general") == "2001-01-01"
assert data_store.get("window_history_size", "general") == 4
# target
assert data_store.get("target_var", "general") == "temp"
assert data_store.get("target_var", "general") == "relhum"
assert data_store.get("target_dim", "general") == "target"
assert data_store.get("window_lead_time", "general") == 10
# interpolation
......@@ -173,8 +169,8 @@ class TestExperimentSetup:
def test_compare_variables_and_statistics(self):
experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
kwargs = dict(parser_args={"experiment_date": "TODAY"},
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'},
stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], statistics_per_var=None,
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"],
experiment_path=experiment_path)
with pytest.raises(ValueError) as e:
ExperimentSetup(**kwargs)
......
......@@ -27,7 +27,7 @@ class TestPreProcessing:
@pytest.fixture
def obj_with_exp_setup(self):
ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background")
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background")
pre = object.__new__(PreProcessing)
super(PreProcessing, pre).__init__()
yield pre
......@@ -35,7 +35,7 @@ class TestPreProcessing:
def test_init(self, caplog):
ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
caplog.set_level(logging.INFO)
with PreProcessing():
assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment