Skip to content
Snippets Groups Projects

Resolve "new tests for advanced data handler"

2 files
+ 102
117
Compare changes
  • Side-by-side
  • Inline
Files
2
__author__ = 'Lukas Leufen'
__date__ = '2020-07-08'
import numpy as np
import xarray as xr
import os
import pandas as pd
import datetime as dt
from mlair.data_handler import AbstractDataHandler
from typing import Union, List, Tuple, Dict
import logging
from functools import reduce
from mlair.helpers.join import EmptyQueryResult
from mlair.helpers import TimeTracking
number = Union[float, int]
num_or_list = Union[number, List[number]]
def run_data_prep():
from .data_handler_neighbors import DataHandlerNeighbors
data = DummyDataHandler("main_class")
data.get_X()
data.get_Y()
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
data_prep = DataHandlerNeighbors(DummyDataHandler("main_class"),
path,
neighbors=[DummyDataHandler("neighbor1"),
DummyDataHandler("neighbor2")],
extreme_values=[1., 1.2])
data_prep.get_data(upsampling=False)
def create_data_prep():
from .data_handler_neighbors import DataHandlerNeighbors
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
station_type = None
network = 'UBA'
sampling = 'daily'
target_dim = 'variables'
target_var = 'o3'
interpolation_dim = 'datetime'
window_history_size = 7
window_lead_time = 3
central_station = DataHandlerSingleStation("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim,
target_var, interpolation_dim, window_history_size, window_lead_time)
neighbor1 = DataHandlerSingleStation("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {}, station_type, network, sampling, target_dim,
target_var, interpolation_dim, window_history_size, window_lead_time)
neighbor2 = DataHandlerSingleStation("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim,
target_var, interpolation_dim, window_history_size, window_lead_time)
data_prep = []
data_prep.append(DataHandlerNeighbors(central_station, path, neighbors=[neighbor1, neighbor2]))
data_prep.append(DataHandlerNeighbors(neighbor1, path, neighbors=[central_station, neighbor2]))
data_prep.append(DataHandlerNeighbors(neighbor2, path, neighbors=[neighbor1, central_station]))
return data_prep
class DummyDataHandler(AbstractDataHandler):
def __init__(self, name, number_of_samples=None):
"""This data handler takes a name argument and the number of samples to generate. If not provided, a random
number between 100 and 150 is set."""
super().__init__()
self.name = name
self.number_of_samples = number_of_samples if number_of_samples is not None else np.random.randint(100, 150)
self._X = self.create_X()
self._Y = self.create_Y()
def create_X(self):
"""Inputs are random numbers between 0 and 10 with shape (no_samples, window=14, variables=5)."""
X = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables
datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
return xr.DataArray(X, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
"window": range(14),
"variables": range(5)})
def create_Y(self):
"""Targets are normal distributed random numbers with shape (no_samples, window=5, variables=1)."""
Y = np.round(0.5 * np.random.randn(self.number_of_samples, 5, 1), 1) # samples, window, variables
datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
return xr.DataArray(Y, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
"window": range(5),
"variables": range(1)})
def get_X(self, upsampling=False, as_numpy=False):
"""Upsampling parameter is not used for X."""
return np.copy(self._X) if as_numpy is True else self._X
def get_Y(self, upsampling=False, as_numpy=False):
"""Upsampling parameter is not used for Y."""
return np.copy(self._Y) if as_numpy is True else self._Y
def __str__(self):
return self.name
if __name__ == "__main__":
from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
from mlair.data_handler.iterator import KerasIterator, DataCollection
data_prep = create_data_prep()
data_collection = DataCollection(data_prep)
for data in data_collection:
print(data)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras")
keras_it = KerasIterator(data_collection, 100, path, upsampling=True)
keras_it[2]
Loading