Skip to content
Snippets Groups Projects
Commit 40ebb434 authored by leufen1's avatar leufen1
Browse files

change data origin to reduce data loading

parent 69ffde01
Branches
Tags
5 merge requests!480Merge multiple stats into crps working branch,!470Develop,!467Resolve "release v2.2.0",!466Draft: Resolve "Include CRPS analysis and other ens verif methods or plots",!460Resolve "TECH: reduce CI running time"
Pipeline #106888 failed
......@@ -30,6 +30,7 @@ class TestPreProcessing:
def obj_with_exp_setup(self):
ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW087', 'DEBW99X'],
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background",
data_origin={'o3': 'UBA', 'temp': 'UBA'},
data_handler=DefaultDataHandler)
pre = object.__new__(PreProcessing)
super(PreProcessing, pre).__init__()
......@@ -38,7 +39,8 @@ class TestPreProcessing:
def test_init(self, caplog):
ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW087'],
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
data_origin={'o3': 'UBA', 'temp': 'UBA'})
caplog.clear()
caplog.set_level(logging.INFO)
with PreProcessing():
......@@ -85,13 +87,13 @@ class TestPreProcessing:
def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup):
caplog.set_level(logging.DEBUG)
obj_with_exp_setup.create_set_split(slice(0, 2), "awesome")
message = "Awesome stations (len=5): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW99X']"
message = "Awesome stations (len=5): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW087', 'DEBW99X']"
assert ('root', 10, message) in caplog.record_tuples
data_store = obj_with_exp_setup.data_store
assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection)
with pytest.raises(NameNotFoundInScope):
data_store.get("data_collection", "general")
assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076']
assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW087']
@pytest.mark.parametrize("name", (None, "tester"))
def test_validate_station_serial(self, caplog, obj_with_exp_setup, name):
......
......@@ -23,29 +23,6 @@ from mlair.run_modules.run_environment import RunEnvironment
from mlair.run_modules.training import Training
def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False):
inception_model = InceptionModelBase()
conv_settings_dict1 = {
'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
if add_minor_branch:
out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
output_activation='linear', reduction_filter=64,
name='Minor_1', dropout_rate=dropout_rate,
)]
else:
out = []
X_in = keras.layers.Dropout(dropout_rate)(X_in)
out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size,
output_activation='linear', reduction_filter=64,
name='Main', dropout_rate=dropout_rate,
))
return keras.Model(inputs=X_input, outputs=out)
class TestTraining:
@pytest.fixture
......@@ -90,15 +67,6 @@ class TestTraining:
RunEnvironment().__del__()
except AssertionError:
pass
# try:
# yield obj
# finally:
# if os.path.exists(path):
# shutil.rmtree(path)
# try:
# RunEnvironment().__del__()
# except AssertionError:
# pass
@pytest.fixture
def learning_rate(self):
......@@ -150,12 +118,16 @@ class TestTraining:
return {'o3': 'dma8eu', 'temp': 'maximum'}
@pytest.fixture
def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var):
data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(path, 'data'),
def data_origin(self):
return {'o3': 'UBA', 'temp': 'UBA'}
@pytest.fixture
def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var, data_origin):
data_prep = DefaultDataHandler.build('DEBW107', data_path=os.path.join(path, 'data'),
experiment_path=os.path.join(path, 'exp_path'),
statistics_per_var=statistics_per_var, station_type="background",
network="AIRBASE", sampling="daily", target_dim="variables",
target_var="o3", time_dim="datetime",
sampling="daily", target_dim="variables",
target_var="o3", time_dim="datetime", data_origin=data_origin,
window_history_size=window_history_size,
window_lead_time=window_lead_time, name_affix="train")
return DataCollection([data_prep])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment