From 468e4383ed393320132de6a745206307fe3b21d8 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 22 Nov 2019 13:33:48 +0100 Subject: [PATCH] worked on split methods --- run.py | 4 +-- src/experiment_setup.py | 6 ++++ src/modules.py | 74 +++++++++++++++++++++++++++++++++++++---- test/test_modules.py | 56 ++++++++++++++++++++++++++++--- 4 files changed, 128 insertions(+), 12 deletions(-) diff --git a/run.py b/run.py index 6b115cf2..5e092698 100644 --- a/run.py +++ b/run.py @@ -11,7 +11,7 @@ from src.modules import run, PreProcessing, Training, PostProcessing def main(): with run(): - exp_setup = ExperimentSetup(args, trainable=True) + exp_setup = ExperimentSetup(args, trainable=True, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) PreProcessing(exp_setup) @@ -23,7 +23,7 @@ def main(): if __name__ == "__main__": formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' - logging.basicConfig(format=formatter, level=logging.DEBUG) + logging.basicConfig(format=formatter, level=logging.INFO) parser = argparse.ArgumentParser() parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None, diff --git a/src/experiment_setup.py b/src/experiment_setup.py index d8cf04ec..4fc14573 100644 --- a/src/experiment_setup.py +++ b/src/experiment_setup.py @@ -29,6 +29,9 @@ class ExperimentSetup(object): self.interpolate_dim = None self.target_dim = None self.target_var = None + self.train_kwargs = None + self.val_kwargs = None + self.test_kwargs = None self.setup_experiment(**kwargs) def _set_param(self, param, value, default=None): @@ -86,3 +89,6 @@ class ExperimentSetup(object): self._set_param("interpolate_dim", kwargs, default='datetime') self._set_param("target_dim", kwargs, default='variables') self._set_param("target_var", kwargs, default="o3") + self._set_param("train_kwargs", kwargs, default={"start": "1997-01-01", "end": "2007-12-31"}) + self._set_param("val_kwargs", kwargs, default={"start": "2008-01-01", "end": "2009-12-31"}) + self._set_param("test_kwargs", kwargs, default={"start": "2010-01-01", "end": "2017-12-31"}) diff --git a/src/modules.py b/src/modules.py index 85c0d413..01f7ed67 100644 --- a/src/modules.py +++ b/src/modules.py @@ -4,7 +4,7 @@ import time from src.data_generator import DataGenerator from src.experiment_setup import ExperimentSetup import argparse -from typing import Dict, List +from typing import Dict, List, Any, Tuple class run(object): @@ -63,15 +63,77 @@ class PreProcessing(run): kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13, 'window_lead_time': 3, 'interpolate_method': 'linear', 'statistics_per_var': self.setup.var_all_dict, } - valid_stations = self.check_valid_stations(self.setup.__dict__, kwargs, self.setup.stations) args = self.setup.__dict__ - args["stations"] = valid_stations + valid_stations = self.check_valid_stations(args, kwargs, self.setup.stations) + args = self.update_key(args, "stations", valid_stations) data_gen = DataGenerator(**args, **kwargs) - train, val, test = self.split_train_val_test() + train, val, test = self.split_train_val_test(data_gen, valid_stations, args, kwargs) + # print stats of data + + def split_train_val_test(self, data, stations, args, kwargs): + train_index, val_index, test_index = self.split_set_indices(len(stations), args["fraction_of_training"]) + train = self.create_set_split(stations, args, kwargs, train_index, "train") + val = self.create_set_split(stations, args, kwargs, val_index, "val") + test = self.create_set_split(stations, args, kwargs, test_index, "test") + return train, val, test + + @staticmethod + def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice]: + """ + create the training, validation and test subset slice indices for given total_length. The test data consists on + (1-fraction) of total_length (fraction*len:end). Train and validation data therefore are made from fraction of + total_length (0:fraction*len). Train and validation data is split by the factor 0.8 for train and 0.2 for + validation. + :param total_length: list with all objects to split + :param fraction: ratio between test and union of train/val data + :return: slices for each subset in the order: train, val, test + """ + pos_test_split = int(total_length * fraction) + train_index = slice(0, int(pos_test_split * 0.8)) + val_index = slice(int(pos_test_split * 0.8), pos_test_split) + test_index = slice(pos_test_split, total_length) + return train_index, val_index, test_index + + def create_set_split(self, stations, args, kwargs, index_list, set_name): + if args["use_all_stations_on_all_data_sets"]: + set_stations = stations + else: + set_stations = stations[index_list] + logging.debug(f"{set_name.capitalize()} stations (len={set_stations}): {set_stations}") + set_kwargs = self.update_kwargs(args, kwargs, f"{set_name}_kwargs") + set_stations = self.check_valid_stations(args, set_kwargs, set_stations) + set_args = self.update_key(args, "stations", set_stations) + data_set = DataGenerator(**set_args, **set_kwargs) + return data_set @staticmethod - def split_train_val_test(): - return None, None, None + def update_key(orig_dict: Dict, key: str, value: Any) -> Dict: + """ + create copy of `orig_dict` and update given key by value, returns a copied dict. The original input dict + `orig_dict` is not modified by this function. + :param orig_dict: dictionary with arguments that should be updated + :param key: the key to update + :param value: the update itself for given key + :return: updated dict + """ + updated = orig_dict.copy() + updated.update({key: value}) + return updated + + @staticmethod + def update_kwargs(args: Dict, kwargs: Dict, kwargs_name: str): + """ + copy kwargs and update kwargs parameters by another dictionary stored in args. Not existing keys in kwargs are + created, existing keys overwritten. + :param args: dict with the new kwargs parameters stored with key `kwargs_name` + :param kwargs: dict to update + :param kwargs_name: key in `args` to find the updates for `kwargs` + :return: updated kwargs dict + """ + kwargs_updated = kwargs.copy() + if kwargs_name in args.keys() and args[kwargs_name]: + kwargs_updated.update(args[kwargs_name]) + return kwargs_updated @staticmethod def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]): diff --git a/test/test_modules.py b/test/test_modules.py index 02b49b28..3211a8e8 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -4,8 +4,10 @@ from src.modules import run, PreProcessing from src.helpers import TimeTracking import src.helpers from src.experiment_setup import ExperimentSetup +from src.data_generator import DataGenerator import re import mock +import numpy as np class pytest_regex: @@ -29,7 +31,7 @@ class TestRun: assert caplog.record_tuples[-1] == ('root', 20, 'run started') assert isinstance(r.time, TimeTracking) r.do_stuff(0.1) - assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s")) + assert caplog.record_tuples[-1] == ('root', 20, pytest_regex(r"run finished after \d+\.\d+s")) def test_init_del(self, caplog): caplog.set_level(logging.INFO) @@ -37,7 +39,7 @@ class TestRun: assert caplog.record_tuples[-1] == ('root', 20, 'run started') r.do_stuff(0.2) del r - assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s")) + assert caplog.record_tuples[-1] == ('root', 20, pytest_regex(r"run finished after \d+\.\d+s")) class TestPreProcessing: @@ -49,7 +51,7 @@ class TestPreProcessing: pre = PreProcessing(setup) assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[-1] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)')) + assert caplog.record_tuples[-1] == ('root', 20, pytest_regex(r'run for \d+\.\d+s to check 5 station\(s\)')) def test_run(self): pre_processing = object.__new__(PreProcessing) @@ -73,4 +75,50 @@ class TestPreProcessing: valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations) assert valids == pre.setup.stations assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[1] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)')) + assert caplog.record_tuples[1] == ('root', 20, pytest_regex(r'run for \d+\.\d+s to check 5 station\(s\)')) + + def test_update_kwargs(self): + args = {"testName": {"testAttribute": "TestValue", "optional": "2019-11-21"}} + kwargs = {"testAttribute": "DefaultValue", "defaultAttribute": 3} + updated = PreProcessing.update_kwargs(args, kwargs, "testName") + assert updated == {"testAttribute": "TestValue", "defaultAttribute": 3, "optional": "2019-11-21"} + assert kwargs == {"testAttribute": "DefaultValue", "defaultAttribute": 3} + args = {"testName": None} + updated = PreProcessing.update_kwargs(args, kwargs, "testName") + assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3} + args = {"dummy": "notMeaningful"} + updated = PreProcessing.update_kwargs(args, kwargs, "testName") + assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3} + + def test_update_key(self): + orig_dict = {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]} + f = PreProcessing.update_key + assert f(orig_dict, "Test2", 4) == {"Test1": 3, "Test2": 4, "test3": [1, 2, 3]} + assert orig_dict == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]} + assert f(orig_dict, "Test3", 4) == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3], "Test3": 4} + + def test_split_set_indices(self): + dummy_list = list(range(0, 15)) + train, val, test = PreProcessing.split_set_indices(len(dummy_list), 0.9) + assert dummy_list[train] == list(range(0, 10)) + assert dummy_list[val] == list(range(10, 13)) + assert dummy_list[test] == list(range(13, 15)) + + @mock.patch("DataGenerator", return_value=object.__new__(DataGenerator)) + @mock.patch("DataGenerator[station]", return_value=(np.ones(10), np.zeros(10))) + def test_create_set_split(self): + stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] + pre = object.__new__(PreProcessing) + pre.setup = ExperimentSetup({}, stations=stations, var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, + train_kwargs={"start": "2000-01-01", "end": "2007-12-31"}) + kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'statistics_per_var': pre.setup.var_all_dict, } + train = pre.create_set_split(stations, pre.setup.__dict__, kwargs, slice(0, 3), "train") + # stopped here. It is a mess with all the different kwargs, args etc. Restructure the idea of how to implement + # the data sets. Because there are multiple kwargs declarations and which counts in the end. And there are + # multiple declarations of the DataGenerator class. Why this? Is it somehow possible, to select elements from + # this iterator class. Furthermore the names of the DataPrep class is not distinct, because there is no time + # range provided in file's name. Given the case, that first to total DataGen is called with a short period for + # data loading. But then, for the data split (I don't know why this could happen, but it is very likely because + # osf the current multiple declarations of kwargs arguments) the desired time range exceeds the previou + # mentioned and short time range. But nevertheless, the file with the short period is loaded and used (because + # during DataPrep loading, the available range is checked). -- GitLab