Skip to content
Snippets Groups Projects
Select Git revision
  • 9ec5d18db36105ac380fc08b507fed2ed91a7f5f
  • 2023 default protected
2 results

mpi4py-3.0.3-ipsmpi-2020-Python-3.8.5.eb

Blame
  • data_generator.py NaN GiB
    __author__ = 'Felix Kleinert, Lukas Leufen'
    __date__ = '2019-11-07'
    
    import os
    from typing import Union, List, Tuple, Any, Dict
    
    import keras
    import xarray as xr
    import pickle
    import logging
    
    from src import helpers
    from src.data_handling.data_preparation import DataPrep
    from src.join import EmptyQueryResult
    
    
    class DataGenerator(keras.utils.Sequence):
        """
        This class is a generator to handle large arrays for machine learning. This class can be used with keras'
        fit_generator and predict_generator. Individual stations are the iterables. This class uses class Dataprep and
        returns X, y when an item is called.
        Item can be called manually by position (integer) or  station id (string). Methods also accept lists with exactly
        one entry of integer or string
        """
    
        def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
                     interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
                     interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
                     window_lead_time: int = 4, transformation: Dict = None, **kwargs):
            self.data_path = os.path.abspath(data_path)
            self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
            if not os.path.exists(self.data_path_tmp):
                os.makedirs(self.data_path_tmp)
            self.network = network
            self.stations = helpers.to_list(stations)
            self.variables = variables
            self.interpolate_dim = interpolate_dim
            self.target_dim = target_dim
            self.target_var = target_var
            self.station_type = station_type
            self.interpolate_method = interpolate_method
            self.limit_nan_fill = limit_nan_fill
            self.window_history_size = window_history_size
            self.window_lead_time = window_lead_time
            self.kwargs = kwargs
            self.transformation = self.setup_transformation(transformation)
    
        def __repr__(self):
            """
            display all class attributes
            """
            return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \
                   f"variables={self.variables}, station_type={self.station_type}, " \
                   f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \
                   f"target_var='{self.target_var}', **{self.kwargs})"
    
        def __len__(self):
            """
            display the number of stations
            """
            return len(self.stations)
    
        def __iter__(self) -> "DataGenerator":
            """
            Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
            `_iterator` to 0.
            :return:
            """
            self._iterator = 0
            return self
    
        def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]:
            """
            This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
            the history and label data of this generator.
            :return:
            """
            if self._iterator < self.__len__():
                data = self.get_data_generator()
                self._iterator += 1
                if data.history is not None and data.label is not None:  # pragma: no branch
                    return data.history.transpose("datetime", "window", "Stations", "variables"), \
                        data.label.squeeze("Stations").transpose("datetime", "window")
                else:
                    self.__next__()  # pragma: no cover
            else:
                raise StopIteration
    
        def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]:
            """
            Defines the get item method for this generator. Retrieve data from generator and return history and labels.
            :param item: station key to choose the data generator.
            :return: The generator's time series of history data and its labels
            """
            data = self.get_data_generator(key=item)
            return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window")
    
        def setup_transformation(self, transformation):
            if transformation is None:
                return
            scope = transformation.get("scope", "station")
            method = transformation.get("method", "standardise")
            mean = transformation.get("mean", None)
            std = transformation.get("std", None)
            if scope == "data":
                if mean == "accurate":
                    mean, std = self.calculate_accurate_transformation(method)
                elif mean == "estimate":
                    mean, std = self.calculate_estimated_transformation(method)
                else:
                    mean = mean
            transformation["mean"] = mean
            transformation["std"] = std
            return transformation
    
        def calculate_accurate_transformation(self, method):
            mean = None
            std = None
            return mean, std
    
        def calculate_estimated_transformation(self, method):
            mean = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"])
            std = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"])
            for station in self.stations:
                try:
                    data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
                                    **self.kwargs)
                    data.transform("datetime", method=method)
                    mean = mean.combine_first(data.mean)
                    std = std.combine_first(data.std)
                    data.transform("datetime", method=method, inverse=True)
                except EmptyQueryResult:
                    continue
            return mean.mean("Stations") if mean.shape[1] > 0 else "hi", std.mean("Stations") if std.shape[1] > 0 else None
    
        def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep:
            """
            Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
            remove nans.
            :param key: station key to choose the data generator.
            :param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from
                tmp pickle file to save computational time (but of course more disk space required).
            :return: preprocessed data as a DataPrep instance
            """
            station = self.get_station_key(key)
            try:
                if not local_tmp_storage:
                    raise FileNotFoundError
                data = self._load_pickle_data(station, self.variables)
            except FileNotFoundError:
                logging.info(f"load not pickle data for {station}")
                data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
                                **self.kwargs)
                data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
                data.transform("datetime", **helpers.dict_pop(self.transformation, "scope"))
                data.make_history_window(self.interpolate_dim, self.window_history_size)
                data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
                data.history_label_nan_remove(self.interpolate_dim)
                self._save_pickle_data(data)
            return data
    
        def _save_pickle_data(self, data: Any):
            """
            Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle'
            :param data: any data, that should be saved
            """
            date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}"
            vars = '_'.join(sorted(data.variables))
            station = ''.join(data.station)
            file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle")
            with open(file, "wb") as f:
                pickle.dump(data, f)
            logging.debug(f"save pickle data to {file}")
    
        def _load_pickle_data(self, station: Union[str, List[str]], variables: List[str]) -> Any:
            """
            Load locally saved data from self.data_path_tmp and name '<station>_<var1>_<var2>_..._<varX>.pickle'.
            :param station: station to load
            :param variables: list of variables to load
            :return: loaded data
            """
            date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}"
            vars = '_'.join(sorted(variables))
            station = ''.join(station)
            file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle")
            with open(file, "rb") as f:
                data = pickle.load(f)
            logging.debug(f"load pickle data from {file}")
            return data
    
        def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str:
            """
            Return a valid station key or raise KeyError if this wasn't possible
            :param key: station key to choose the data generator.
            :return: station key (id from database)
            """
            # extract value if given as list
            if isinstance(key, list):
                if len(key) == 1:
                    key = key[0]
                else:
                    raise KeyError(f"More than one key was given: {key}")
            # return station name either from key or the recent element from iterator
            if key is None:
                return self.stations[self._iterator]
            else:
                if isinstance(key, int):
                    if key < self.__len__():
                        return self.stations[key]
                    else:
                        raise KeyError(f"{key} is not in range(0, {self.__len__()})")
                elif isinstance(key, str):
                    if key in self.stations:
                        return key
                    else:
                        raise KeyError(f"{key} is not in stations")
                else:
                    raise KeyError(f"Key has to be from Union[str, int]. Given was {key} ({type(key)})")