Skip to content
Snippets Groups Projects
Commit 468e4383 authored by lukas leufen's avatar lukas leufen
Browse files

worked on split methods

parent ae85c9ec
No related branches found
No related tags found
2 merge requests!17update to v0.4.0,!15new feat split subsets
......@@ -11,7 +11,7 @@ from src.modules import run, PreProcessing, Training, PostProcessing
def main():
with run():
exp_setup = ExperimentSetup(args, trainable=True)
exp_setup = ExperimentSetup(args, trainable=True, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])
PreProcessing(exp_setup)
......@@ -23,7 +23,7 @@ def main():
if __name__ == "__main__":
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.DEBUG)
logging.basicConfig(format=formatter, level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
......
......@@ -29,6 +29,9 @@ class ExperimentSetup(object):
self.interpolate_dim = None
self.target_dim = None
self.target_var = None
self.train_kwargs = None
self.val_kwargs = None
self.test_kwargs = None
self.setup_experiment(**kwargs)
def _set_param(self, param, value, default=None):
......@@ -86,3 +89,6 @@ class ExperimentSetup(object):
self._set_param("interpolate_dim", kwargs, default='datetime')
self._set_param("target_dim", kwargs, default='variables')
self._set_param("target_var", kwargs, default="o3")
self._set_param("train_kwargs", kwargs, default={"start": "1997-01-01", "end": "2007-12-31"})
self._set_param("val_kwargs", kwargs, default={"start": "2008-01-01", "end": "2009-12-31"})
self._set_param("test_kwargs", kwargs, default={"start": "2010-01-01", "end": "2017-12-31"})
......@@ -4,7 +4,7 @@ import time
from src.data_generator import DataGenerator
from src.experiment_setup import ExperimentSetup
import argparse
from typing import Dict, List
from typing import Dict, List, Any, Tuple
class run(object):
......@@ -63,15 +63,77 @@ class PreProcessing(run):
kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13,
'window_lead_time': 3, 'interpolate_method': 'linear',
'statistics_per_var': self.setup.var_all_dict, }
valid_stations = self.check_valid_stations(self.setup.__dict__, kwargs, self.setup.stations)
args = self.setup.__dict__
args["stations"] = valid_stations
valid_stations = self.check_valid_stations(args, kwargs, self.setup.stations)
args = self.update_key(args, "stations", valid_stations)
data_gen = DataGenerator(**args, **kwargs)
train, val, test = self.split_train_val_test()
train, val, test = self.split_train_val_test(data_gen, valid_stations, args, kwargs)
# print stats of data
def split_train_val_test(self, data, stations, args, kwargs):
train_index, val_index, test_index = self.split_set_indices(len(stations), args["fraction_of_training"])
train = self.create_set_split(stations, args, kwargs, train_index, "train")
val = self.create_set_split(stations, args, kwargs, val_index, "val")
test = self.create_set_split(stations, args, kwargs, test_index, "test")
return train, val, test
@staticmethod
def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice]:
"""
create the training, validation and test subset slice indices for given total_length. The test data consists on
(1-fraction) of total_length (fraction*len:end). Train and validation data therefore are made from fraction of
total_length (0:fraction*len). Train and validation data is split by the factor 0.8 for train and 0.2 for
validation.
:param total_length: list with all objects to split
:param fraction: ratio between test and union of train/val data
:return: slices for each subset in the order: train, val, test
"""
pos_test_split = int(total_length * fraction)
train_index = slice(0, int(pos_test_split * 0.8))
val_index = slice(int(pos_test_split * 0.8), pos_test_split)
test_index = slice(pos_test_split, total_length)
return train_index, val_index, test_index
def create_set_split(self, stations, args, kwargs, index_list, set_name):
if args["use_all_stations_on_all_data_sets"]:
set_stations = stations
else:
set_stations = stations[index_list]
logging.debug(f"{set_name.capitalize()} stations (len={set_stations}): {set_stations}")
set_kwargs = self.update_kwargs(args, kwargs, f"{set_name}_kwargs")
set_stations = self.check_valid_stations(args, set_kwargs, set_stations)
set_args = self.update_key(args, "stations", set_stations)
data_set = DataGenerator(**set_args, **set_kwargs)
return data_set
@staticmethod
def split_train_val_test():
return None, None, None
def update_key(orig_dict: Dict, key: str, value: Any) -> Dict:
"""
create copy of `orig_dict` and update given key by value, returns a copied dict. The original input dict
`orig_dict` is not modified by this function.
:param orig_dict: dictionary with arguments that should be updated
:param key: the key to update
:param value: the update itself for given key
:return: updated dict
"""
updated = orig_dict.copy()
updated.update({key: value})
return updated
@staticmethod
def update_kwargs(args: Dict, kwargs: Dict, kwargs_name: str):
"""
copy kwargs and update kwargs parameters by another dictionary stored in args. Not existing keys in kwargs are
created, existing keys overwritten.
:param args: dict with the new kwargs parameters stored with key `kwargs_name`
:param kwargs: dict to update
:param kwargs_name: key in `args` to find the updates for `kwargs`
:return: updated kwargs dict
"""
kwargs_updated = kwargs.copy()
if kwargs_name in args.keys() and args[kwargs_name]:
kwargs_updated.update(args[kwargs_name])
return kwargs_updated
@staticmethod
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]):
......
......@@ -4,8 +4,10 @@ from src.modules import run, PreProcessing
from src.helpers import TimeTracking
import src.helpers
from src.experiment_setup import ExperimentSetup
from src.data_generator import DataGenerator
import re
import mock
import numpy as np
class pytest_regex:
......@@ -29,7 +31,7 @@ class TestRun:
assert caplog.record_tuples[-1] == ('root', 20, 'run started')
assert isinstance(r.time, TimeTracking)
r.do_stuff(0.1)
assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s"))
assert caplog.record_tuples[-1] == ('root', 20, pytest_regex(r"run finished after \d+\.\d+s"))
def test_init_del(self, caplog):
caplog.set_level(logging.INFO)
......@@ -37,7 +39,7 @@ class TestRun:
assert caplog.record_tuples[-1] == ('root', 20, 'run started')
r.do_stuff(0.2)
del r
assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s"))
assert caplog.record_tuples[-1] == ('root', 20, pytest_regex(r"run finished after \d+\.\d+s"))
class TestPreProcessing:
......@@ -49,7 +51,7 @@ class TestPreProcessing:
pre = PreProcessing(setup)
assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[-1] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)'))
assert caplog.record_tuples[-1] == ('root', 20, pytest_regex(r'run for \d+\.\d+s to check 5 station\(s\)'))
def test_run(self):
pre_processing = object.__new__(PreProcessing)
......@@ -73,4 +75,50 @@ class TestPreProcessing:
valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations)
assert valids == pre.setup.stations
assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[1] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)'))
assert caplog.record_tuples[1] == ('root', 20, pytest_regex(r'run for \d+\.\d+s to check 5 station\(s\)'))
def test_update_kwargs(self):
args = {"testName": {"testAttribute": "TestValue", "optional": "2019-11-21"}}
kwargs = {"testAttribute": "DefaultValue", "defaultAttribute": 3}
updated = PreProcessing.update_kwargs(args, kwargs, "testName")
assert updated == {"testAttribute": "TestValue", "defaultAttribute": 3, "optional": "2019-11-21"}
assert kwargs == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
args = {"testName": None}
updated = PreProcessing.update_kwargs(args, kwargs, "testName")
assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
args = {"dummy": "notMeaningful"}
updated = PreProcessing.update_kwargs(args, kwargs, "testName")
assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
def test_update_key(self):
orig_dict = {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]}
f = PreProcessing.update_key
assert f(orig_dict, "Test2", 4) == {"Test1": 3, "Test2": 4, "test3": [1, 2, 3]}
assert orig_dict == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]}
assert f(orig_dict, "Test3", 4) == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3], "Test3": 4}
def test_split_set_indices(self):
dummy_list = list(range(0, 15))
train, val, test = PreProcessing.split_set_indices(len(dummy_list), 0.9)
assert dummy_list[train] == list(range(0, 10))
assert dummy_list[val] == list(range(10, 13))
assert dummy_list[test] == list(range(13, 15))
@mock.patch("DataGenerator", return_value=object.__new__(DataGenerator))
@mock.patch("DataGenerator[station]", return_value=(np.ones(10), np.zeros(10)))
def test_create_set_split(self):
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
pre = object.__new__(PreProcessing)
pre.setup = ExperimentSetup({}, stations=stations, var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'},
train_kwargs={"start": "2000-01-01", "end": "2007-12-31"})
kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'statistics_per_var': pre.setup.var_all_dict, }
train = pre.create_set_split(stations, pre.setup.__dict__, kwargs, slice(0, 3), "train")
# stopped here. It is a mess with all the different kwargs, args etc. Restructure the idea of how to implement
# the data sets. Because there are multiple kwargs declarations and which counts in the end. And there are
# multiple declarations of the DataGenerator class. Why this? Is it somehow possible, to select elements from
# this iterator class. Furthermore the names of the DataPrep class is not distinct, because there is no time
# range provided in file's name. Given the case, that first to total DataGen is called with a short period for
# data loading. But then, for the data split (I don't know why this could happen, but it is very likely because
# osf the current multiple declarations of kwargs arguments) the desired time range exceeds the previou
# mentioned and short time range. But nevertheless, the file with the short period is loaded and used (because
# during DataPrep loading, the available range is checked).
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment