diff --git a/src/modules/pre_processing.py b/src/modules/pre_processing.py index 141d9abec1757dce57f808c133dd653e8c249287..aeed05fccab27a4787f23019f7eec391a6564297 100644 --- a/src/modules/pre_processing.py +++ b/src/modules/pre_processing.py @@ -4,6 +4,11 @@ from typing import Any, Tuple, Dict, List from src.data_generator import DataGenerator from src.helpers import TimeTracking from src.modules.run_environment import RunEnvironment +from src.datastore import NameNotFoundInDataStore + + +DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] +DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history", "window_lead_time"] class PreProcessing(RunEnvironment): @@ -15,30 +20,40 @@ class PreProcessing(RunEnvironment): testing subsets. """ - def __init__(self, experiment_setup: Any): + def __init__(self): + + # create run framework super().__init__() - self.setup = experiment_setup - self.kwargs = None - self.valid_stations = [] + + # self._run() + def _create_args_dict(self, arg_list, scope="general"): + args = {} + for arg in arg_list: + try: + args[arg] = self.data_store.get(arg, scope) + except NameNotFoundInDataStore: + pass + return args + def _run(self): - kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13, - 'window_lead_time': 3, 'interpolate_method': 'linear', - 'statistics_per_var': self.setup.var_all_dict, } - args = self.setup.__dict__ - valid_stations = self.check_valid_stations(args, kwargs, self.setup.stations) - args = self.update_key(args, "stations", valid_stations) - data_gen = DataGenerator(**args, **kwargs) - train, val, test = self.split_train_val_test(data_gen, valid_stations, args, kwargs) - # print stats of data - def split_train_val_test(self, data, stations, args, kwargs): - train_index, val_index, test_index = self.split_set_indices(len(stations), args["fraction_of_training"]) - train = self.create_set_split(stations, args, kwargs, train_index, "train") - val = self.create_set_split(stations, args, kwargs, val_index, "val") - test = self.create_set_split(stations, args, kwargs, test_index, "test") - return train, val, test + args = self._create_args_dict(DEFAULT_ARGS_LIST) + kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST) + valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general")) + self.data_store.put("stations", valid_stations) + self.split_train_val_test() + + def split_train_val_test(self): + fraction_of_training = self.data_store.get("fraction_of_training", "general") + stations = self.data_store.get("stations", "general") + train_index, val_index, test_index = self.split_set_indices(len(stations), fraction_of_training) + for (ind, scope) in zip([train_index, val_index, test_index], ["train", "val", "test"]): + self.create_set_split(ind, scope) + # self.create_set_split(train_index, "train") + # self.create_set_split(val_index, "val") + # self.create_set_split(test_index, "test") @staticmethod def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice]: @@ -57,46 +72,21 @@ class PreProcessing(RunEnvironment): test_index = slice(pos_test_split, total_length) return train_index, val_index, test_index - def create_set_split(self, stations, args, kwargs, index_list, set_name): + def create_set_split(self, index_list, set_name): + scope = f"general.{set_name}" + args = self._create_args_dict(DEFAULT_ARGS_LIST, scope) + kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST, scope) + stations = args["stations"] if args["use_all_stations_on_all_data_sets"]: set_stations = stations else: set_stations = stations[index_list] logging.debug(f"{set_name.capitalize()} stations (len={set_stations}): {set_stations}") - set_kwargs = self.update_kwargs(args, kwargs, f"{set_name}_kwargs") - set_stations = self.check_valid_stations(args, set_kwargs, set_stations) - set_args = self.update_key(args, "stations", set_stations) - data_set = DataGenerator(**set_args, **set_kwargs) - return data_set - - @staticmethod - def update_key(orig_dict: Dict, key: str, value: Any) -> Dict: - """ - create copy of `orig_dict` and update given key by value, returns a copied dict. The original input dict - `orig_dict` is not modified by this function. - :param orig_dict: dictionary with arguments that should be updated - :param key: the key to update - :param value: the update itself for given key - :return: updated dict - """ - updated = orig_dict.copy() - updated.update({key: value}) - return updated - - @staticmethod - def update_kwargs(args: Dict, kwargs: Dict, kwargs_name: str): - """ - copy kwargs and update kwargs parameters by another dictionary stored in args. Not existing keys in kwargs are - created, existing keys overwritten. - :param args: dict with the new kwargs parameters stored with key `kwargs_name` - :param kwargs: dict to update - :param kwargs_name: key in `args` to find the updates for `kwargs` - :return: updated kwargs dict - """ - kwargs_updated = kwargs.copy() - if kwargs_name in args.keys() and args[kwargs_name]: - kwargs_updated.update(args[kwargs_name]) - return kwargs_updated + set_stations = self.check_valid_stations(args, kwargs, set_stations) + self.data_store.put("stations", set_stations, scope) + set_args = self._create_args_dict(DEFAULT_ARGS_LIST, scope) + data_set = DataGenerator(**set_args, **kwargs) + self.data_store.put("generator", data_set, scope) @staticmethod def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]): @@ -127,4 +117,4 @@ class PreProcessing(RunEnvironment): except AttributeError: continue logging.info(f"run for {t_outer} to check {len(all_stations)} station(s)") - return valid_stations \ No newline at end of file + return valid_stations diff --git a/test/test_modules.py b/test/test_modules.py index ce519e40984d20cb16ed5a4f970f8dd721cb0945..b28b04f643122b019e912540f228c8ed20be9eeb 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -1,89 +1,3 @@ -import logging -from src.modules.pre_processing import PreProcessing -from src.helpers import TimeTracking, PyTestRegex -from src.modules.experiment_setup import ExperimentSetup -from src.data_generator import DataGenerator -import mock -import numpy as np -class TestPreProcessing: - def test_init(self, caplog): - caplog.set_level(logging.INFO) - setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) - pre = PreProcessing(setup) - assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') - assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) - - def test_run(self): - pre_processing = object.__new__(PreProcessing) - pre_processing.time = TimeTracking() - pre_processing.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) - assert pre_processing._run() is None - - def test_split_train_val_test(self): - pass - - def test_check_valid_stations(self, caplog): - caplog.set_level(logging.INFO) - pre = object.__new__(PreProcessing) - pre.time = TimeTracking() - pre.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) - kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13, - 'window_lead_time': 3, 'interpolate_method': 'linear', - 'statistics_per_var': pre.setup.var_all_dict, } - valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations) - assert valids == pre.setup.stations - assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) - - def test_update_kwargs(self): - args = {"testName": {"testAttribute": "TestValue", "optional": "2019-11-21"}} - kwargs = {"testAttribute": "DefaultValue", "defaultAttribute": 3} - updated = PreProcessing.update_kwargs(args, kwargs, "testName") - assert updated == {"testAttribute": "TestValue", "defaultAttribute": 3, "optional": "2019-11-21"} - assert kwargs == {"testAttribute": "DefaultValue", "defaultAttribute": 3} - args = {"testName": None} - updated = PreProcessing.update_kwargs(args, kwargs, "testName") - assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3} - args = {"dummy": "notMeaningful"} - updated = PreProcessing.update_kwargs(args, kwargs, "testName") - assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3} - - def test_update_key(self): - orig_dict = {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]} - f = PreProcessing.update_key - assert f(orig_dict, "Test2", 4) == {"Test1": 3, "Test2": 4, "test3": [1, 2, 3]} - assert orig_dict == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]} - assert f(orig_dict, "Test3", 4) == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3], "Test3": 4} - - def test_split_set_indices(self): - dummy_list = list(range(0, 15)) - train, val, test = PreProcessing.split_set_indices(len(dummy_list), 0.9) - assert dummy_list[train] == list(range(0, 10)) - assert dummy_list[val] == list(range(10, 13)) - assert dummy_list[test] == list(range(13, 15)) - - # @mock.patch("DataGenerator", return_value=object.__new__(DataGenerator)) - # @mock.patch("DataGenerator[station]", return_value=(np.ones(10), np.zeros(10))) - # def test_create_set_split(self): - # stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] - # pre = object.__new__(PreProcessing) - # pre.setup = ExperimentSetup({}, stations=stations, var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, - # train_kwargs={"start": "2000-01-01", "end": "2007-12-31"}) - # kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'statistics_per_var': pre.setup.var_all_dict, } - # train = pre.create_set_split(stations, pre.setup.__dict__, kwargs, slice(0, 3), "train") - # # stopped here. It is a mess with all the different kwargs, args etc. Restructure the idea of how to implement - # # the data sets. Because there are multiple kwargs declarations and which counts in the end. And there are - # # multiple declarations of the DataGenerator class. Why this? Is it somehow possible, to select elements from - # # this iterator class. Furthermore the names of the DataPrep class is not distinct, because there is no time - # # range provided in file's name. Given the case, that first to total DataGen is called with a short period for - # # data loading. But then, for the data split (I don't know why this could happen, but it is very likely because - # # osf the current multiple declarations of kwargs arguments) the desired time range exceeds the previou - # # mentioned and short time range. But nevertheless, the file with the short period is loaded and used (because - # # during DataPrep loading, the available range is checked). diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a1aa454fc788aabe6280d618009834dc9f26bf --- /dev/null +++ b/test/test_modules/test_pre_processing.py @@ -0,0 +1,87 @@ +import logging + +from src.helpers import PyTestRegex, TimeTracking +from src.modules.experiment_setup import ExperimentSetup +from src.modules.pre_processing import PreProcessing + + +class TestPreProcessing: + + def test_init(self, caplog): + caplog.set_level(logging.INFO) + setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], + var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) + pre = PreProcessing(setup) + assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') + assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) + + def test_run(self): + pre_processing = object.__new__(PreProcessing) + pre_processing.time = TimeTracking() + pre_processing.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], + var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) + assert pre_processing._run() is None + + def test_split_train_val_test(self): + pass + + def test_check_valid_stations(self, caplog): + caplog.set_level(logging.INFO) + pre = object.__new__(PreProcessing) + pre.time = TimeTracking() + pre.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], + var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) + kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13, + 'window_lead_time': 3, 'interpolate_method': 'linear', + 'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'} } + valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations) + assert valids == pre.setup.stations + assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started') + assert caplog.record_tuples[1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) + + def test_update_kwargs(self): + args = {"testName": {"testAttribute": "TestValue", "optional": "2019-11-21"}} + kwargs = {"testAttribute": "DefaultValue", "defaultAttribute": 3} + updated = PreProcessing.update_kwargs(args, kwargs, "testName") + assert updated == {"testAttribute": "TestValue", "defaultAttribute": 3, "optional": "2019-11-21"} + assert kwargs == {"testAttribute": "DefaultValue", "defaultAttribute": 3} + args = {"testName": None} + updated = PreProcessing.update_kwargs(args, kwargs, "testName") + assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3} + args = {"dummy": "notMeaningful"} + updated = PreProcessing.update_kwargs(args, kwargs, "testName") + assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3} + + def test_update_key(self): + orig_dict = {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]} + f = PreProcessing.update_key + assert f(orig_dict, "Test2", 4) == {"Test1": 3, "Test2": 4, "test3": [1, 2, 3]} + assert orig_dict == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]} + assert f(orig_dict, "Test3", 4) == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3], "Test3": 4} + + def test_split_set_indices(self): + dummy_list = list(range(0, 15)) + train, val, test = PreProcessing.split_set_indices(len(dummy_list), 0.9) + assert dummy_list[train] == list(range(0, 10)) + assert dummy_list[val] == list(range(10, 13)) + assert dummy_list[test] == list(range(13, 15)) + + # @mock.patch("DataGenerator", return_value=object.__new__(DataGenerator)) + # @mock.patch("DataGenerator[station]", return_value=(np.ones(10), np.zeros(10))) + # def test_create_set_split(self): + # stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] + # pre = object.__new__(PreProcessing) + # pre.setup = ExperimentSetup({}, stations=stations, var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, + # train_kwargs={"start": "2000-01-01", "end": "2007-12-31"}) + # kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'statistics_per_var': pre.setup.var_all_dict, } + # train = pre.create_set_split(stations, pre.setup.__dict__, kwargs, slice(0, 3), "train") + # # stopped here. It is a mess with all the different kwargs, args etc. Restructure the idea of how to implement + # # the data sets. Because there are multiple kwargs declarations and which counts in the end. And there are + # # multiple declarations of the DataGenerator class. Why this? Is it somehow possible, to select elements from + # # this iterator class. Furthermore the names of the DataPrep class is not distinct, because there is no time + # # range provided in file's name. Given the case, that first to total DataGen is called with a short period for + # # data loading. But then, for the data split (I don't know why this could happen, but it is very likely because + # # osf the current multiple declarations of kwargs arguments) the desired time range exceeds the previou + # # mentioned and short time range. But nevertheless, the file with the short period is loaded and used (because + # # during DataPrep loading, the available range is checked). \ No newline at end of file