From d083f8bd6239b55f9049f2a033fd9f2b75ed540d Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Mon, 21 Sep 2020 13:19:39 +0200 Subject: [PATCH] split advanced data handler into abstract and default data handler, updated import statements, Station Preparation inherits from AbstractDataHandler and states requirements explicit (not from kwargs anymore) --- mlair/data_handler/__init__.py | 3 +- mlair/data_handler/abstract_data_handler.py | 47 +++ mlair/data_handler/advanced_data_handler.py | 290 +----------------- mlair/data_handler/bootstraps.py | 2 +- .../data_preparation_neighbors.py | 2 +- mlair/data_handler/default_data_handler.py | 238 ++++++++++++++ mlair/data_handler/station_preparation.py | 31 +- mlair/run_modules/experiment_setup.py | 2 +- 8 files changed, 301 insertions(+), 314 deletions(-) create mode 100644 mlair/data_handler/abstract_data_handler.py create mode 100644 mlair/data_handler/default_data_handler.py diff --git a/mlair/data_handler/__init__.py b/mlair/data_handler/__init__.py index 6510b336..01d66003 100644 --- a/mlair/data_handler/__init__.py +++ b/mlair/data_handler/__init__.py @@ -11,5 +11,6 @@ __date__ = '2020-04-17' from .bootstraps import BootStraps from .iterator import KerasIterator, DataCollection -from .advanced_data_handler import DefaultDataHandler, AbstractDataHandler +from .default_data_handler import DefaultDataHandler +from .abstract_data_handler import AbstractDataHandler from .data_preparation_neighbors import DataHandlerNeighbors diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py new file mode 100644 index 00000000..04b3d465 --- /dev/null +++ b/mlair/data_handler/abstract_data_handler.py @@ -0,0 +1,47 @@ + +__author__ = 'Lukas Leufen' +__date__ = '2020-09-21' + +import inspect +from typing import Union, Dict + +from mlair.helpers import remove_items + + +class AbstractDataHandler: + + _requirements = [] + + def __init__(self, *args, **kwargs): + pass + + @classmethod + def build(cls, *args, **kwargs): + """Return initialised class.""" + return cls(*args, **kwargs) + + @classmethod + def requirements(cls): + """Return requirements and own arguments without duplicates.""" + return list(set(cls._requirements + cls.own_args())) + + @classmethod + def own_args(cls, *args): + return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args)) + + @classmethod + def transformation(cls, *args, **kwargs): + return None + + def get_X(self, upsampling=False, as_numpy=False): + raise NotImplementedError + + def get_Y(self, upsampling=False, as_numpy=False): + raise NotImplementedError + + def get_data(self, upsampling=False, as_numpy=False): + return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy) + + def get_coordinates(self) -> Union[None, Dict]: + """Return coordinates as dictionary with keys `lon` and `lat`.""" + return None diff --git a/mlair/data_handler/advanced_data_handler.py b/mlair/data_handler/advanced_data_handler.py index bf7defa5..c2d210bf 100644 --- a/mlair/data_handler/advanced_data_handler.py +++ b/mlair/data_handler/advanced_data_handler.py @@ -2,306 +2,20 @@ __author__ = 'Lukas Leufen' __date__ = '2020-07-08' - -from mlair.helpers import to_list, remove_items import numpy as np import xarray as xr -import pickle import os import pandas as pd import datetime as dt -import shutil -import inspect -import copy -from typing import Union, List, Tuple, Dict -import logging -from functools import reduce -from mlair.data_handler.station_preparation import DataHandlerSingleStation -from mlair.helpers.join import EmptyQueryResult +from mlair.data_handler import AbstractDataHandler +from typing import Union, List number = Union[float, int] num_or_list = Union[number, List[number]] -class DummyDataSingleStation: # pragma: no cover - - def __init__(self, name, number_of_samples=None): - self.name = name - self.number_of_samples = number_of_samples if number_of_samples is not None else np.random.randint(100, 150) - - def get_X(self): - X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables - datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist() - return xr.DataArray(X1, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist, - "window": range(14), - "variables": range(5)}) - - def get_Y(self): - Y1 = np.round(0.5 * np.random.randn(self.number_of_samples, 5, 1), 1) # samples, window, variables - datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist() - return xr.DataArray(Y1, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist, - "window": range(5), - "variables": range(1)}) - - def __str__(self): - return self.name - - -class AbstractDataHandler: - - _requirements = [] - - def __init__(self, *args, **kwargs): - pass - - @classmethod - def build(cls, *args, **kwargs): - """Return initialised class.""" - return cls(*args, **kwargs) - - @classmethod - def requirements(cls): - """Return requirements and own arguments without duplicates.""" - return list(set(cls._requirements + cls.own_args())) - - @classmethod - def own_args(cls, *args): - return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args)) - - @classmethod - def transformation(cls, *args, **kwargs): - return None - - def get_X(self, upsampling=False, as_numpy=False): - raise NotImplementedError - - def get_Y(self, upsampling=False, as_numpy=False): - raise NotImplementedError - - def get_data(self, upsampling=False, as_numpy=False): - return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy) - - def get_coordinates(self) -> Union[None, Dict]: - """Return coordinates as dictionary with keys `lon` and `lat`.""" - return None - - -class DefaultDataHandler(AbstractDataHandler): - - _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - - def __init__(self, id_class: DataHandlerSingleStation, data_path: str, min_length: int = 0, - extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None): - super().__init__() - self.id_class = id_class - self.interpolation_dim = "datetime" - self.min_length = min_length - self._X = None - self._Y = None - self._X_extreme = None - self._Y_extreme = None - _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) - self._save_file = os.path.join(data_path, f"data_preparation_{_name_affix}.pickle") - self._collection = self._create_collection() - self.harmonise_X() - self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim) - self._store(fresh_store=True) - - @classmethod - def build(cls, station: str, **kwargs): - sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} - sp = DataHandlerSingleStation(station, **sp_keys) - dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} - return cls(sp, **dp_args) - - def _create_collection(self): - return [self.id_class] - - @classmethod - def requirements(cls): - return remove_items(super().requirements(), "id_class") - - def _reset_data(self): - self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None - - def _cleanup(self): - directory = os.path.dirname(self._save_file) - if os.path.exists(directory) is False: - os.makedirs(directory) - if os.path.exists(self._save_file): - shutil.rmtree(self._save_file, ignore_errors=True) - - def _store(self, fresh_store=False): - self._cleanup() if fresh_store is True else None - data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme} - with open(self._save_file, "wb") as f: - pickle.dump(data, f) - logging.debug(f"save pickle data to {self._save_file}") - self._reset_data() - - def _load(self): - try: - with open(self._save_file, "rb") as f: - data = pickle.load(f) - logging.debug(f"load pickle data from {self._save_file}") - self._X, self._Y = data["X"], data["Y"] - self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"] - except FileNotFoundError: - pass - - def get_data(self, upsampling=False, as_numpy=True): - self._load() - X = self.get_X(upsampling, as_numpy) - Y = self.get_Y(upsampling, as_numpy) - self._reset_data() - return X, Y - - def __repr__(self): - return ";".join(list(map(lambda x: str(x), self._collection))) - - def get_X_original(self): - X = [] - for data in self._collection: - X.append(data.get_X()) - return X - - def get_Y_original(self): - Y = self._collection[0].get_Y() - return Y - - @staticmethod - def _to_numpy(d): - return list(map(lambda x: np.copy(x), d)) - - def get_X(self, upsampling=False, as_numpy=True): - no_data = (self._X is None) - self._load() if no_data is True else None - X = self._X if upsampling is False else self._X_extreme - self._reset_data() if no_data is True else None - return self._to_numpy(X) if as_numpy is True else X - - def get_Y(self, upsampling=False, as_numpy=True): - no_data = (self._Y is None) - self._load() if no_data is True else None - Y = self._Y if upsampling is False else self._Y_extreme - self._reset_data() if no_data is True else None - return self._to_numpy([Y]) if as_numpy is True else Y - - def harmonise_X(self): - X_original, Y_original = self.get_X_original(), self.get_Y_original() - dim = self.interpolation_dim - intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original)) - if len(intersect) < max(self.min_length, 1): - X, Y = None, None - else: - X = list(map(lambda x: x.sel({dim: intersect}), X_original)) - Y = Y_original.sel({dim: intersect}) - self._X, self._Y = X, Y - - def get_observation(self): - return self.id_class.observation.copy().squeeze() - - def get_transformation_Y(self): - return self.id_class.get_transformation_information() - - def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, - timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"): - """ - Multiply extremes. - - This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can - also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of - floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised - space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be - extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is - used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can - identify those "artificial" data points later easily. Extreme inputs and labels are stored in - self.extremes_history and self.extreme_labels, respectively. - - :param extreme_values: user definition of extreme - :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values, - if True only extract values larger than extreme_values - :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime - """ - # check if X or Y is None - if (self._X is None) or (self._Y is None): - logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes") - return - if extreme_values is None: - logging.debug(f"No extreme values given, skip multiply extremes") - self._X_extreme, self._Y_extreme = self._X, self._Y - return - - # check type if inputs - extreme_values = to_list(extreme_values) - for i in extreme_values: - if not isinstance(i, number.__args__): - raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element " - f"{i} is type {type(i)}") - - for extr_val in sorted(extreme_values): - # check if some extreme values are already extracted - if (self._X_extreme is None) or (self._Y_extreme is None): - X = self._X - Y = self._Y - else: # one extr value iteration is done already: self.extremes_label is NOT None... - X = self._X_extreme - Y = self._Y_extreme - - # extract extremes based on occurrence in labels - other_dims = remove_items(list(Y.dims), dim) - if extremes_on_right_tail_only: - extreme_idx = (Y > extr_val).any(dim=other_dims) - else: - extreme_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]), - (Y > extr_val).any(dim=other_dims[0])], - dim=other_dims[1]).any(dim=other_dims[1]) - - extremes_X = list(map(lambda x: x.sel(**{dim: extreme_idx}), X)) - self._add_timedelta(extremes_X, dim, timedelta) - # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X)) - - extremes_Y = Y.sel(**{dim: extreme_idx}) - extremes_Y.coords[dim].values += np.timedelta64(*timedelta) - - self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim) - self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X)) - - @staticmethod - def _add_timedelta(data, dim, timedelta): - for d in data: - d.coords[dim].values += np.timedelta64(*timedelta) - - @classmethod - def transformation(cls, set_stations, **kwargs): - sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} - transformation_dict = sp_keys.pop("transformation") - if transformation_dict is None: - return - scope = transformation_dict.pop("scope") - method = transformation_dict.pop("method") - if transformation_dict.pop("mean", None) is not None: - return - mean, std = None, None - for station in set_stations: - try: - sp = DataHandlerSingleStation(station, transformation={"method": method}, **sp_keys) - mean = sp.mean.copy(deep=True) if mean is None else mean.combine_first(sp.mean) - std = sp.std.copy(deep=True) if std is None else std.combine_first(sp.std) - except (AttributeError, EmptyQueryResult): - continue - if mean is None: - return None - mean_estimated = mean.mean("Stations") - std_estimated = std.mean("Stations") - return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated} - - def get_coordinates(self): - return self.id_class.get_coordinates() - - def run_data_prep(): from .data_preparation_neighbors import DataHandlerNeighbors diff --git a/mlair/data_handler/bootstraps.py b/mlair/data_handler/bootstraps.py index f7f5c3c7..68a4bbc4 100644 --- a/mlair/data_handler/bootstraps.py +++ b/mlair/data_handler/bootstraps.py @@ -19,7 +19,7 @@ from itertools import chain import numpy as np import xarray as xr -from mlair.data_handler.advanced_data_handler import AbstractDataHandler +from mlair.data_handler.abstract_data_handler import AbstractDataHandler class BootstrapIterator(Iterator): diff --git a/mlair/data_handler/data_preparation_neighbors.py b/mlair/data_handler/data_preparation_neighbors.py index 37e19225..1482bb9f 100644 --- a/mlair/data_handler/data_preparation_neighbors.py +++ b/mlair/data_handler/data_preparation_neighbors.py @@ -5,7 +5,7 @@ __date__ = '2020-07-17' from mlair.helpers import to_list from mlair.data_handler.station_preparation import DataHandlerSingleStation -from mlair.data_handler.advanced_data_handler import DefaultDataHandler +from mlair.data_handler import DefaultDataHandler import os from typing import Union, List diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py new file mode 100644 index 00000000..47f63a3e --- /dev/null +++ b/mlair/data_handler/default_data_handler.py @@ -0,0 +1,238 @@ + +__author__ = 'Lukas Leufen' +__date__ = '2020-09-21' + +import copy +import inspect +import logging +import os +import pickle +import shutil +from functools import reduce +from typing import Tuple, Union, List + +import numpy as np +import xarray as xr + +from mlair.data_handler.abstract_data_handler import AbstractDataHandler +from mlair.data_handler.station_preparation import DataHandlerSingleStation +from mlair.helpers import remove_items, to_list +from mlair.helpers.join import EmptyQueryResult + + +number = Union[float, int] +num_or_list = Union[number, List[number]] + + +class DefaultDataHandler(AbstractDataHandler): + + _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) + + def __init__(self, id_class: DataHandlerSingleStation, data_path: str, min_length: int = 0, + extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None): + super().__init__() + self.id_class = id_class + self.interpolation_dim = "datetime" + self.min_length = min_length + self._X = None + self._Y = None + self._X_extreme = None + self._Y_extreme = None + _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) + self._save_file = os.path.join(data_path, f"data_preparation_{_name_affix}.pickle") + self._collection = self._create_collection() + self.harmonise_X() + self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim) + self._store(fresh_store=True) + + @classmethod + def build(cls, station: str, **kwargs): + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} + sp = DataHandlerSingleStation(station, **sp_keys) + dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} + return cls(sp, **dp_args) + + def _create_collection(self): + return [self.id_class] + + @classmethod + def requirements(cls): + return remove_items(super().requirements(), "id_class") + + def _reset_data(self): + self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None + + def _cleanup(self): + directory = os.path.dirname(self._save_file) + if os.path.exists(directory) is False: + os.makedirs(directory) + if os.path.exists(self._save_file): + shutil.rmtree(self._save_file, ignore_errors=True) + + def _store(self, fresh_store=False): + self._cleanup() if fresh_store is True else None + data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme} + with open(self._save_file, "wb") as f: + pickle.dump(data, f) + logging.debug(f"save pickle data to {self._save_file}") + self._reset_data() + + def _load(self): + try: + with open(self._save_file, "rb") as f: + data = pickle.load(f) + logging.debug(f"load pickle data from {self._save_file}") + self._X, self._Y = data["X"], data["Y"] + self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"] + except FileNotFoundError: + pass + + def get_data(self, upsampling=False, as_numpy=True): + self._load() + X = self.get_X(upsampling, as_numpy) + Y = self.get_Y(upsampling, as_numpy) + self._reset_data() + return X, Y + + def __repr__(self): + return ";".join(list(map(lambda x: str(x), self._collection))) + + def get_X_original(self): + X = [] + for data in self._collection: + X.append(data.get_X()) + return X + + def get_Y_original(self): + Y = self._collection[0].get_Y() + return Y + + @staticmethod + def _to_numpy(d): + return list(map(lambda x: np.copy(x), d)) + + def get_X(self, upsampling=False, as_numpy=True): + no_data = (self._X is None) + self._load() if no_data is True else None + X = self._X if upsampling is False else self._X_extreme + self._reset_data() if no_data is True else None + return self._to_numpy(X) if as_numpy is True else X + + def get_Y(self, upsampling=False, as_numpy=True): + no_data = (self._Y is None) + self._load() if no_data is True else None + Y = self._Y if upsampling is False else self._Y_extreme + self._reset_data() if no_data is True else None + return self._to_numpy([Y]) if as_numpy is True else Y + + def harmonise_X(self): + X_original, Y_original = self.get_X_original(), self.get_Y_original() + dim = self.interpolation_dim + intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original)) + if len(intersect) < max(self.min_length, 1): + X, Y = None, None + else: + X = list(map(lambda x: x.sel({dim: intersect}), X_original)) + Y = Y_original.sel({dim: intersect}) + self._X, self._Y = X, Y + + def get_observation(self): + return self.id_class.observation.copy().squeeze() + + def get_transformation_Y(self): + return self.id_class.get_transformation_information() + + def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, + timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"): + """ + Multiply extremes. + + This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can + also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of + floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised + space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be + extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is + used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can + identify those "artificial" data points later easily. Extreme inputs and labels are stored in + self.extremes_history and self.extreme_labels, respectively. + + :param extreme_values: user definition of extreme + :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values, + if True only extract values larger than extreme_values + :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime + """ + # check if X or Y is None + if (self._X is None) or (self._Y is None): + logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes") + return + if extreme_values is None: + logging.debug(f"No extreme values given, skip multiply extremes") + self._X_extreme, self._Y_extreme = self._X, self._Y + return + + # check type if inputs + extreme_values = to_list(extreme_values) + for i in extreme_values: + if not isinstance(i, number.__args__): + raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element " + f"{i} is type {type(i)}") + + for extr_val in sorted(extreme_values): + # check if some extreme values are already extracted + if (self._X_extreme is None) or (self._Y_extreme is None): + X = self._X + Y = self._Y + else: # one extr value iteration is done already: self.extremes_label is NOT None... + X = self._X_extreme + Y = self._Y_extreme + + # extract extremes based on occurrence in labels + other_dims = remove_items(list(Y.dims), dim) + if extremes_on_right_tail_only: + extreme_idx = (Y > extr_val).any(dim=other_dims) + else: + extreme_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]), + (Y > extr_val).any(dim=other_dims[0])], + dim=other_dims[1]).any(dim=other_dims[1]) + + extremes_X = list(map(lambda x: x.sel(**{dim: extreme_idx}), X)) + self._add_timedelta(extremes_X, dim, timedelta) + # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X)) + + extremes_Y = Y.sel(**{dim: extreme_idx}) + extremes_Y.coords[dim].values += np.timedelta64(*timedelta) + + self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim) + self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X)) + + @staticmethod + def _add_timedelta(data, dim, timedelta): + for d in data: + d.coords[dim].values += np.timedelta64(*timedelta) + + @classmethod + def transformation(cls, set_stations, **kwargs): + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} + transformation_dict = sp_keys.pop("transformation") + if transformation_dict is None: + return + scope = transformation_dict.pop("scope") + method = transformation_dict.pop("method") + if transformation_dict.pop("mean", None) is not None: + return + mean, std = None, None + for station in set_stations: + try: + sp = DataHandlerSingleStation(station, transformation={"method": method}, **sp_keys) + mean = sp.mean.copy(deep=True) if mean is None else mean.combine_first(sp.mean) + std = sp.std.copy(deep=True) if std is None else std.combine_first(sp.std) + except (AttributeError, EmptyQueryResult): + continue + if mean is None: + return None + mean_estimated = mean.mean("Stations") + std_estimated = std.mean("Stations") + return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated} + + def get_coordinates(self): + return self.id_class.get_coordinates() \ No newline at end of file diff --git a/mlair/data_handler/station_preparation.py b/mlair/data_handler/station_preparation.py index a278d0df..6112e7c5 100644 --- a/mlair/data_handler/station_preparation.py +++ b/mlair/data_handler/station_preparation.py @@ -16,6 +16,7 @@ import xarray as xr from mlair.configuration import check_path_and_create from mlair import helpers from mlair.helpers import join, statistics +from mlair.data_handler.abstract_data_handler import AbstractDataHandler # define a more general date type for type hinting date = Union[dt.date, dt.datetime] @@ -39,18 +40,7 @@ DEFAULT_SAMPLING = "daily" DEFAULT_INTERPOLATION_METHOD = "linear" -class AbstractDataHandlerSingleStation(object): - def __init__(self): #, path, station, statistics_per_var, transformation, **kwargs): - pass - - def get_X(self): - raise NotImplementedError - - def get_Y(self): - raise NotImplementedError - - -class DataHandlerSingleStation(AbstractDataHandlerSingleStation): +class DataHandlerSingleStation(AbstractDataHandler): def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE, network=DEFAULT_NETWORK, sampling=DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM, @@ -58,7 +48,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation): window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, interpolation_limit: int = 0, interpolation_method: str = DEFAULT_INTERPOLATION_METHOD, overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True, - min_length: int = 0, start=None, end=None, **kwargs): + min_length: int = 0, start=None, end=None, variables=None, **kwargs): super().__init__() # path, station, statistics_per_var, transformation, **kwargs) self.station = helpers.to_list(station) self.path = os.path.abspath(data_path) @@ -86,7 +76,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation): # internal self.data = None self.meta = None - self.variables = kwargs.get('variables', list(statistics_per_var.keys())) + self.variables = statistics_per_var.keys() if variables is None else variables self.history = None self.label = None self.observation = None @@ -98,10 +88,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation): self.min = None self._transform_method = None - self.kwargs = kwargs - # self.kwargs["overwrite_local_data"] = overwrite_local_data - - # self.make_samples() + # create samples self.setup_samples() def __str__(self): @@ -123,7 +110,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation): f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \ f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \ f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data}, " \ - f"transformation={self._print_transformation_as_string}, **{self.kwargs})" + f"transformation={self._print_transformation_as_string})" @property def _print_transformation_as_string(self): @@ -155,10 +142,10 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation): """ return self.label.squeeze("Stations").transpose("datetime", "window").copy() - def get_X(self): + def get_X(self, **kwargs): return self.get_transposed_history() - def get_Y(self): + def get_Y(self, **kwargs): return self.get_transposed_label() def get_coordinates(self): @@ -498,7 +485,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation): """ chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", "toluene"] - used_chem_vars = list(set(chem_vars) & set(self.variables)) + used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys())) data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index d66954b0..f5d7d80f 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -18,7 +18,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST -from mlair.data_handler.advanced_data_handler import DefaultDataHandler +from mlair.data_handler import DefaultDataHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.model_class import MyLittleModel as VanillaModel -- GitLab