Skip to content
Snippets Groups Projects
Commit fbcb630b authored by lukas leufen's avatar lukas leufen
Browse files

introduced data store to preprocessing module

parent 1faf7c2d
No related branches found
No related tags found
2 merge requests!17update to v0.4.0,!15new feat split subsets
Pipeline #26468 passed
......@@ -4,6 +4,11 @@ from typing import Any, Tuple, Dict, List
from src.data_generator import DataGenerator
from src.helpers import TimeTracking
from src.modules.run_environment import RunEnvironment
from src.datastore import NameNotFoundInDataStore
DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history", "window_lead_time"]
class PreProcessing(RunEnvironment):
......@@ -15,30 +20,40 @@ class PreProcessing(RunEnvironment):
testing subsets.
"""
def __init__(self, experiment_setup: Any):
def __init__(self):
# create run framework
super().__init__()
self.setup = experiment_setup
self.kwargs = None
self.valid_stations = []
#
self._run()
def _create_args_dict(self, arg_list, scope="general"):
args = {}
for arg in arg_list:
try:
args[arg] = self.data_store.get(arg, scope)
except NameNotFoundInDataStore:
pass
return args
def _run(self):
kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13,
'window_lead_time': 3, 'interpolate_method': 'linear',
'statistics_per_var': self.setup.var_all_dict, }
args = self.setup.__dict__
valid_stations = self.check_valid_stations(args, kwargs, self.setup.stations)
args = self.update_key(args, "stations", valid_stations)
data_gen = DataGenerator(**args, **kwargs)
train, val, test = self.split_train_val_test(data_gen, valid_stations, args, kwargs)
# print stats of data
def split_train_val_test(self, data, stations, args, kwargs):
train_index, val_index, test_index = self.split_set_indices(len(stations), args["fraction_of_training"])
train = self.create_set_split(stations, args, kwargs, train_index, "train")
val = self.create_set_split(stations, args, kwargs, val_index, "val")
test = self.create_set_split(stations, args, kwargs, test_index, "test")
return train, val, test
args = self._create_args_dict(DEFAULT_ARGS_LIST)
kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST)
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"))
self.data_store.put("stations", valid_stations)
self.split_train_val_test()
def split_train_val_test(self):
fraction_of_training = self.data_store.get("fraction_of_training", "general")
stations = self.data_store.get("stations", "general")
train_index, val_index, test_index = self.split_set_indices(len(stations), fraction_of_training)
for (ind, scope) in zip([train_index, val_index, test_index], ["train", "val", "test"]):
self.create_set_split(ind, scope)
# self.create_set_split(train_index, "train")
# self.create_set_split(val_index, "val")
# self.create_set_split(test_index, "test")
@staticmethod
def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice]:
......@@ -57,46 +72,21 @@ class PreProcessing(RunEnvironment):
test_index = slice(pos_test_split, total_length)
return train_index, val_index, test_index
def create_set_split(self, stations, args, kwargs, index_list, set_name):
def create_set_split(self, index_list, set_name):
scope = f"general.{set_name}"
args = self._create_args_dict(DEFAULT_ARGS_LIST, scope)
kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST, scope)
stations = args["stations"]
if args["use_all_stations_on_all_data_sets"]:
set_stations = stations
else:
set_stations = stations[index_list]
logging.debug(f"{set_name.capitalize()} stations (len={set_stations}): {set_stations}")
set_kwargs = self.update_kwargs(args, kwargs, f"{set_name}_kwargs")
set_stations = self.check_valid_stations(args, set_kwargs, set_stations)
set_args = self.update_key(args, "stations", set_stations)
data_set = DataGenerator(**set_args, **set_kwargs)
return data_set
@staticmethod
def update_key(orig_dict: Dict, key: str, value: Any) -> Dict:
"""
create copy of `orig_dict` and update given key by value, returns a copied dict. The original input dict
`orig_dict` is not modified by this function.
:param orig_dict: dictionary with arguments that should be updated
:param key: the key to update
:param value: the update itself for given key
:return: updated dict
"""
updated = orig_dict.copy()
updated.update({key: value})
return updated
@staticmethod
def update_kwargs(args: Dict, kwargs: Dict, kwargs_name: str):
"""
copy kwargs and update kwargs parameters by another dictionary stored in args. Not existing keys in kwargs are
created, existing keys overwritten.
:param args: dict with the new kwargs parameters stored with key `kwargs_name`
:param kwargs: dict to update
:param kwargs_name: key in `args` to find the updates for `kwargs`
:return: updated kwargs dict
"""
kwargs_updated = kwargs.copy()
if kwargs_name in args.keys() and args[kwargs_name]:
kwargs_updated.update(args[kwargs_name])
return kwargs_updated
set_stations = self.check_valid_stations(args, kwargs, set_stations)
self.data_store.put("stations", set_stations, scope)
set_args = self._create_args_dict(DEFAULT_ARGS_LIST, scope)
data_set = DataGenerator(**set_args, **kwargs)
self.data_store.put("generator", data_set, scope)
@staticmethod
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]):
......
import logging
from src.modules.pre_processing import PreProcessing
from src.helpers import TimeTracking, PyTestRegex
from src.modules.experiment_setup import ExperimentSetup
from src.data_generator import DataGenerator
import mock
import numpy as np
class TestPreProcessing:
def test_init(self, caplog):
caplog.set_level(logging.INFO)
setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
pre = PreProcessing(setup)
assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)'))
def test_run(self):
pre_processing = object.__new__(PreProcessing)
pre_processing.time = TimeTracking()
pre_processing.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
assert pre_processing._run() is None
def test_split_train_val_test(self):
pass
def test_check_valid_stations(self, caplog):
caplog.set_level(logging.INFO)
pre = object.__new__(PreProcessing)
pre.time = TimeTracking()
pre.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13,
'window_lead_time': 3, 'interpolate_method': 'linear',
'statistics_per_var': pre.setup.var_all_dict, }
valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations)
assert valids == pre.setup.stations
assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)'))
def test_update_kwargs(self):
args = {"testName": {"testAttribute": "TestValue", "optional": "2019-11-21"}}
kwargs = {"testAttribute": "DefaultValue", "defaultAttribute": 3}
updated = PreProcessing.update_kwargs(args, kwargs, "testName")
assert updated == {"testAttribute": "TestValue", "defaultAttribute": 3, "optional": "2019-11-21"}
assert kwargs == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
args = {"testName": None}
updated = PreProcessing.update_kwargs(args, kwargs, "testName")
assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
args = {"dummy": "notMeaningful"}
updated = PreProcessing.update_kwargs(args, kwargs, "testName")
assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
def test_update_key(self):
orig_dict = {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]}
f = PreProcessing.update_key
assert f(orig_dict, "Test2", 4) == {"Test1": 3, "Test2": 4, "test3": [1, 2, 3]}
assert orig_dict == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]}
assert f(orig_dict, "Test3", 4) == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3], "Test3": 4}
def test_split_set_indices(self):
dummy_list = list(range(0, 15))
train, val, test = PreProcessing.split_set_indices(len(dummy_list), 0.9)
assert dummy_list[train] == list(range(0, 10))
assert dummy_list[val] == list(range(10, 13))
assert dummy_list[test] == list(range(13, 15))
# @mock.patch("DataGenerator", return_value=object.__new__(DataGenerator))
# @mock.patch("DataGenerator[station]", return_value=(np.ones(10), np.zeros(10)))
# def test_create_set_split(self):
# stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
# pre = object.__new__(PreProcessing)
# pre.setup = ExperimentSetup({}, stations=stations, var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'},
# train_kwargs={"start": "2000-01-01", "end": "2007-12-31"})
# kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'statistics_per_var': pre.setup.var_all_dict, }
# train = pre.create_set_split(stations, pre.setup.__dict__, kwargs, slice(0, 3), "train")
# # stopped here. It is a mess with all the different kwargs, args etc. Restructure the idea of how to implement
# # the data sets. Because there are multiple kwargs declarations and which counts in the end. And there are
# # multiple declarations of the DataGenerator class. Why this? Is it somehow possible, to select elements from
# # this iterator class. Furthermore the names of the DataPrep class is not distinct, because there is no time
# # range provided in file's name. Given the case, that first to total DataGen is called with a short period for
# # data loading. But then, for the data split (I don't know why this could happen, but it is very likely because
# # osf the current multiple declarations of kwargs arguments) the desired time range exceeds the previou
# # mentioned and short time range. But nevertheless, the file with the short period is loaded and used (because
# # during DataPrep loading, the available range is checked).
import logging
from src.helpers import PyTestRegex, TimeTracking
from src.modules.experiment_setup import ExperimentSetup
from src.modules.pre_processing import PreProcessing
class TestPreProcessing:
def test_init(self, caplog):
caplog.set_level(logging.INFO)
setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
pre = PreProcessing(setup)
assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)'))
def test_run(self):
pre_processing = object.__new__(PreProcessing)
pre_processing.time = TimeTracking()
pre_processing.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
assert pre_processing._run() is None
def test_split_train_val_test(self):
pass
def test_check_valid_stations(self, caplog):
caplog.set_level(logging.INFO)
pre = object.__new__(PreProcessing)
pre.time = TimeTracking()
pre.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13,
'window_lead_time': 3, 'interpolate_method': 'linear',
'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'} }
valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations)
assert valids == pre.setup.stations
assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)'))
def test_update_kwargs(self):
args = {"testName": {"testAttribute": "TestValue", "optional": "2019-11-21"}}
kwargs = {"testAttribute": "DefaultValue", "defaultAttribute": 3}
updated = PreProcessing.update_kwargs(args, kwargs, "testName")
assert updated == {"testAttribute": "TestValue", "defaultAttribute": 3, "optional": "2019-11-21"}
assert kwargs == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
args = {"testName": None}
updated = PreProcessing.update_kwargs(args, kwargs, "testName")
assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
args = {"dummy": "notMeaningful"}
updated = PreProcessing.update_kwargs(args, kwargs, "testName")
assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
def test_update_key(self):
orig_dict = {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]}
f = PreProcessing.update_key
assert f(orig_dict, "Test2", 4) == {"Test1": 3, "Test2": 4, "test3": [1, 2, 3]}
assert orig_dict == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]}
assert f(orig_dict, "Test3", 4) == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3], "Test3": 4}
def test_split_set_indices(self):
dummy_list = list(range(0, 15))
train, val, test = PreProcessing.split_set_indices(len(dummy_list), 0.9)
assert dummy_list[train] == list(range(0, 10))
assert dummy_list[val] == list(range(10, 13))
assert dummy_list[test] == list(range(13, 15))
# @mock.patch("DataGenerator", return_value=object.__new__(DataGenerator))
# @mock.patch("DataGenerator[station]", return_value=(np.ones(10), np.zeros(10)))
# def test_create_set_split(self):
# stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
# pre = object.__new__(PreProcessing)
# pre.setup = ExperimentSetup({}, stations=stations, var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'},
# train_kwargs={"start": "2000-01-01", "end": "2007-12-31"})
# kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'statistics_per_var': pre.setup.var_all_dict, }
# train = pre.create_set_split(stations, pre.setup.__dict__, kwargs, slice(0, 3), "train")
# # stopped here. It is a mess with all the different kwargs, args etc. Restructure the idea of how to implement
# # the data sets. Because there are multiple kwargs declarations and which counts in the end. And there are
# # multiple declarations of the DataGenerator class. Why this? Is it somehow possible, to select elements from
# # this iterator class. Furthermore the names of the DataPrep class is not distinct, because there is no time
# # range provided in file's name. Given the case, that first to total DataGen is called with a short period for
# # data loading. But then, for the data split (I don't know why this could happen, but it is very likely because
# # osf the current multiple declarations of kwargs arguments) the desired time range exceeds the previou
# # mentioned and short time range. But nevertheless, the file with the short period is loaded and used (because
# # during DataPrep loading, the available range is checked).
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment