From 468e4383ed393320132de6a745206307fe3b21d8 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 22 Nov 2019 13:33:48 +0100
Subject: [PATCH] worked on split methods

---
 run.py                  |  4 +--
 src/experiment_setup.py |  6 ++++
 src/modules.py          | 74 +++++++++++++++++++++++++++++++++++++----
 test/test_modules.py    | 56 ++++++++++++++++++++++++++++---
 4 files changed, 128 insertions(+), 12 deletions(-)

diff --git a/run.py b/run.py
index 6b115cf2..5e092698 100644
--- a/run.py
+++ b/run.py
@@ -11,7 +11,7 @@ from src.modules import run, PreProcessing, Training, PostProcessing
 def main():
 
     with run():
-        exp_setup = ExperimentSetup(args, trainable=True)
+        exp_setup = ExperimentSetup(args, trainable=True, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])
 
         PreProcessing(exp_setup)
 
@@ -23,7 +23,7 @@ def main():
 if __name__ == "__main__":
 
     formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
-    logging.basicConfig(format=formatter, level=logging.DEBUG)
+    logging.basicConfig(format=formatter, level=logging.INFO)
 
     parser = argparse.ArgumentParser()
     parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
diff --git a/src/experiment_setup.py b/src/experiment_setup.py
index d8cf04ec..4fc14573 100644
--- a/src/experiment_setup.py
+++ b/src/experiment_setup.py
@@ -29,6 +29,9 @@ class ExperimentSetup(object):
         self.interpolate_dim = None
         self.target_dim = None
         self.target_var = None
+        self.train_kwargs = None
+        self.val_kwargs = None
+        self.test_kwargs = None
         self.setup_experiment(**kwargs)
 
     def _set_param(self, param, value, default=None):
@@ -86,3 +89,6 @@ class ExperimentSetup(object):
         self._set_param("interpolate_dim", kwargs, default='datetime')
         self._set_param("target_dim", kwargs, default='variables')
         self._set_param("target_var", kwargs, default="o3")
+        self._set_param("train_kwargs", kwargs, default={"start": "1997-01-01", "end": "2007-12-31"})
+        self._set_param("val_kwargs", kwargs, default={"start": "2008-01-01", "end": "2009-12-31"})
+        self._set_param("test_kwargs", kwargs, default={"start": "2010-01-01", "end": "2017-12-31"})
diff --git a/src/modules.py b/src/modules.py
index 85c0d413..01f7ed67 100644
--- a/src/modules.py
+++ b/src/modules.py
@@ -4,7 +4,7 @@ import time
 from src.data_generator import DataGenerator
 from src.experiment_setup import ExperimentSetup
 import argparse
-from typing import Dict, List
+from typing import Dict, List, Any, Tuple
 
 
 class run(object):
@@ -63,15 +63,77 @@ class PreProcessing(run):
         kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13,
                   'window_lead_time': 3, 'interpolate_method': 'linear',
                   'statistics_per_var': self.setup.var_all_dict, }
-        valid_stations = self.check_valid_stations(self.setup.__dict__, kwargs, self.setup.stations)
         args = self.setup.__dict__
-        args["stations"] = valid_stations
+        valid_stations = self.check_valid_stations(args, kwargs, self.setup.stations)
+        args = self.update_key(args, "stations", valid_stations)
         data_gen = DataGenerator(**args, **kwargs)
-        train, val, test = self.split_train_val_test()
+        train, val, test = self.split_train_val_test(data_gen, valid_stations, args, kwargs)
+        # print stats of data
+
+    def split_train_val_test(self, data, stations, args, kwargs):
+        train_index, val_index, test_index = self.split_set_indices(len(stations), args["fraction_of_training"])
+        train = self.create_set_split(stations, args, kwargs, train_index, "train")
+        val = self.create_set_split(stations, args, kwargs, val_index, "val")
+        test = self.create_set_split(stations, args, kwargs, test_index, "test")
+        return train, val, test
+
+    @staticmethod
+    def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice]:
+        """
+        create the training, validation and test subset slice indices for given total_length. The test data consists on
+        (1-fraction) of total_length (fraction*len:end). Train and validation data therefore are made from fraction of
+        total_length (0:fraction*len). Train and validation data is split by the factor 0.8 for train and 0.2 for
+        validation.
+        :param total_length: list with all objects to split
+        :param fraction: ratio between test and union of train/val data
+        :return: slices for each subset in the order: train, val, test
+        """
+        pos_test_split = int(total_length * fraction)
+        train_index = slice(0, int(pos_test_split * 0.8))
+        val_index = slice(int(pos_test_split * 0.8), pos_test_split)
+        test_index = slice(pos_test_split, total_length)
+        return train_index, val_index, test_index
+
+    def create_set_split(self, stations, args, kwargs, index_list, set_name):
+        if args["use_all_stations_on_all_data_sets"]:
+            set_stations = stations
+        else:
+            set_stations = stations[index_list]
+        logging.debug(f"{set_name.capitalize()} stations (len={set_stations}): {set_stations}")
+        set_kwargs = self.update_kwargs(args, kwargs, f"{set_name}_kwargs")
+        set_stations = self.check_valid_stations(args, set_kwargs, set_stations)
+        set_args = self.update_key(args, "stations", set_stations)
+        data_set = DataGenerator(**set_args, **set_kwargs)
+        return data_set
 
     @staticmethod
-    def split_train_val_test():
-        return None, None, None
+    def update_key(orig_dict: Dict, key: str, value: Any) -> Dict:
+        """
+        create copy of `orig_dict` and update given key by value, returns a copied dict. The original input dict
+        `orig_dict` is not modified by this function.
+        :param orig_dict: dictionary with arguments that should be updated
+        :param key: the key to update
+        :param value: the update itself for given key
+        :return: updated dict
+        """
+        updated = orig_dict.copy()
+        updated.update({key: value})
+        return updated
+
+    @staticmethod
+    def update_kwargs(args: Dict, kwargs: Dict, kwargs_name: str):
+        """
+        copy kwargs and update kwargs parameters by another dictionary stored in args. Not existing keys in kwargs are
+        created, existing keys overwritten.
+        :param args: dict with the new kwargs parameters stored with key `kwargs_name`
+        :param kwargs: dict to update
+        :param kwargs_name: key in `args` to find the updates for `kwargs`
+        :return: updated kwargs dict
+        """
+        kwargs_updated = kwargs.copy()
+        if kwargs_name in args.keys() and args[kwargs_name]:
+            kwargs_updated.update(args[kwargs_name])
+        return kwargs_updated
 
     @staticmethod
     def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]):
diff --git a/test/test_modules.py b/test/test_modules.py
index 02b49b28..3211a8e8 100644
--- a/test/test_modules.py
+++ b/test/test_modules.py
@@ -4,8 +4,10 @@ from src.modules import run, PreProcessing
 from src.helpers import TimeTracking
 import src.helpers
 from src.experiment_setup import ExperimentSetup
+from src.data_generator import DataGenerator
 import re
 import mock
+import numpy as np
 
 
 class pytest_regex:
@@ -29,7 +31,7 @@ class TestRun:
             assert caplog.record_tuples[-1] == ('root', 20, 'run started')
             assert isinstance(r.time, TimeTracking)
             r.do_stuff(0.1)
-        assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s"))
+        assert caplog.record_tuples[-1] == ('root', 20, pytest_regex(r"run finished after \d+\.\d+s"))
 
     def test_init_del(self, caplog):
         caplog.set_level(logging.INFO)
@@ -37,7 +39,7 @@ class TestRun:
         assert caplog.record_tuples[-1] == ('root', 20, 'run started')
         r.do_stuff(0.2)
         del r
