Select Git revision
data_generator.py 5.98 KiB
__author__ = 'Felix Kleinert, Lukas Leufen'
__date__ = '2019-11-07'
import os
from typing import Union, List, Tuple
import keras
import xarray as xr
from src import helpers
from src.data_handling.data_preparation import DataPrep
class DataGenerator(keras.utils.Sequence):
"""
This class is a generator to handle large arrays for machine learning. This class can be used with keras'
fit_generator and predict_generator. Individual stations are the iterables. This class uses class Dataprep and
returns X, y when an item is called.
Item can be called manually by position (integer) or station id (string). Methods also accept lists with exactly
one entry of integer or string
"""
def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
window_lead_time: int = 4, transform_method: str = "standardise", **kwargs):
self.data_path = os.path.abspath(data_path)
self.network = network
self.stations = helpers.to_list(stations)
self.variables = variables
self.interpolate_dim = interpolate_dim
self.target_dim = target_dim
self.target_var = target_var
self.station_type = station_type
self.interpolate_method = interpolate_method
self.limit_nan_fill = limit_nan_fill
self.window_history_size = window_history_size
self.window_lead_time = window_lead_time
self.transform_method = transform_method
self.kwargs = kwargs
def __repr__(self):
"""
display all class attributes
"""
return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \
f"variables={self.variables}, station_type={self.station_type}, " \
f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \
f"target_var='{self.target_var}', **{self.kwargs})"
def __len__(self):
"""
display the number of stations
"""
return len(self.stations)
def __iter__(self) -> "DataGenerator":
"""
Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
`_iterator` to 0.
:return:
"""
self._iterator = 0
return self
def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]:
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
if self._iterator < self.__len__():
data = self.get_data_generator()
self._iterator += 1
if data.history is not None and data.label is not None:
return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window")
else:
self.__next__()
else:
raise StopIteration
def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]:
"""
Defines the get item method for this generator. Retrieve data from generator and return history and labels.
:param item: station key to choose the data generator.
:return: The generator's time series of history data and its labels
"""
data = self.get_data_generator(key=item)
return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window")
def get_data_generator(self, key: Union[str, int] = None) -> DataPrep:
"""
Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
remove nans.
:param key: station key to choose the data generator.
:return: preprocessed data as a DataPrep instance
"""
station = self.get_station_key(key)
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs)
data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
data.transform("datetime", method=self.transform_method)
data.make_history_window(self.interpolate_dim, self.window_history_size)
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.history_label_nan_remove(self.interpolate_dim)
return data
def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str:
"""
Return a valid station key or raise KeyError if this wasn't possible
:param key: station key to choose the data generator.
:return: station key (id from database)
"""
# extract value if given as list
if isinstance(key, list):
if len(key) == 1:
key = key[0]
else:
raise KeyError(f"More than one key was given: {key}")
# return station name either from key or the recent element from iterator
if key is None:
return self.stations[self._iterator]
else:
if isinstance(key, int):
if key < self.__len__():
return self.stations[key]
else:
raise KeyError(f"{key} is not in range(0, {self.__len__()})")
elif isinstance(key, str):
if key in self.stations:
return key
else:
raise KeyError(f"{key} is not in stations")
else:
raise KeyError(f"Key has to be from Union[str, int]. Given was {key} ({type(key)})")