Select Git revision
mpi4py-3.0.3-ipsmpi-2020-Python-3.8.5.eb
data_generator.py NaN GiB
__author__ = 'Felix Kleinert, Lukas Leufen'
__date__ = '2019-11-07'
import os
from typing import Union, List, Tuple, Any, Dict
import keras
import xarray as xr
import pickle
import logging
from src import helpers
from src.data_handling.data_preparation import DataPrep
from src.join import EmptyQueryResult
class DataGenerator(keras.utils.Sequence):
"""
This class is a generator to handle large arrays for machine learning. This class can be used with keras'
fit_generator and predict_generator. Individual stations are the iterables. This class uses class Dataprep and
returns X, y when an item is called.
Item can be called manually by position (integer) or station id (string). Methods also accept lists with exactly
one entry of integer or string
"""
def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
window_lead_time: int = 4, transformation: Dict = None, **kwargs):
self.data_path = os.path.abspath(data_path)
self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
if not os.path.exists(self.data_path_tmp):
os.makedirs(self.data_path_tmp)
self.network = network
self.stations = helpers.to_list(stations)
self.variables = variables
self.interpolate_dim = interpolate_dim
self.target_dim = target_dim
self.target_var = target_var
self.station_type = station_type
self.interpolate_method = interpolate_method
self.limit_nan_fill = limit_nan_fill
self.window_history_size = window_history_size
self.window_lead_time = window_lead_time
self.kwargs = kwargs
self.transformation = self.setup_transformation(transformation)
def __repr__(self):
"""
display all class attributes
"""
return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \
f"variables={self.variables}, station_type={self.station_type}, " \
f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \
f"target_var='{self.target_var}', **{self.kwargs})"
def __len__(self):
"""
display the number of stations
"""
return len(self.stations)
def __iter__(self) -> "DataGenerator":
"""
Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
`_iterator` to 0.
:return:
"""
self._iterator = 0
return self
def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]:
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
if self._iterator < self.__len__():
data = self.get_data_generator()
self._iterator += 1
if data.history is not None and data.label is not None: # pragma: no branch
return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window")
else:
self.__next__() # pragma: no cover
else:
raise StopIteration
def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]:
"""
Defines the get item method for this generator. Retrieve data from generator and return history and labels.
:param item: station key to choose the data generator.
:return: The generator's time series of history data and its labels
"""
data = self.get_data_generator(key=item)
return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window")
def setup_transformation(self, transformation):
if transformation is None:
return
scope = transformation.get("scope", "station")
method = transformation.get("method", "standardise")
mean = transformation.get("mean", None)
std = transformation.get("std", None)
if scope == "data":
if mean == "accurate":
mean, std = self.calculate_accurate_transformation(method)
elif mean == "estimate":
mean, std = self.calculate_estimated_transformation(method)
else:
mean = mean
transformation["mean"] = mean
transformation["std"] = std
return transformation
def calculate_accurate_transformation(self, method):
mean = None
std = None
return mean, std
def calculate_estimated_transformation(self, method):
mean = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"])
std = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"])
for station in self.stations:
try:
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs)
data.transform("datetime", method=method)
mean = mean.combine_first(data.mean)
std = std.combine_first(data.std)
data.transform("datetime", method=method, inverse=True)
except EmptyQueryResult:
continue
return mean.mean("Stations") if mean.shape[1] > 0 else "hi", std.mean("Stations") if std.shape[1] > 0 else None
def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep:
"""
Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
remove nans.
:param key: station key to choose the data generator.
:param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from
tmp pickle file to save computational time (but of course more disk space required).
:return: preprocessed data as a DataPrep instance
"""
station = self.get_station_key(key)
try:
if not local_tmp_storage:
raise FileNotFoundError
data = self._load_pickle_data(station, self.variables)
except FileNotFoundError:
logging.info(f"load not pickle data for {station}")
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs)
data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
data.transform("datetime", **helpers.dict_pop(self.transformation, "scope"))
data.make_history_window(self.interpolate_dim, self.window_history_size)
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.history_label_nan_remove(self.interpolate_dim)
self._save_pickle_data(data)
return data
def _save_pickle_data(self, data: Any):
"""
Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle'
:param data: any data, that should be saved
"""
date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}"
vars = '_'.join(sorted(data.variables))
station = ''.join(data.station)
file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle")
with open(file, "wb") as f:
pickle.dump(data, f)
logging.debug(f"save pickle data to {file}")
def _load_pickle_data(self, station: Union[str, List[str]], variables: List[str]) -> Any:
"""
Load locally saved data from self.data_path_tmp and name '<station>_<var1>_<var2>_..._<varX>.pickle'.
:param station: station to load
:param variables: list of variables to load
:return: loaded data
"""
date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}"
vars = '_'.join(sorted(variables))
station = ''.join(station)
file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle")
with open(file, "rb") as f:
data = pickle.load(f)
logging.debug(f"load pickle data from {file}")
return data
def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str:
"""
Return a valid station key or raise KeyError if this wasn't possible
:param key: station key to choose the data generator.
:return: station key (id from database)
"""
# extract value if given as list
if isinstance(key, list):
if len(key) == 1:
key = key[0]
else:
raise KeyError(f"More than one key was given: {key}")
# return station name either from key or the recent element from iterator
if key is None:
return self.stations[self._iterator]
else:
if isinstance(key, int):
if key < self.__len__():
return self.stations[key]
else:
raise KeyError(f"{key} is not in range(0, {self.__len__()})")
elif isinstance(key, str):
if key in self.stations:
return key
else:
raise KeyError(f"{key} is not in stations")
else:
raise KeyError(f"Key has to be from Union[str, int]. Given was {key} ({type(key)})")