-        assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s"))
+        assert caplog.record_tuples[-1] == ('root', 20, pytest_regex(r"run finished after \d+\.\d+s"))
 
 
 class TestPreProcessing:
@@ -49,7 +51,7 @@ class TestPreProcessing:
         pre = PreProcessing(setup)
         assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
         assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
-        assert caplog.record_tuples[-1] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)'))
+        assert caplog.record_tuples[-1] == ('root', 20, pytest_regex(r'run for \d+\.\d+s to check 5 station\(s\)'))
 
     def test_run(self):
         pre_processing = object.__new__(PreProcessing)
@@ -73,4 +75,50 @@ class TestPreProcessing:
         valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations)
         assert valids == pre.setup.stations
         assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started')
-        assert caplog.record_tuples[1] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)'))
+        assert caplog.record_tuples[1] == ('root', 20, pytest_regex(r'run for \d+\.\d+s to check 5 station\(s\)'))
+
+    def test_update_kwargs(self):
+        args = {"testName": {"testAttribute": "TestValue", "optional": "2019-11-21"}}
+        kwargs = {"testAttribute": "DefaultValue", "defaultAttribute": 3}
+        updated = PreProcessing.update_kwargs(args, kwargs, "testName")
+        assert updated == {"testAttribute": "TestValue", "defaultAttribute": 3, "optional": "2019-11-21"}
+        assert kwargs == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
+        args = {"testName": None}
+        updated = PreProcessing.update_kwargs(args, kwargs, "testName")
+        assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
+        args = {"dummy": "notMeaningful"}
+        updated = PreProcessing.update_kwargs(args, kwargs, "testName")
+        assert updated == {"testAttribute": "DefaultValue", "defaultAttribute": 3}
+
+    def test_update_key(self):
+        orig_dict = {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]}
+        f = PreProcessing.update_key
+        assert f(orig_dict, "Test2", 4) == {"Test1": 3, "Test2": 4, "test3": [1, 2, 3]}
+        assert orig_dict == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3]}
+        assert f(orig_dict, "Test3", 4) == {"Test1": 3, "Test2": "4", "test3": [1, 2, 3], "Test3": 4}
+
+    def test_split_set_indices(self):
+        dummy_list = list(range(0, 15))
+        train, val, test = PreProcessing.split_set_indices(len(dummy_list), 0.9)
+        assert dummy_list[train] == list(range(0, 10))
+        assert dummy_list[val] == list(range(10, 13))
+        assert dummy_list[test] == list(range(13, 15))
+
+    @mock.patch("DataGenerator", return_value=object.__new__(DataGenerator))
+    @mock.patch("DataGenerator[station]", return_value=(np.ones(10), np.zeros(10)))
+    def test_create_set_split(self):
+        stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
+        pre = object.__new__(PreProcessing)
+        pre.setup = ExperimentSetup({}, stations=stations, var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'},
+                                    train_kwargs={"start": "2000-01-01", "end": "2007-12-31"})
+        kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'statistics_per_var': pre.setup.var_all_dict, }
+        train = pre.create_set_split(stations, pre.setup.__dict__, kwargs, slice(0, 3), "train")
+        # stopped here. It is a mess with all the different kwargs, args etc. Restructure the idea of how to implement
+        # the data sets. Because there are multiple kwargs declarations and which counts in the end. And there are
+        # multiple declarations of  the DataGenerator class. Why this? Is it somehow possible, to select elements from
+        # this iterator class. Furthermore the names of the DataPrep class is not distinct, because there is no time
+        # range provided in file's name. Given the case, that first to total DataGen is called with a short period for
+        # data loading. But then, for the data split (I don't know why this could happen, but it is very likely because
+        # osf the current multiple declarations of kwargs arguments) the desired time range exceeds the previou
+        # mentioned and short time range. But nevertheless, the file with the short period is loaded and used (because
+        # during DataPrep loading, the available range is checked).
-- 
GitLab