diff --git a/src/modules.py b/src/modules.py index 0e03a352dc7a2f3cd057e7166024f5ace2e6a55c..85c0d413c1128286678ef221716720e2082ff135 100644 --- a/src/modules.py +++ b/src/modules.py @@ -4,6 +4,7 @@ import time from src.data_generator import DataGenerator from src.experiment_setup import ExperimentSetup import argparse +from typing import Dict, List class run(object): @@ -12,52 +13,96 @@ class run(object): after finishing the measurement. The duration result is logged. """ + del_by_exit = False + def __init__(self): + """ + Starts time tracking automatically and logs as info. + """ self.time = TimeTracking() logging.info(f"{self.__class__.__name__} started") def __del__(self): - self.time.stop() - logging.info(f"{self.__class__.__name__} finished after {self.time}") + """ + This is the class finalizer. The code is not executed if already called by exit method to prevent duplicated + logging (__exit__ is always executed before __del__) it this class was used in a with statement. + """ + if not self.del_by_exit: + self.time.stop() + logging.info(f"{self.__class__.__name__} finished after {self.time}") + self.del_by_exit = True def __enter__(self): - pass + return self def __exit__(self, exc_type, exc_val, exc_tb): - pass + self.__del__() - def do_stuff(self): - time.sleep(2) + @staticmethod + def do_stuff(length=2): + time.sleep(length) class PreProcessing(run): - def __init__(self, setup): + """ + Pre-process your data by using this class. It includes time tracking and uses the experiment setup to look for data + and stores it if not already in local disk. Further, it provides this data as a generator and checks for valid + stations (in this context: valid=data available). Finally, it splits the data into valid training, validation and + testing subsets. + """ + + def __init__(self, experiment_setup: ExperimentSetup): super().__init__() - self.setup = setup + self.setup = experiment_setup self.kwargs = None + self.valid_stations = [] self._run() def _run(self): - self.kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit': 1, 'window_history': 13, - 'window_lead_time': 3, 'method': 'linear', - 'statistics_per_var': self.setup.var_all_dict, } - self.check_valid_stations() - - def check_valid_stations(self): - t = TimeTracking - logging.debug("check valid stations started") - window_lead_time = self.kwargs.get("window_lead_time", None) + kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13, + 'window_lead_time': 3, 'interpolate_method': 'linear', + 'statistics_per_var': self.setup.var_all_dict, } + valid_stations = self.check_valid_stations(self.setup.__dict__, kwargs, self.setup.stations) + args = self.setup.__dict__ + args["stations"] = valid_stations + data_gen = DataGenerator(**args, **kwargs) + train, val, test = self.split_train_val_test() + + @staticmethod + def split_train_val_test(): + return None, None, None + + @staticmethod + def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]): + """ + Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given + time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. + :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`, + `variables`, `interpolate_dim`, `target_dim`, `target_var`). + :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`, + `window_lead_time`). + :param all_stations: All stations to check. + :return: Corrected list containing only valid station IDs. + """ + t_outer = TimeTracking() + t_inner = TimeTracking(start=False) + logging.info("check valid stations started") valid_stations = [] - for s in self.setup.stations: - valid = False - args = self.setup.__dict__ - args["stations"] = s - - h = DataGenerator(**args, **self.kwargs) - da_it = h.get_data_generator(s) - print('hi') + # all required arguments of the DataGenerator can be found in args, positional arguments in args and kwargs + data_gen = DataGenerator(**args, **kwargs) + for station in all_stations: + t_inner.run() + try: + (history, label) = data_gen[station] + valid_stations.append(station) + logging.debug(f"{station}: history_shape = {history.shape}") + logging.debug(f"{station}: loading time = {t_inner}") + except AttributeError: + continue + logging.info(f"run for {t_outer} to check {len(all_stations)} station(s)") + return valid_stations class Training(run): @@ -82,7 +127,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None, help="set experiment date as string") - args = parser.parse_args() + parser_args = parser.parse_args() with run(): - setup = ExperimentSetup(args, test=True) + setup = ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) PreProcessing(setup) diff --git a/test/test_modules.py b/test/test_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..6048a7d29d7f3f8bcb42a0656d2c610e388c9b87 --- /dev/null +++ b/test/test_modules.py @@ -0,0 +1,71 @@ +import pytest +import logging +from src.modules import run, PreProcessing +from src.helpers import TimeTracking +import src.helpers +from src.experiment_setup import ExperimentSetup +import re +import mock + + +class pytest_regex: + """Assert that a given string meets some expectations.""" + + def __init__(self, pattern, flags=0): + self._regex = re.compile(pattern, flags) + + def __eq__(self, actual): + return bool(self._regex.match(actual)) + + def __repr__(self): + return self._regex.pattern + + +class TestRun: + + def test_enter_exit(self, caplog): + caplog.set_level(logging.INFO) + with run() as r: + assert caplog.record_tuples[-1] == ('root', 20, 'run started') + assert isinstance(r.time, TimeTracking) + r.do_stuff(0.1) + assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s")) + + def test_init_del(self, caplog): + caplog.set_level(logging.INFO) + r = run() + assert caplog.record_tuples[-1] == ('root', 20, 'run started') + r.do_stuff(0.2) + del r + assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s")) + + +class TestPreProcessing: + + def test_init(self, caplog): + caplog.set_level(logging.INFO) + setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) + pre = PreProcessing(setup) + assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') + assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') + assert caplog.record_tuples[2] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)')) + + def test_run(self): + pre_processing = object.__new__(PreProcessing) + pre_processing.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) + assert pre_processing._run() is None + + def test_split_train_val_test(self): + pass + + def test_check_valid_stations(self, caplog): + caplog.set_level(logging.INFO) + pre = object.__new__(PreProcessing) + pre.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) + kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13, + 'window_lead_time': 3, 'interpolate_method': 'linear', + 'statistics_per_var': pre.setup.var_all_dict, } + valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations) + assert valids == pre.setup.stations + assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started') + assert caplog.record_tuples[1] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)'))