Skip to content
Snippets Groups Projects
Commit 15b95a11 authored by lukas leufen's avatar lukas leufen
Browse files

implemented first methods and its tests

parent 430cc664
Branches
Tags
2 merge requests!9new version v0.2.0,!8data generator
__author__ = 'Felix Kleinert, Lukas Leufen'
__date__ = '2019-11-07'
import keras
from src import helpers
import os
from typing import Union, List
import decimal
import numpy as np
class DataGenerator(keras.utils.Sequence):
"""
This class is a generator to handle large arrays for machine learning. This class can be used with keras'
fit_generator and predict_generator. Individual stations are the iterables. This class uses class Dataprep and
returns X, y when an item is called.
Item can be called manually by position (integer) or station id (string). Methods also accept lists with exactly
one entry of integer or string
"""
def __init__(self, path: str, network: str, stations: Union[str, List[str]], variables: List[str], dim: str,
target_dim: str, target_var: str, **kwargs):
self.path = os.path.abspath(path)
self.network = network
self.stations = helpers.to_list(stations)
self.variables = variables
self.dim = dim
self.target_dim = target_dim
self.target_var = target_var
self.kwargs = kwargs
self.threshold = self.threshold_setup()
def __repr__(self):
"""
display all class attributes
"""
return f"DataGenerator(path='{self.path}', network='{self.network}', stations={self.stations}, "\
f"variables={self.variables}, dim='{self.dim}', target_dim='{self.target_dim}', target_var='" \
f"{self.target_var}', **{self.kwargs})"
def __len__(self):
"""
display the number of stations
"""
return len(self.stations)
def __iter__(self):
self.iterator = 0
return self
def __next__(self):
raise NotImplementedError
def __getitem__(self, item):
raise NotImplementedError
def threshold_setup(self) -> List[str]:
"""
set threshold for given min/max and number of steps. defaults are [0, 100] with n=200 steps
:return:
"""
thr_min = self.kwargs.get('thr_min', 0)
thr_max = self.kwargs.get('thr_max', 100)
thr_number_of_steps = self.kwargs.get('thr_number_of_steps', 200)
return [str(decimal.Decimal("%.4f" % e)) for e in np.linspace(thr_min, thr_max, thr_number_of_steps)]
import pytest
import os
from src.data_generator import DataGenerator
import logging
import numpy as np
import xarray as xr
import datetime as dt
import pandas as pd
from operator import itemgetter
class TestDataGenerator:
@pytest.fixture
def gen(self):
return DataGenerator('data', 'UBA', 'DEBW107', ['o3', 'temp'], 'datetime', 'datetime', 'o3')
def test_init(self, gen):
assert gen.path == os.path.abspath('data')
assert gen.network == 'UBA'
assert gen.stations == ['DEBW107']
assert gen.variables == ['o3', 'temp']
assert gen.dim == 'datetime'
assert gen.target_dim == 'datetime'
assert gen.target_var == 'o3'
assert gen.threshold is not None
def test_repr(self, gen):
path = os.path.join(os.path.dirname(__file__), 'data')
assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\
f"variables=['o3', 'temp'], dim='datetime', target_dim='datetime', " \
f"target_var='o3', **{{}})".rstrip()
def test_len(self, gen):
assert len(gen) == 1
gen.stations = ['station1', 'station2', 'station3']
assert len(gen) == 3
def test_threshold_setup(self, gen):
def res(arg, val):
gen.kwargs[arg] = val
return list(map(float, gen.threshold_setup()))
compare = np.testing.assert_array_almost_equal
assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None
assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None
assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None
assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment