Skip to content
Snippets Groups Projects
Commit caf82558 authored by lukas leufen's avatar lukas leufen
Browse files

can load data during pre-processing. /close #13

parent e1144c96
Branches
Tags
2 merge requests!17update to v0.4.0,!15new feat split subsets
...@@ -4,6 +4,7 @@ import time ...@@ -4,6 +4,7 @@ import time
from src.data_generator import DataGenerator from src.data_generator import DataGenerator
from src.experiment_setup import ExperimentSetup from src.experiment_setup import ExperimentSetup
import argparse import argparse
from typing import Dict, List
class run(object): class run(object):
...@@ -12,52 +13,96 @@ class run(object): ...@@ -12,52 +13,96 @@ class run(object):
after finishing the measurement. The duration result is logged. after finishing the measurement. The duration result is logged.
""" """
del_by_exit = False
def __init__(self): def __init__(self):
"""
Starts time tracking automatically and logs as info.
"""
self.time = TimeTracking() self.time = TimeTracking()
logging.info(f"{self.__class__.__name__} started") logging.info(f"{self.__class__.__name__} started")
def __del__(self): def __del__(self):
"""
This is the class finalizer. The code is not executed if already called by exit method to prevent duplicated
logging (__exit__ is always executed before __del__) it this class was used in a with statement.
"""
if not self.del_by_exit:
self.time.stop() self.time.stop()
logging.info(f"{self.__class__.__name__} finished after {self.time}") logging.info(f"{self.__class__.__name__} finished after {self.time}")
self.del_by_exit = True
def __enter__(self): def __enter__(self):
pass return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
pass self.__del__()
def do_stuff(self): @staticmethod
time.sleep(2) def do_stuff(length=2):
time.sleep(length)
class PreProcessing(run): class PreProcessing(run):
def __init__(self, setup): """
Pre-process your data by using this class. It includes time tracking and uses the experiment setup to look for data
and stores it if not already in local disk. Further, it provides this data as a generator and checks for valid
stations (in this context: valid=data available). Finally, it splits the data into valid training, validation and
testing subsets.
"""
def __init__(self, experiment_setup: ExperimentSetup):
super().__init__() super().__init__()
self.setup = setup self.setup = experiment_setup
self.kwargs = None self.kwargs = None
self.valid_stations = []
self._run() self._run()
def _run(self): def _run(self):
self.kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit': 1, 'window_history': 13, kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13,
'window_lead_time': 3, 'method': 'linear', 'window_lead_time': 3, 'interpolate_method': 'linear',
'statistics_per_var': self.setup.var_all_dict, } 'statistics_per_var': self.setup.var_all_dict, }
self.check_valid_stations() valid_stations = self.check_valid_stations(self.setup.__dict__, kwargs, self.setup.stations)
def check_valid_stations(self):
t = TimeTracking
logging.debug("check valid stations started")
window_lead_time = self.kwargs.get("window_lead_time", None)
valid_stations = []
for s in self.setup.stations:
valid = False
args = self.setup.__dict__ args = self.setup.__dict__
args["stations"] = s args["stations"] = valid_stations
data_gen = DataGenerator(**args, **kwargs)
train, val, test = self.split_train_val_test()
h = DataGenerator(**args, **self.kwargs) @staticmethod
da_it = h.get_data_generator(s) def split_train_val_test():
print('hi') return None, None, None
@staticmethod
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]):
"""
Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
:param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
`variables`, `interpolate_dim`, `target_dim`, `target_var`).
:param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`,
`window_lead_time`).
:param all_stations: All stations to check.
:return: Corrected list containing only valid station IDs.
"""
t_outer = TimeTracking()
t_inner = TimeTracking(start=False)
logging.info("check valid stations started")
valid_stations = []
# all required arguments of the DataGenerator can be found in args, positional arguments in args and kwargs
data_gen = DataGenerator(**args, **kwargs)
for station in all_stations:
t_inner.run()
try:
(history, label) = data_gen[station]
valid_stations.append(station)
logging.debug(f"{station}: history_shape = {history.shape}")
logging.debug(f"{station}: loading time = {t_inner}")
except AttributeError:
continue
logging.info(f"run for {t_outer} to check {len(all_stations)} station(s)")
return valid_stations
class Training(run): class Training(run):
...@@ -82,7 +127,7 @@ if __name__ == "__main__": ...@@ -82,7 +127,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None, parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
help="set experiment date as string") help="set experiment date as string")
args = parser.parse_args() parser_args = parser.parse_args()
with run(): with run():
setup = ExperimentSetup(args, test=True) setup = ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])
PreProcessing(setup) PreProcessing(setup)
import pytest
import logging
from src.modules import run, PreProcessing
from src.helpers import TimeTracking
import src.helpers
from src.experiment_setup import ExperimentSetup
import re
import mock
class pytest_regex:
"""Assert that a given string meets some expectations."""
def __init__(self, pattern, flags=0):
self._regex = re.compile(pattern, flags)
def __eq__(self, actual):
return bool(self._regex.match(actual))
def __repr__(self):
return self._regex.pattern
class TestRun:
def test_enter_exit(self, caplog):
caplog.set_level(logging.INFO)
with run() as r:
assert caplog.record_tuples[-1] == ('root', 20, 'run started')
assert isinstance(r.time, TimeTracking)
r.do_stuff(0.1)
assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s"))
def test_init_del(self, caplog):
caplog.set_level(logging.INFO)
r = run()
assert caplog.record_tuples[-1] == ('root', 20, 'run started')
r.do_stuff(0.2)
del r
assert caplog.record_tuples[-1] == ('root', 20, pytest_regex("run finished after \d+\.\d+s"))
class TestPreProcessing:
def test_init(self, caplog):
caplog.set_level(logging.INFO)
setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])
pre = PreProcessing(setup)
assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[2] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)'))
def test_run(self):
pre_processing = object.__new__(PreProcessing)
pre_processing.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])
assert pre_processing._run() is None
def test_split_train_val_test(self):
pass
def test_check_valid_stations(self, caplog):
caplog.set_level(logging.INFO)
pre = object.__new__(PreProcessing)
pre.setup = ExperimentSetup({}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])
kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit_nan_fill': 1, 'window_history': 13,
'window_lead_time': 3, 'interpolate_method': 'linear',
'statistics_per_var': pre.setup.var_all_dict, }
valids = pre.check_valid_stations(pre.setup.__dict__, kwargs, pre.setup.stations)
assert valids == pre.setup.stations
assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[1] == ('root', 20, pytest_regex('run for \d+\.\d+s to check 5 station\(s\)'))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment