diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index b49e4b90ebff7dc1456f424bbc94f9ada5bc5ad2..d5933f193018efb1529db2c026981e8c4d7936d2 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -7,11 +7,13 @@ import datetime as dt import logging import os from functools import reduce -from typing import Union, List, Iterable, Tuple +from typing import Union, List, Iterable, Tuple, Dict +from src.helpers.join import EmptyQueryResult import numpy as np import pandas as pd import xarray as xr +import dask.array as da from src.configuration import check_path_and_create from src import helpers @@ -26,12 +28,63 @@ data_or_none = Union[xr.DataArray, None] class AbstractStationPrep(): - def __init__(self, path, station, statistics_per_var, **kwargs): + def __init__(self): #, path, station, statistics_per_var, transformation, **kwargs): + pass # passed parameters + # self.path = os.path.abspath(path) + # self.station = helpers.to_list(station) + # self.statistics_per_var = statistics_per_var + # # self.target_dim = 'variable' + # self.transformation = self.setup_transformation(transformation) + # self.kwargs = kwargs + # + # # internal + # self.data = None + # self.meta = None + # self.variables = kwargs.get('variables', list(statistics_per_var.keys())) + # self.history = None + # self.label = None + # self.observation = None + + + def get_X(self): + raise NotImplementedError + + def get_Y(self): + raise NotImplementedError + + # def load_data(self): + # try: + # self.read_data_from_disk() + # except FileNotFoundError: + # self.download_data() + # self.load_data() + # + # def read_data_from_disk(self): + # raise NotImplementedError + # + # def download_data(self): + # raise NotImplementedError + +class StationPrep(AbstractStationPrep): + + def __init__(self, path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var, + interpolate_dim, window_history_size, window_lead_time, **kwargs): + super().__init__() # path, station, statistics_per_var, transformation, **kwargs) + self.station_type = station_type + self.network = network + self.sampling = sampling + self.target_dim = target_dim + self.target_var = target_var + self.interpolate_dim = interpolate_dim + self.window_history_size = window_history_size + self.window_lead_time = window_lead_time + self.path = os.path.abspath(path) self.station = helpers.to_list(station) self.statistics_per_var = statistics_per_var # self.target_dim = 'variable' + self.transformation = self.setup_transformation(transformation) self.kwargs = kwargs # internal @@ -51,12 +104,123 @@ class AbstractStationPrep(): except FileNotFoundError: self.download_data() self.load_data() + self.make_samples() - def read_data_from_disk(self): - raise NotImplementedError + def __repr__(self): + return f"StationPrep(path='{self.path}', station={self.station}, statistics_per_var={self.statistics_per_var}, " \ + f"transformation={self.transformation}, station_type='{self.station_type}', network='{self.network}', " \ + f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \ + f"interpolate_dim='{self.interpolate_dim}', window_history_size={self.window_history_size}, " \ + f"window_lead_time={self.window_lead_time}, **{self.kwargs})" - def download_data(self): - raise NotImplementedError + def get_transposed_history(self) -> xr.DataArray: + """Return history. + + :return: history with dimensions datetime, window, Stations, variables. + """ + return self.history.transpose("datetime", "window", "Stations", "variables").copy() + + def get_transposed_label(self) -> xr.DataArray: + """Return label. + + :return: label with dimensions datetime*, window*, Stations, variables. + """ + return self.label.squeeze("Stations").transpose("datetime", "window").copy() + + def get_X(self): + return self.get_transposed_history() + + def get_Y(self): + return self.get_transposed_label() + + def make_samples(self): + self.load_data() + self.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim) + self.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) + self.make_observation(self.target_dim, self.target_var, self.interpolate_dim) + self.remove_nan(self.interpolate_dim) + + def read_data_from_disk(self, source_name=""): + """ + Load data and meta data either from local disk (preferred) or download new data by using a custom download method. + + Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both + cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not + set, it is assumed, that data should be saved locally. + """ + source_name = source_name if len(source_name) == 0 else f" from {source_name}" + check_path_and_create(self.path) + file_name = self._set_file_name() + meta_file = self._set_meta_file_name() + if self.kwargs.get('overwrite_local_data', False): + logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}") + if os.path.exists(file_name): + os.remove(file_name) + if os.path.exists(meta_file): + os.remove(meta_file) + data, self.meta = self.download_data(file_name, meta_file) + logging.debug(f"loaded new data{source_name}") + else: + try: + logging.debug(f"try to load local data from: {file_name}") + data = xr.open_dataarray(file_name) + self.meta = pd.read_csv(meta_file, index_col=0) + self.check_station_meta() + logging.debug("loading finished") + except FileNotFoundError as e: + logging.debug(e) + logging.debug(f"load new data{source_name}") + data, self.meta = self.download_data(file_name, meta_file) + logging.debug("loading finished") + # create slices and check for negative concentration. + data = self._slice_prep(data) + self.data = self.check_for_negative_concentrations(data) + + def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: + """ + Download data from TOAR database using the JOIN interface. + + Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally + stored locally using given names for file and meta file. + + :param file_name: name of file to save data to (containing full path) + :param meta_file: name of the meta data file (also containing full path) + + :return: downloaded data and its meta data + """ + df_all = {} + df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var, + station_type=self.station_type, network_name=self.network, sampling=self.sampling) + df_all[self.station[0]] = df + # convert df_all to xarray + xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()} + xarr = xr.Dataset(xarr).to_array(dim='Stations') + if self.kwargs.get('store_data_locally', True): + # save locally as nc/csv file + xarr.to_netcdf(path=file_name) + meta.to_csv(meta_file) + return xarr, meta + + def download_data(self, file_name, meta_file): + data, meta = self.download_data_from_join(file_name, meta_file) + return data, meta + + def check_station_meta(self): + """ + Search for the entries in meta data and compare the value with the requested values. + + Will raise a FileNotFoundError if the values mismatch. + """ + if self.station_type is not None: + check_dict = {"station_type": self.station_type, "network_name": self.network} + for (k, v) in check_dict.items(): + if v is None: + continue + if self.meta.at[k, self.station[0]] != v: + logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " + f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " + f"grapping from web.") + raise FileNotFoundError def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: """ @@ -284,130 +448,124 @@ class AbstractStationPrep(): data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data + def setup_transformation(self, transformation: Dict): + """ + Set up transformation by extracting all relevant information. -class StationPrep(AbstractStationPrep): + Extract all information from transformation dictionary. Possible keys are scope. method, mean, and std. Scope + can either be station or data. Station scope means, that data transformation is performed for each station + independently (somehow like batch normalisation), whereas data scope means a transformation applied on the + entire data set. - def __init__(self, path, station, statistics_per_var, station_type, network, sampling, target_dim, target_var, - interpolate_dim, window_history_size, window_lead_time, **kwargs): - super().__init__(path, station, statistics_per_var, **kwargs) - self.station_type = station_type - self.network = network - self.sampling = sampling - self.target_dim = target_dim - self.target_var = target_var - self.interpolate_dim = interpolate_dim - self.window_history_size = window_history_size - self.window_lead_time = window_lead_time - self.make_samples() + * If using data scope, mean and standard deviation (each only if required by transformation method) can either + be calculated accurate or as an estimate (faster implementation). This must be set in dictionary either + as "mean": "accurate" or "mean": "estimate". In both cases, the required statistics are calculated and saved. + After this calculations, the mean key is overwritten by the actual values to use. + * If using station scope, no additional information is required. + * If a transformation should be applied on base of existing values, these need to be provided in the respective + keys "mean" and "std" (again only if required for given method). - def get_transposed_history(self) -> xr.DataArray: - """Return history. + :param transformation: the transformation dictionary as described above. - :return: history with dimensions datetime, window, Stations, variables. + :return: updated transformation dictionary """ - return self.history.transpose("datetime", "window", "Stations", "variables").copy() - - def get_transposed_label(self) -> xr.DataArray: - """Return label. + if transformation is None: + return + transformation = transformation.copy() + scope = transformation.get("scope", "station") + method = transformation.get("method", "standardise") + mean = transformation.get("mean", None) + std = transformation.get("std", None) + if scope == "data": + if isinstance(mean, str): + if mean == "accurate": + mean, std = self.calculate_accurate_transformation(method) + elif mean == "estimate": + mean, std = self.calculate_estimated_transformation(method) + else: + raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or" + f"be an array with already calculated means. Given was: {mean}") + elif scope == "station": + mean, std = None, None + else: + raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}") + transformation["method"] = method + transformation["mean"] = mean + transformation["std"] = std + return transformation - :return: label with dimensions datetime, window, Stations, variables. + def calculate_accurate_transformation(self, method: str) -> Tuple[data_or_none, data_or_none]: """ - return self.label.squeeze("Stations").transpose("datetime", "window").copy() + Calculate accurate transformation statistics. - def get_X(self): - return self.get_transposed_history() - - def get_Y(self): - return self.get_transposed_label() + Use all stations of this generator and calculate mean and standard deviation on entire data set using dask. + Because there can be much data, this can take a while. - def make_samples(self): - self.load_data() - self.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim) - self.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) - self.make_observation(self.target_dim, self.target_var, self.interpolate_dim) - self.remove_nan(self.interpolate_dim) + :param method: name of transformation method - def read_data_from_disk(self, source_name=""): + :return: accurate calculated mean and std (depending on transformation) """ - Load data and meta data either from local disk (preferred) or download new data by using a custom download method. - - Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both - cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not - set, it is assumed, that data should be saved locally. - """ - source_name = source_name if len(source_name) == 0 else f" from {source_name}" - check_path_and_create(self.path) - file_name = self._set_file_name() - meta_file = self._set_meta_file_name() - if self.kwargs.get('overwrite_local_data', False): - logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}") - if os.path.exists(file_name): - os.remove(file_name) - if os.path.exists(meta_file): - os.remove(meta_file) - data, self.meta = self.download_data(file_name, meta_file) - logging.debug(f"loaded new data{source_name}") - else: + tmp = [] + mean = None + std = None + for station in self.stations: try: - logging.debug(f"try to load local data from: {file_name}") - data = xr.open_dataarray(file_name) - self.meta = pd.read_csv(meta_file, index_col=0) - self.check_station_meta() - logging.debug("loading finished") - except FileNotFoundError as e: - logging.debug(e) - logging.debug(f"load new data{source_name}") - data, self.meta = self.download_data(file_name, meta_file) - logging.debug("loading finished") - # create slices and check for negative concentration. - data = self._slice_prep(data) - self.data = self.check_for_negative_concentrations(data) + data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, + **self.kwargs) + chunks = (1, 100, data.data.shape[2]) + tmp.append(da.from_array(data.data.data, chunks=chunks)) + except EmptyQueryResult: + continue + tmp = da.concatenate(tmp, axis=1) + if method in ["standardise", "centre"]: + mean = da.nanmean(tmp, axis=1).compute() + mean = xr.DataArray(mean.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"]) + if method == "standardise": + std = da.nanstd(tmp, axis=1).compute() + std = xr.DataArray(std.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"]) + else: + raise NotImplementedError + return mean, std - def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: + def calculate_estimated_transformation(self, method): """ - Download data from TOAR database using the JOIN interface. + Calculate estimated transformation statistics. - Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally - stored locally using given names for file and meta file. + Use all stations of this generator and calculate mean and standard deviation first for each station separately. + Afterwards, calculate the average mean and standard devation as estimated statistics. Because this method does + not consider the length of each data set, the estimated mean distinguishes from the real data mean. Furthermore, + the estimated standard deviation is assumed to be the mean (also not weighted) of all deviations. But this is + mathematically not true, but still a rough and faster estimation of the true standard deviation. Do not use this + method for further statistical calculation. However, in the scope of data preparation for machine learning, this + approach is decent ("it is just scaling"). - :param file_name: name of file to save data to (containing full path) - :param meta_file: name of the meta data file (also containing full path) + :param method: name of transformation method - :return: downloaded data and its meta data + :return: accurate calculated mean and std (depending on transformation) """ - df_all = {} - df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var, - station_type=self.station_type, network_name=self.network, sampling=self.sampling) - df_all[self.station[0]] = df - # convert df_all to xarray - xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()} - xarr = xr.Dataset(xarr).to_array(dim='Stations') - if self.kwargs.get('store_data_locally', True): - # save locally as nc/csv file - xarr.to_netcdf(path=file_name) - meta.to_csv(meta_file) - return xarr, meta - - def download_data(self, file_name, meta_file): - data, meta = self.download_data_from_join(file_name, meta_file) - return data, meta + data = [[]] * len(self.variables) + coords = {"variables": self.variables, "Stations": range(0)} + mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) + std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) + for station in self.stations: + try: + data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, + **self.kwargs) + data.transform("datetime", method=method) + mean = mean.combine_first(data.mean) + std = std.combine_first(data.std) + data.transform("datetime", method=method, inverse=True) + except EmptyQueryResult: + continue + return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None - def check_station_meta(self): - """ - Search for the entries in meta data and compare the value with the requested values. + def load_data(self): + try: + self.read_data_from_disk() + except FileNotFoundError: + self.download_data() + self.load_data() - Will raise a FileNotFoundError if the values mismatch. - """ - if self.station_type is not None: - check_dict = {"station_type": self.station_type, "network_name": self.network} - for (k, v) in check_dict.items(): - if v is None: - continue - if self.meta.at[k, self.station[0]] != v: - logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " - f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " - f"grapping from web.") - raise FileNotFoundError class AbstractDataPrep(object): """ @@ -942,9 +1100,9 @@ if __name__ == "__main__": # print(dp) statistics_per_var = {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'} sp = StationPrep(path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122', - statistics_per_var=statistics_per_var, station_type='background', + statistics_per_var=statistics_per_var, transformation={}, station_type='background', network='UBA', sampling='daily', target_dim='variables', target_var='o3', interpolate_dim='datetime', window_history_size=7, window_lead_time=3) - sp.load_data() - sp.download_data('newfile.nc', 'new_meta.csv') + sp.get_X() + sp.get_Y() print(sp)