diff --git a/src/datastore.py b/src/datastore.py index b357fd345381bab98455e487422a14118e072e95..bb8474a04b503b1ff50fcea1b7e5f8bbd1d9ebea 100644 --- a/src/datastore.py +++ b/src/datastore.py @@ -86,6 +86,9 @@ class AbstractDataStore(ABC): """ pass + def clear_data_store(self) -> None: + self._store = {} + class DataStoreByVariable(AbstractDataStore): diff --git a/src/modules/experiment_setup.py b/src/modules/experiment_setup.py index f81d2a5b7ff2c7ab477454ee34d77f2c15381dd4..a76fe60b34b679b5702ec85a11f95002c3c6fe34 100644 --- a/src/modules/experiment_setup.py +++ b/src/modules/experiment_setup.py @@ -40,7 +40,7 @@ class ExperimentSetup(RunEnvironment): # experiment setup self._set_param("data_path", helpers.prepare_host()) self._set_param("trainable", trainable, default=False) - self._set_param("fraction_of_train", fraction_of_train, default=0.8) + self._set_param("fraction_of_training", fraction_of_train, default=0.8) # set experiment name exp_date = self._get_parser_args(parser_args).get("experiment_date") diff --git a/src/modules/pre_processing.py b/src/modules/pre_processing.py index aeed05fccab27a4787f23019f7eec391a6564297..d999217e9f903d3d67a24179c9f3654fee3e60d4 100644 --- a/src/modules/pre_processing.py +++ b/src/modules/pre_processing.py @@ -4,11 +4,11 @@ from typing import Any, Tuple, Dict, List from src.data_generator import DataGenerator from src.helpers import TimeTracking from src.modules.run_environment import RunEnvironment -from src.datastore import NameNotFoundInDataStore +from src.datastore import NameNotFoundInDataStore, NameNotFoundInScope DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] -DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history", "window_lead_time"] +DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history", "window_lead_time", "statistics_per_var"] class PreProcessing(RunEnvironment): @@ -33,16 +33,15 @@ class PreProcessing(RunEnvironment): for arg in arg_list: try: args[arg] = self.data_store.get(arg, scope) - except NameNotFoundInDataStore: + except (NameNotFoundInDataStore, NameNotFoundInScope): pass return args def _run(self): - args = self._create_args_dict(DEFAULT_ARGS_LIST) kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST) valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general")) - self.data_store.put("stations", valid_stations) + self.data_store.put("stations", valid_stations, "general") self.split_train_val_test() def split_train_val_test(self): @@ -51,9 +50,6 @@ class PreProcessing(RunEnvironment): train_index, val_index, test_index = self.split_set_indices(len(stations), fraction_of_training) for (ind, scope) in zip([train_index, val_index, test_index], ["train", "val", "test"]): self.create_set_split(ind, scope) - # self.create_set_split(train_index, "train") - # self.create_set_split(val_index, "val") - # self.create_set_split(test_index, "test") @staticmethod def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice]: @@ -77,11 +73,11 @@ class PreProcessing(RunEnvironment): args = self._create_args_dict(DEFAULT_ARGS_LIST, scope) kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST, scope) stations = args["stations"] - if args["use_all_stations_on_all_data_sets"]: + if self.data_store.get("use_all_stations_on_all_data_sets", scope): set_stations = stations else: set_stations = stations[index_list] - logging.debug(f"{set_name.capitalize()} stations (len={set_stations}): {set_stations}") + logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") set_stations = self.check_valid_stations(args, kwargs, set_stations) self.data_store.put("stations", set_stations, scope) set_args = self._create_args_dict(DEFAULT_ARGS_LIST, scope) diff --git a/src/modules/run_environment.py b/src/modules/run_environment.py index b0aa77d50fc4fbc10b3b9e4debfe2ae5173d2a22..56c017290eea4d11881b9b131378d8c5995f0b29 100644 --- a/src/modules/run_environment.py +++ b/src/modules/run_environment.py @@ -33,6 +33,8 @@ class RunEnvironment(object): self.time.stop() logging.info(f"{self.__class__.__name__} finished after {self.time}") self.del_by_exit = True + if self.__class__.__name__ == "RunEnvironment": + self.data_store.clear_data_store() def __enter__(self): return self diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index a1a1aa454fc788aabe6280d618009834dc9f26bf..bc121885ddb8ee20b0f571e7f0250845c6e99e6a 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -1,87 +1,110 @@ import logging +import pytest from src.helpers import PyTestRegex, TimeTracking from src.modules.experiment_setup import ExperimentSetup -from src.modules.pre_processing import PreProcessing +from src.modules.pre_processing import PreProcessing, DEFAULT_ARGS_LIST, DEFAULT_KWARGS_LIST +from src.data_generator import DataGenerator +from src.datastore import NameNotFoundInScope +from src.modules.run_environment import RunEnvironment class TestPreProcessing: + @pytest.fixture + def obj_no_init(self): + return object.__new__(PreProcessing) + + @pytest.fixture + def obj_super_init(self): + obj = object.__new__(PreProcessing) + super(PreProcessing, obj).__init__() + obj.data_store.put("NAME1", 1, "general") + obj.data_store.put("NAME2", 2, "general") + obj.data_store.put("NAME3", 3, "general") + obj.data_store.put("NAME1", 10, "general.sub") + obj.data_store.put("NAME4", 4, "general.sub.sub") + yield obj + RunEnvironment().__del__() + + @pytest.fixture + def obj_with_exp_setup(self): + ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], + var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) + pre = object.__new__(PreProcessing) + super(PreProcessing, pre).__init__() + yield pre + RunEnvironment().__del__() + def test_init(self, caplog): + ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], + var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) caplog.set_level(logging.INFO) - setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) - pre = PreProcessing(setup) + PreProcessing() assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) + assert caplog.record_tuples[-2] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) + RunEnvironment().__del__() + + def test_run(self, obj_with_exp_setup): + assert obj_with_exp_setup.data_store.search_name("generator") == [] + assert obj_with_exp_setup._run() is None + assert obj_with_exp_setup.data_store.search_name("generator") == sorted(["general.train", "general.val", + "general.test"]) - def test_run(self): - pre_processing = object.__new__(PreProcessing) - pre_processing.time = TimeTracking() - pre_processing.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) - assert pre_processing._run() is None + def test_split_train_val_test(self, obj_with_exp_setup): + assert obj_with_exp_setup.data_store.search_name("generator") == [] + obj_with_exp_setup.split_train_val_test() + data_store = obj_with_exp_setup.data_store + assert data_store.search_scope("general.train") == sorted(["generator", "start", "end", "stations"]) + assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test"]) - def test_split_train_val_test(self): - pass + def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup): + caplog.set_level(logging.DEBUG) + obj_with_exp_setup.data_store.put("use_all_stations_on_all_data_sets", False, "general.awesome") + obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") + assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") + data_store = obj_with_exp_setup.data_store + assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) + with pytest.raises(NameNotFoundInScope): + data_store.get("generator", "general") + assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBY081"] - def test_check_valid_stations(self, caplog): + def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): + caplog.set_level(logging.DEBUG) + obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") + assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=5): ['DEBW107', 'DEBY081', 'DEBW013', " + "'DEBW076', 'DEBW087']") + data_store = obj_with_exp_setup.data_store + assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) + with pytest.raises(NameNotFoundInScope): + data_store.get("generator", "general") + assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] + + def test_check_valid_stations(self, caplog, obj_with_exp_setup): + pre = obj_with_exp_setup caplog.set_level(logging.INFO) - pre = object.__new__(PreProcessing) - pre.time = TimeTracking() - pre.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) - kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13, - 'window_lead_time': 3, 'interpolate_method': 'linear', - 'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'} } - valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations) - assert valids == pre.setup.stations + args = pre._create_args_dict(DEFAULT_ARGS_LIST) + kwargs = pre._create_args_dict(DEFAULT_KWARGS_LIST) + stations = pre.data_store.get("stations", "general") + valid_stations = pre.check_valid_stations(args, kwargs, stations) + assert valid_stations == stations assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) - - def test_update_kwargs(self): - args = {"testName": {"testAttribute": "TestValue", "optional": "2019-11-21"}} - kwargs = {"testAttribute": "DefaultValue", "defaultAttribute": 3} - updated = PreProcessing.update_kwargs(args, kwargs, "testName") - assert updated == {"testAttribute": "TestValue", "defaultAttribute": 3, "optional": "2019-11-21"} - assert kwargs == {"testAttribute": "DefaultValue", "defaultAttribute": 3} - args = {"testName": None} - updated = PreProcessing.update_kwargs(args, kwargs, "testName") - assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3} - args = {"dummy": "notMeaningful"} - updated = PreProcessing.update_kwargs(args, kwargs, "testName") - assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3} - - def test_update_key(self): - orig_dict = {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]} - f = PreProcessing.update_key - assert f(orig_dict, "Test2", 4) == {"Test1": 3, "Test2": 4, "test3": [1, 2, 3]} - assert orig_dict == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]} - assert f(orig_dict, "Test3", 4) == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3], "Test3": 4} - - def test_split_set_indices(self): + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) + + def test_split_set_indices(self, obj_no_init): dummy_list = list(range(0, 15)) - train, val, test = PreProcessing.split_set_indices(len(dummy_list), 0.9) + train, val, test = obj_no_init.split_set_indices(len(dummy_list), 0.9) assert dummy_list[train] == list(range(0, 10)) assert dummy_list[val] == list(range(10, 13)) assert dummy_list[test] == list(range(13, 15)) - # @mock.patch("DataGenerator", return_value=object.__new__(DataGenerator)) - # @mock.patch("DataGenerator[station]", return_value=(np.ones(10), np.zeros(10))) - # def test_create_set_split(self): - # stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] - # pre = object.__new__(PreProcessing) - # pre.setup = ExperimentSetup({}, stations=stations, var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, - # train_kwargs={"start": "2000-01-01", "end": "2007-12-31"}) - # kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'statistics_per_var': pre.setup.var_all_dict, } - # train = pre.create_set_split(stations, pre.setup.__dict__, kwargs, slice(0, 3), "train") - # # stopped here. It is a mess with all the different kwargs, args etc. Restructure the idea of how to implement - # # the data sets. Because there are multiple kwargs declarations and which counts in the end. And there are - # # multiple declarations of the DataGenerator class. Why this? Is it somehow possible, to select elements from - # # this iterator class. Furthermore the names of the DataPrep class is not distinct, because there is no time - # # range provided in file's name. Given the case, that first to total DataGen is called with a short period for - # # data loading. But then, for the data split (I don't know why this could happen, but it is very likely because - # # osf the current multiple declarations of kwargs arguments) the desired time range exceeds the previou - # # mentioned and short time range. But nevertheless, the file with the short period is loaded and used (because - # # during DataPrep loading, the available range is checked). \ No newline at end of file + def test_create_args_dict_default_scope(self, obj_super_init): + assert obj_super_init._create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2} + + def test_create_args_dict_given_scope(self, obj_super_init): + assert obj_super_init._create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2} + + def test_create_args_dict_missing_entry(self, obj_super_init): + assert obj_super_init._create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2} + assert obj_super_init._create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2}