Skip to content
Snippets Groups Projects
Select Git revision
  • a6bdbdee5d9ab5db08bb9da2dad855ddae00b1a6
  • master default protected
  • 2024
  • 2023
  • 2022
  • 2021
  • 2020
7 results

06_LocalParallel.ipynb

Blame
  • advanced_data_handling.py 9.22 KiB
    
    __author__ = 'Lukas Leufen'
    __date__ = '2020-07-08'
    
    
    from src.helpers import to_list, remove_items
    import numpy as np
    import xarray as xr
    import pickle
    import os
    import pandas as pd
    import datetime as dt
    import shutil
    
    from typing import Union, List, Tuple
    import logging
    from functools import reduce
    
    number = Union[float, int]
    num_or_list = Union[number, List[number]]
    
    
    class DummyDataSingleStation:  # pragma: no cover
    
        def __init__(self, name, number_of_samples=None):
            self.name = name
            self.number_of_samples = number_of_samples if number_of_samples is not None else np.random.randint(100, 150)
    
        def get_X(self):
            X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5))  # samples, window, variables
            datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
            return xr.DataArray(X1, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
                                                                                      "window": range(14),
                                                                                      "variables": range(5)})
    
        def get_Y(self):
            Y1 = np.round(0.5 * np.random.randn(self.number_of_samples, 5, 1), 1)  # samples, window, variables
            datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
            return xr.DataArray(Y1, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
                                                                                      "window": range(5),
                                                                                      "variables": range(1)})
    
    
    class DataPreparation:
    
        def __init__(self, id, data_class, interpolate_dim: str, store_path, neighbor_ids=None, min_length=0,
                     extreme_values: num_or_list = 1.,extremes_on_right_tail_only: bool = False,):
            self.id = id
            self.neighbor_ids = sorted(to_list(neighbor_ids)) if neighbor_ids is not None else []
            self.interpolate_dim = interpolate_dim
            self.min_length = min_length
            self._X = None
            self._Y = None
            self._X_extreme = None
            self._Y_extreme = None
            self._path = os.path.join(store_path, f"data_preparation_{self.id}.pickle")
            self._collection = []
            self._create_collection(data_class)
            self.harmonise_X()
            self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolate_dim)
            self._store(fresh_store=True)
    
        def _reset_data(self):
            self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
    
        def _cleanup(self):
            directory = os.path.dirname(self._path)
            if os.path.exists(directory) is False:
                os.makedirs(directory)
            if os.path.exists(self._path):
                shutil.rmtree(self._path, ignore_errors=True)
    
        def _store(self, fresh_store=False):
            self._cleanup() if fresh_store is True else None
            data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
            with open(self._path, "wb") as f:
                pickle.dump(data, f)
            logging.debug(f"save pickle data to {self._path}")
            self._reset_data()
    
        def _load(self):
            try:
                with open(self._path, "rb") as f:
                    data = pickle.load(f)
                logging.debug(f"load pickle data from {self._path}")
                self._X, self._Y = data["X"], data["Y"]
                self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"]
            except FileNotFoundError:
                pass
    
        def get_data(self, upsampling=False):
            self._load()
            X = self.get_X(upsampling)
            Y = self.get_Y(upsampling)
            self._reset_data()
            return X, Y
    
        def _create_collection(self, data_class, **kwargs):
            for name in [id] + self.neighbor_ids:
                data = data_class(name, **kwargs)
                self._collection.append(data)
    
        def get_X_original(self):
            X = []
            for data in self._collection:
                X.append(data.get_X())
            return X
    
        def get_Y_original(self):
            Y = self._collection[0].get_Y()
            return Y
    
        @staticmethod
        def _to_numpy(d):
            return list(map(lambda x: np.copy(x), d))
    
        def get_X(self, upsamling=False):
            no_data = (self._X is None)
            self._load() if no_data is True else None
            X = self._X if upsamling is False else self._X_extreme
            self._reset_data() if no_data is True else None
            return self._to_numpy(X)
    
        def get_Y(self, upsamling=False):
            no_data = (self._Y is None)
            self._load() if no_data is True else None
            Y = self._Y if upsamling is False else self._Y_extreme
            self._reset_data() if no_data is True else None
            return self._to_numpy([Y])
    
        def harmonise_X(self):
            X_original, Y_original = self.get_X_original(), self.get_Y_original()
            dim = self.interpolate_dim
            intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original))
            if len(intersect) < max(self.min_length, 1):
                X, Y = None, None
            else:
                X = list(map(lambda x: x.sel({dim: intersect}), X_original))
                Y = Y_original.sel({dim: intersect})
            self._X, self._Y = X, Y
    
        def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
                              timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"):
            """
            Multiply extremes.
    
            This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can
            also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of
            floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised
            space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be
            extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is
            used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can
            identify those "artificial" data points later easily. Extreme inputs and labels are stored in
            self.extremes_history and self.extreme_labels, respectively.
    
            :param extreme_values: user definition of extreme
            :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values,
                if True only extract values larger than extreme_values
            :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime
            """
            # check if X or Y is None
            if (self._X is None) or (self._Y is None):
                logging.debug(f"{self.id} has no data for X or Y, skip multiply extremes")
                return
    
            # check type if inputs
            extreme_values = to_list(extreme_values)
            for i in extreme_values:
                if not isinstance(i, number.__args__):
                    raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
                                    f"{i} is type {type(i)}")
    
            for extr_val in sorted(extreme_values):
                # check if some extreme values are already extracted
                if (self._X_extreme is None) or (self._Y_extreme is None):
                    X = self._X
                    Y = self._Y
                else:  # one extr value iteration is done already: self.extremes_label is NOT None...
                    X = self._X_extreme
                    Y = self._Y_extreme
    
                # extract extremes based on occurance in labels
                other_dims = remove_items(list(Y.dims), dim)
                if extremes_on_right_tail_only:
                    extreme_Y_idx = (Y > extr_val).any(dim=other_dims)
                else:
                    extreme_Y_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]),
                                               (Y > extr_val).any(dim=other_dims[0])],
                                              dim=other_dims[1]).any(dim=other_dims[1])
    
                extremes_X = list(map(lambda x: x.sel(**{dim: extreme_Y_idx}), X))
                self._add_timedelta(extremes_X, dim, timedelta)
                # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X))
    
                extremes_Y = Y.sel(**{dim: extreme_Y_idx})
                extremes_Y.coords[dim].values += np.timedelta64(*timedelta)
    
                self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
                self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
    
        @staticmethod
        def _add_timedelta(data, dim, timedelta):
            for d in data:
                d.coords[dim].values += np.timedelta64(*timedelta)
    
    
    
    if __name__ == "__main__":
    
        data = DummyDataSingleStation("main_class")
        data.get_X()
        data.get_Y()
    
        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
        data_prep = DataPreparation("main_class", DummyDataSingleStation, "datetime", path, neighbor_ids=["neighbor1", "neighbor2"],
                                    extreme_values=[1., 1.2])
        data_prep.get_data(upsampling=False)