diff --git a/docs/_source/_plots/separation_of_scales.png b/docs/_source/_plots/separation_of_scales.png new file mode 100755 index 0000000000000000000000000000000000000000..d2bbc625a5d50051d8ec2babe976f88d7446e39e Binary files /dev/null and b/docs/_source/_plots/separation_of_scales.png differ diff --git a/mlair/configuration/.gitignore b/mlair/configuration/.gitignore index 8e2358dc56797578fe0de020aa827b1fef8663bf..91eccc695f4ea58374a14a1ba0272f98f210c203 100644 --- a/mlair/configuration/.gitignore +++ b/mlair/configuration/.gitignore @@ -1 +1,2 @@ -join_settings.py \ No newline at end of file +join_settings.py +join_rest \ No newline at end of file diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index d191af2edd8a6fe2c1093b3f1c3f5d419cc42b76..ce42fc0eed6e891bc0a0625666da3dccfcc8a3ee 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -1,6 +1,7 @@ __author__ = "Lukas Leufen" __date__ = '2020-06-25' +from mlair.helpers.statistics import TransformationClass DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', @@ -13,8 +14,7 @@ DEFAULT_START = "1997-01-01" DEFAULT_END = "2017-12-31" DEFAULT_WINDOW_HISTORY_SIZE = 13 DEFAULT_OVERWRITE_LOCAL_DATA = False -# DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"} -DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise"} +DEFAULT_TRANSFORMATION = TransformationClass(inputs_method="standardise", targets_method="standardise") DEFAULT_HPC_LOGIN_LIST = ["ju", "hdfmll"] # ju[wels} #hdfmll(ogin) DEFAULT_HPC_HOST_LIST = ["jw", "hdfmlc"] # first part of node names for Juwels (jw[comp], hdfmlc(ompute). DEFAULT_CREATE_NEW_MODEL = True @@ -46,13 +46,13 @@ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True DEFAULT_EVALUATE_BOOTSTRAPS = True DEFAULT_CREATE_NEW_BOOTSTRAPS = False DEFAULT_NUMBER_OF_BOOTSTRAPS = 20 -#DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", -# "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", -# "PlotAvailability"] -DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", +DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", - "PlotAvailability"] - + "PlotAvailability", "PlotSeparationOfScales"] +DEFAULT_SAMPLING = "daily" +DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", + "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", "no": "", "no2": "", "o3": "", + "pm10": "", "so2": ""} def get_defaults(): diff --git a/mlair/configuration/path_config.py b/mlair/configuration/path_config.py index 9b3d6f250d97d93dd1d06004690885f44de30073..bf40c361e121c409efec08b85fdf4e19848049ee 100644 --- a/mlair/configuration/path_config.py +++ b/mlair/configuration/path_config.py @@ -20,7 +20,7 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str: :param create_new: Create new path if enabled :param data_path: Parse your custom path (and therefore ignore preset paths fitting to known hosts) - :param sampling: sampling rate to separate data physically by temporal resolution + :param sampling: sampling rate to separate data physically by temporal resolution (deprecated) :return: full path to data """ @@ -32,17 +32,14 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str: data_path = f"/home/{user}/Data/toar_{sampling}/" elif hostname == "zam347": data_path = f"/home/{user}/Data/toar_{sampling}/" - elif hostname == "linux-aa9b": - data_path = f"/home/{user}/mlair/data/toar_{sampling}/" elif (len(hostname) > 2) and (hostname[:2] == "jr"): data_path = f"/p/project/cjjsc42/{user}/DATA/toar_{sampling}/" elif (len(hostname) > 2) and (hostname[:2] in ['jw', 'ju'] or hostname[:5] in ['hdfml']): - data_path = f"/p/project/deepacf/intelliaq/{user}/DATA/toar_{sampling}/" + data_path = f"/p/project/deepacf/intelliaq/{user}/DATA/MLAIR/" elif runner_regex.match(hostname) is not None: - data_path = f"/home/{user}/mlair/data/toar_{sampling}/" + data_path = f"/home/{user}/mlair/data/" else: - data_path = os.path.join(os.getcwd(), "data", sampling) - # raise OSError(f"unknown host '{hostname}'") + data_path = os.path.join(os.getcwd(), "data") if not os.path.exists(data_path): try: @@ -97,7 +94,7 @@ def set_experiment_name(name: str = None, sampling: str = None) -> str: return experiment_name -def set_bootstrap_path(bootstrap_path: str, data_path: str, sampling: str) -> str: +def set_bootstrap_path(bootstrap_path: str, data_path: str) -> str: """ Set path for bootstrap input data. @@ -105,12 +102,11 @@ def set_bootstrap_path(bootstrap_path: str, data_path: str, sampling: str) -> st :param bootstrap_path: custom path to store bootstrap data :param data_path: path of data for default bootstrap path - :param sampling: sampling rate to add, if path is set to default :return: full bootstrap path """ if bootstrap_path is None: - bootstrap_path = os.path.join(data_path, "..", f"bootstrap_{sampling}") + bootstrap_path = os.path.join(data_path, "bootstrap") check_path_and_create(bootstrap_path) return os.path.abspath(bootstrap_path) diff --git a/mlair/data_handler/__init__.py b/mlair/data_handler/__init__.py index 01d660031bbbdda08eba80044a08fcb034d8171b..495b6e7c8604a839a084a2b78a54563c13eb06e6 100644 --- a/mlair/data_handler/__init__.py +++ b/mlair/data_handler/__init__.py @@ -13,4 +13,4 @@ from .bootstraps import BootStraps from .iterator import KerasIterator, DataCollection from .default_data_handler import DefaultDataHandler from .abstract_data_handler import AbstractDataHandler -from .data_preparation_neighbors import DataHandlerNeighbors +from .data_handler_neighbors import DataHandlerNeighbors diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py index 04b3d4651347759130da15a05056f6ace3d0fc1f..26ccf69c85e999c540e656a2ceac5737390a579e 100644 --- a/mlair/data_handler/abstract_data_handler.py +++ b/mlair/data_handler/abstract_data_handler.py @@ -27,7 +27,10 @@ class AbstractDataHandler: @classmethod def own_args(cls, *args): - return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args)) + """Return all arguments (including kwonlyargs).""" + arg_spec = inspect.getfullargspec(cls) + list_of_args = arg_spec.args + arg_spec.kwonlyargs + return remove_items(list_of_args, ["self"] + list(args)) @classmethod def transformation(cls, *args, **kwargs): diff --git a/mlair/data_handler/advanced_data_handler.py b/mlair/data_handler/advanced_data_handler.py index c2d210bffdb598b23c025f60b903ddef84e4509d..f04748e82f11116b265796afba7f401c1cad9342 100644 --- a/mlair/data_handler/advanced_data_handler.py +++ b/mlair/data_handler/advanced_data_handler.py @@ -10,15 +10,18 @@ import datetime as dt from mlair.data_handler import AbstractDataHandler -from typing import Union, List +from typing import Union, List, Tuple, Dict +import logging +from functools import reduce +from mlair.helpers.join import EmptyQueryResult +from mlair.helpers import TimeTracking number = Union[float, int] num_or_list = Union[number, List[number]] def run_data_prep(): - - from .data_preparation_neighbors import DataHandlerNeighbors + from .data_handler_neighbors import DataHandlerNeighbors data = DummyDataHandler("main_class") data.get_X() data.get_Y() @@ -33,8 +36,7 @@ def run_data_prep(): def create_data_prep(): - - from .data_preparation_neighbors import DataHandlerNeighbors + from .data_handler_neighbors import DataHandlerNeighbors path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") station_type = None network = 'UBA' @@ -98,7 +100,7 @@ class DummyDataHandler(AbstractDataHandler): if __name__ == "__main__": - from mlair.data_handler.station_preparation import DataHandlerSingleStation + from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation from mlair.data_handler.iterator import KerasIterator, DataCollection data_prep = create_data_prep() data_collection = DataCollection(data_prep) diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..adc5ee0e72694baed6ec0ab0c0bf9259126af292 --- /dev/null +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -0,0 +1,93 @@ +"""Data Handler using kz-filtered data.""" + +__author__ = 'Lukas Leufen' +__date__ = '2020-08-26' + +import inspect +import numpy as np +import pandas as pd +import xarray as xr +from typing import List, Union + +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.data_handler import DefaultDataHandler +from mlair.helpers import remove_items, to_list, TimeTrackingWrapper +from mlair.helpers.statistics import KolmogorovZurbenkoFilterMovingWindow as KZFilter + +# define a more general date type for type hinting +str_or_list = Union[str, List[str]] + + +class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): + """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered.""" + + _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) + + def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs): + self._check_sampling(**kwargs) + # self.original_data = None # ToDo: implement here something to store unfiltered data + self.kz_filter_length = to_list(kz_filter_length) + self.kz_filter_iter = to_list(kz_filter_iter) + self.cutoff_period = None + self.cutoff_period_days = None + super().__init__(*args, **kwargs) + + def _check_sampling(self, **kwargs): + assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution + + def setup_samples(self): + """ + Setup samples. This method prepares and creates samples X, and labels Y. + """ + data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, + self.station_type, self.network, self.store_data_locally, self.data_origin) + self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) + self.set_inputs_and_targets() + self.apply_kz_filter() + # this is just a code snippet to check the results of the kz filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # self.input_data.data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() + # self.input_data.data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") + if self.do_transformation is True: + self.call_transform() + self.make_samples() + + @TimeTrackingWrapper + def apply_kz_filter(self): + """Apply kolmogorov zurbenko filter only on inputs.""" + kz = KZFilter(self.input_data.data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim="datetime") + filtered_data: List[xr.DataArray] = kz.run() + self.cutoff_period = kz.period_null() + self.cutoff_period_days = kz.period_null_days() + self.input_data.data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name="filter")) + + def create_filter_index(self) -> pd.Index: + """ + Round cut off periods in days and append 'res' for residuum index. + + Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append + 'res' for residuum index. + """ + index = np.round(self.cutoff_period_days, 1) + f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) + index = list(map(f, index.tolist())) + index = list(map(lambda x: str(x) + "d", index)) + ["res"] + return pd.Index(index, name="filter") + + def get_transposed_history(self) -> xr.DataArray: + """Return history. + + :return: history with dimensions datetime, window, Stations, variables. + """ + return self.history.transpose("datetime", "window", "Stations", "variables", "filter").copy() + + +class DataHandlerKzFilter(DefaultDataHandler): + """Data handler using kz filtered data.""" + + data_handler = DataHandlerKzFilterSingleStation + data_handler_transformation = DataHandlerKzFilterSingleStation + _requirements = data_handler.requirements() diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1f0d55b55757875b640de00f66e62dd3586b11 --- /dev/null +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -0,0 +1,203 @@ +__author__ = 'Lukas Leufen' +__date__ = '2020-11-05' + +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation +from mlair.data_handler import DefaultDataHandler +from mlair import helpers +from mlair.helpers import remove_items +from mlair.configuration.defaults import DEFAULT_SAMPLING + +import inspect +from typing import Callable +import datetime as dt + +import numpy as np +import pandas as pd +import xarray as xr + + +class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): + _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) + + def __init__(self, *args, sampling_inputs, **kwargs): + sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING)) + kwargs.update({"sampling": sampling}) + super().__init__(*args, **kwargs) + + def setup_samples(self): + """ + Setup samples. This method prepares and creates samples X, and labels Y. + """ + self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data + self.set_inputs_and_targets() + if self.do_transformation is True: + self.call_transform() + self.make_samples() + + def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: + data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + self.station_type, self.network, self.store_data_locally, self.data_origin, + self.start, self.end) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) + return data + + def set_inputs_and_targets(self): + inputs = self._data[0].sel({self.target_dim: helpers.to_list(self.variables)}) + targets = self._data[1].sel({self.target_dim: self.target_var}) + self.input_data.data = inputs + self.target_data.data = targets + + def setup_data_path(self, data_path, sampling): + """Sets two paths instead of single path. Expects sampling arg to be a list with two entries""" + assert len(sampling) == 2 + return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling)) + + +class DataHandlerMixedSampling(DefaultDataHandler): + """Data handler using mixed sampling for input and target.""" + + data_handler = DataHandlerMixedSamplingSingleStation + data_handler_transformation = DataHandlerMixedSamplingSingleStation + _requirements = data_handler.requirements() + + +class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation, + DataHandlerKzFilterSingleStation): + _requirements1 = DataHandlerKzFilterSingleStation.requirements() + _requirements2 = DataHandlerMixedSamplingSingleStation.requirements() + _requirements = list(set(_requirements1 + _requirements2)) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _check_sampling(self, **kwargs): + assert kwargs.get("sampling") == ("hourly", "daily") + + def setup_samples(self): + """ + Setup samples. This method prepares and creates samples X, and labels Y. + + A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values + with daily resolution. + """ + self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data + self.set_inputs_and_targets() + self.apply_kz_filter() + if self.do_transformation is True: + self.call_transform() + self.make_samples() + + def estimate_filter_width(self): + """ + f = 0.5 / (len * sqrt(itr)) -> T = 1 / f + :return: + """ + return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2) + + @staticmethod + def _add_time_delta(date, delta): + new_date = dt.datetime.strptime(date, "%Y-%m-%d") + dt.timedelta(hours=delta) + return new_date.strftime("%Y-%m-%d") + + def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: + + if ind == 0: # for inputs + estimated_filter_width = self.estimate_filter_width() + start = self._add_time_delta(self.start, -estimated_filter_width) + end = self._add_time_delta(self.end, estimated_filter_width) + else: # target + start, end = self.start, self.end + + data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + self.station_type, self.network, self.store_data_locally, self.data_origin, + start, end) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) + return data + + +class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): + """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" + + data_handler = DataHandlerMixedSamplingWithFilterSingleStation + data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation + _requirements = data_handler.requirements() + + +class DataHandlerMixedSamplingSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFilterSingleStation): + """ + Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the + separation frequency of a filtered time series the time step delta for input data is adjusted (see image below). + + .. image:: ../../../../../_source/_plots/separation_of_scales.png + :width: 400 + + """ + + _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements() + + def __init__(self, *args, time_delta=np.sqrt, **kwargs): + assert isinstance(time_delta, Callable) + self.time_delta = time_delta + super().__init__(*args, **kwargs) + + def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: + """ + Create a xr.DataArray containing history data. + + Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted + data. This is used to represent history in the data. Results are stored in history attribute. + + :param dim_name_of_inputs: Name of dimension which contains the input variables + :param window: number of time steps to look back in history + Note: window will be treated as negative value. This should be in agreement with looking back on + a time line. Nonetheless positive values are allowed but they are converted to its negative + expression + :param dim_name_of_shift: Dimension along shift will be applied + """ + window = -abs(window) + data = self.input_data.data + self.history = self.stride(data, dim_name_of_shift, window) + + def stride(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray: + + # this is just a code snippet to check the results of the kz filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # xr.concat(res, dim="filter").sel({"variables":"temp", "Stations":"DEBW107", "datetime":"2010-01-01T00:00:00"}).plot.line(hue="filter") + + time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int) + start, end = window, 1 + res = [] + window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim) + for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]): + res_filter = [] + data_filter = data.sel({"filter": filter_name}) + for w in range(start, end): + res_filter.append(data_filter.shift({dim: -w * delta})) + res_filter = xr.concat(res_filter, dim=window_array).chunk() + res.append(res_filter) + res = xr.concat(res, dim="filter") + return res + + def estimate_filter_width(self): + """ + Attention: this method returns the maximum value of + * either estimated filter width f = 0.5 / (len * sqrt(itr)) -> T = 1 / f or + * time delta method applied on the estimated filter width mupliplied by window_history_size + to provide a sufficiently wide filter width. + """ + est = self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2 + return int(max([self.time_delta(est) * self.window_history_size, est])) + + +class DataHandlerMixedSamplingSeparationOfScales(DefaultDataHandler): + """Data handler using mixed sampling for input and target. Inputs are temporal filtered and different time step + sizes are applied in relation to frequencies.""" + + data_handler = DataHandlerMixedSamplingSeparationOfScalesSingleStation + data_handler_transformation = DataHandlerMixedSamplingSeparationOfScalesSingleStation + _requirements = data_handler.requirements() diff --git a/mlair/data_handler/data_preparation_neighbors.py b/mlair/data_handler/data_handler_neighbors.py similarity index 85% rename from mlair/data_handler/data_preparation_neighbors.py rename to mlair/data_handler/data_handler_neighbors.py index 1482bb9fe20afcc2b92d2b91ae523a6dca19c54d..a004e659969232a080d49eb6905007d353bbe99c 100644 --- a/mlair/data_handler/data_preparation_neighbors.py +++ b/mlair/data_handler/data_handler_neighbors.py @@ -4,9 +4,9 @@ __date__ = '2020-07-17' from mlair.helpers import to_list -from mlair.data_handler.station_preparation import DataHandlerSingleStation from mlair.data_handler import DefaultDataHandler import os +import copy from typing import Union, List @@ -15,6 +15,7 @@ num_or_list = Union[number, List[number]] class DataHandlerNeighbors(DefaultDataHandler): + """Data handler including neighboring stations.""" def __init__(self, id_class, data_path, neighbors=None, min_length=0, extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False): @@ -24,14 +25,14 @@ class DataHandlerNeighbors(DefaultDataHandler): @classmethod def build(cls, station, **kwargs): - sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs} - sp = DataHandlerSingleStation(station, **sp_keys) + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} + sp = cls.data_handler(station, **sp_keys) n_list = [] for neighbor in kwargs.get("neighbors", []): - n_list.append(DataHandlerSingleStation(neighbor, **sp_keys)) + n_list.append(cls.data_handler(neighbor, **sp_keys)) else: kwargs["neighbors"] = n_list if len(n_list) > 0 else None - dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs} + dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} return cls(sp, **dp_args) def _create_collection(self): diff --git a/mlair/data_handler/station_preparation.py b/mlair/data_handler/data_handler_single_station.py similarity index 64% rename from mlair/data_handler/station_preparation.py rename to mlair/data_handler/data_handler_single_station.py index f3428e91bae3dc1d94a45dd7ff2bf931cff1fa54..e554a3b32d8e4e2f5482a388374cfba87f7add15 100644 --- a/mlair/data_handler/station_preparation.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -3,6 +3,7 @@ __author__ = 'Lukas Leufen, Felix Kleinert' __date__ = '2020-07-20' +import copy import datetime as dt import logging import os @@ -15,7 +16,7 @@ import xarray as xr from mlair.configuration import check_path_and_create from mlair import helpers -from mlair.helpers import join, statistics +from mlair.helpers import join, statistics, TimeTrackingWrapper from mlair.data_handler.abstract_data_handler import AbstractDataHandler # define a more general date type for type hinting @@ -48,12 +49,14 @@ class DataHandlerSingleStation(AbstractDataHandler): window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, interpolation_limit: int = 0, interpolation_method: str = DEFAULT_INTERPOLATION_METHOD, overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True, - min_length: int = 0, start=None, end=None, variables=None, **kwargs): + min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, **kwargs): super().__init__() # path, station, statistics_per_var, transformation, **kwargs) self.station = helpers.to_list(station) - self.path = os.path.abspath(data_path) + self.path = self.setup_data_path(data_path, sampling) self.statistics_per_var = statistics_per_var - self.transformation = self.setup_transformation(transformation) + self.data_origin = data_origin + self.do_transformation = transformation is not None + self.input_data, self.target_data = self.setup_transformation(transformation) self.station_type = station_type self.network = network @@ -74,20 +77,13 @@ class DataHandlerSingleStation(AbstractDataHandler): self.end = end # internal - self.data = None + self._data: xr.DataArray = None # loaded raw data self.meta = None self.variables = list(statistics_per_var.keys()) if variables is None else variables self.history = None self.label = None self.observation = None - # internal for transformation - self.mean = None - self.std = None - self.max = None - self.min = None - self._transform_method = None - # create samples self.setup_samples() @@ -100,7 +96,7 @@ class DataHandlerSingleStation(AbstractDataHandler): @property def shape(self): - return self.data.shape, self.get_X().shape, self.get_Y().shape + return self._data.shape, self.get_X().shape, self.get_Y().shape def __repr__(self): return f"StationPrep(station={self.station}, data_path='{self.path}', " \ @@ -109,24 +105,7 @@ class DataHandlerSingleStation(AbstractDataHandler): f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \ f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \ f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \ - f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data}, " \ - f"transformation={self._print_transformation_as_string})" - - @property - def _print_transformation_as_string(self): - str_name = '' - if self.transformation is None: - str_name = f'{None}' - else: - for k, v in self.transformation.items(): - if v is not None: - try: - v_pr = f"xr.DataArray.from_dict({v.to_dict()})" - except AttributeError: - v_pr = f"'{v}'" - str_name += f"'{k}':{v_pr}, " - str_name = f"{{{str_name}}}" - return str_name + f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data})" def get_transposed_history(self) -> xr.DataArray: """Return history. @@ -153,36 +132,40 @@ class DataHandlerSingleStation(AbstractDataHandler): return coords.rename(index={"station_lon": "lon", "station_lat": "lat"}).to_dict()[str(self)] def call_transform(self, inverse=False): - self.transform(dim=self.time_dim, method=self.transformation["method"], - mean=self.transformation['mean'], std=self.transformation["std"], - min_val=self.transformation["min"], max_val=self.transformation["max"], - inverse=inverse - ) - - def set_transformation(self, transformation: dict): - if self._transform_method is not None: - self.call_transform(inverse=True) - self.transformation = self.setup_transformation(transformation) - self.call_transform() - self.make_samples() + kwargs = helpers.remove_items(self.input_data.as_dict(), ["data"]) + self.transform(self.input_data, dim=self.time_dim, inverse=inverse, **kwargs) + kwargs = helpers.remove_items(self.target_data.as_dict(), ["data"]) + self.transform(self.target_data, dim=self.time_dim, inverse=inverse, **kwargs) + @TimeTrackingWrapper def setup_samples(self): """ Setup samples. This method prepares and creates samples X, and labels Y. """ - self.load_data() - self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) - if self.transformation is not None: + data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, + self.station_type, self.network, self.store_data_locally, self.data_origin, + self.start, self.end) + self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) + self.set_inputs_and_targets() + if self.do_transformation is True: self.call_transform() self.make_samples() + def set_inputs_and_targets(self): + inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)}) + targets = self._data.sel({self.target_dim: self.target_var}) + self.input_data.data = inputs + self.target_data.data = targets + def make_samples(self): self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) self.make_observation(self.target_dim, self.target_var, self.time_dim) self.remove_nan(self.time_dim) - def read_data_from_disk(self, source_name=""): + def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None, + store_data_locally=False, data_origin: Dict = None, start=None, end=None): """ Load data and meta data either from local disk (preferred) or download new data by using a custom download method. @@ -190,35 +173,42 @@ class DataHandlerSingleStation(AbstractDataHandler): cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not set, it is assumed, that data should be saved locally. """ - source_name = source_name if len(source_name) == 0 else f" from {source_name}" - check_path_and_create(self.path) - file_name = self._set_file_name() - meta_file = self._set_meta_file_name() + check_path_and_create(path) + file_name = self._set_file_name(path, station, statistics_per_var) + meta_file = self._set_meta_file_name(path, station, statistics_per_var) if self.overwrite_local_data is True: - logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}") + logging.debug(f"overwrite_local_data is true, therefore reload {file_name}") if os.path.exists(file_name): os.remove(file_name) if os.path.exists(meta_file): os.remove(meta_file) - data, self.meta = self.download_data(file_name, meta_file) - logging.debug(f"loaded new data{source_name}") + data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, + station_type=station_type, network=network, + store_data_locally=store_data_locally, data_origin=data_origin) + logging.debug(f"loaded new data") else: try: logging.debug(f"try to load local data from: {file_name}") data = xr.open_dataarray(file_name) - self.meta = pd.read_csv(meta_file, index_col=0) - self.check_station_meta() + meta = pd.read_csv(meta_file, index_col=0) + self.check_station_meta(meta, station, station_type, network) logging.debug("loading finished") except FileNotFoundError as e: logging.debug(e) - logging.debug(f"load new data{source_name}") - data, self.meta = self.download_data(file_name, meta_file) + logging.debug(f"load new data") + data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, + station_type=station_type, network=network, + store_data_locally=store_data_locally, data_origin=data_origin) logging.debug("loading finished") # create slices and check for negative concentration. - data = self._slice_prep(data) - self.data = self.check_for_negative_concentrations(data) + data = self._slice_prep(data, start=start, end=end) + data = self.check_for_negative_concentrations(data) + return data, meta - def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: + @staticmethod + def download_data_from_join(file_name: str, meta_file: str, station, statistics_per_var, sampling, + station_type=None, network=None, store_data_locally=True, data_origin: Dict = None) \ + -> [xr.DataArray, pd.DataFrame]: """ Download data from TOAR database using the JOIN interface. @@ -231,36 +221,37 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: downloaded data and its meta data """ df_all = {} - df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var, - station_type=self.station_type, network_name=self.network, sampling=self.sampling) - df_all[self.station[0]] = df + df, meta = join.download_join(station_name=station, stat_var=statistics_per_var, station_type=station_type, + network_name=network, sampling=sampling, data_origin=data_origin) + df_all[station[0]] = df # convert df_all to xarray xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()} xarr = xr.Dataset(xarr).to_array(dim='Stations') - if self.store_data_locally is True: + if store_data_locally is True: # save locally as nc/csv file xarr.to_netcdf(path=file_name) meta.to_csv(meta_file) return xarr, meta - def download_data(self, file_name, meta_file): - data, meta = self.download_data_from_join(file_name, meta_file) + def download_data(self, *args, **kwargs): + data, meta = self.download_data_from_join(*args, **kwargs) return data, meta - def check_station_meta(self): + @staticmethod + def check_station_meta(meta, station, station_type, network): """ Search for the entries in meta data and compare the value with the requested values. Will raise a FileNotFoundError if the values mismatch. """ - if self.station_type is not None: - check_dict = {"station_type": self.station_type, "network_name": self.network} + if station_type is not None: + check_dict = {"station_type": station_type, "network_name": network} for (k, v) in check_dict.items(): if v is None: continue - if self.meta.at[k, self.station[0]] != v: + if meta.at[k, station[0]] != v: logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " - f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " + f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new " f"grapping from web.") raise FileNotFoundError @@ -279,14 +270,19 @@ class DataHandlerSingleStation(AbstractDataHandler): """ chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", "toluene"] + # used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys())) used_chem_vars = list(set(chem_vars) & set(self.variables)) data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data - def shift(self, dim: str, window: int) -> xr.DataArray: + def setup_data_path(self, data_path: str, sampling: str): + return os.path.join(os.path.abspath(data_path), sampling) + + def shift(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray: """ Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0). + :param data: data set to shift :param dim: dimension along shift is applied :param window: number of steps to shift (corresponds to the window length) @@ -300,7 +296,7 @@ class DataHandlerSingleStation(AbstractDataHandler): end = window + 1 res = [] for w in range(start, end): - res.append(self.data.shift({dim: -w})) + res.append(data.shift({dim: -w})) window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim) res = xr.concat(res, dim=window_array) return res @@ -324,15 +320,18 @@ class DataHandlerSingleStation(AbstractDataHandler): res.name = index_name return res - def _set_file_name(self): - all_vars = sorted(self.statistics_per_var.keys()) - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc") + @staticmethod + def _set_file_name(path, station, statistics_per_var): + all_vars = sorted(statistics_per_var.keys()) + return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc") - def _set_meta_file_name(self): - all_vars = sorted(self.statistics_per_var.keys()) - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv") + @staticmethod + def _set_meta_file_name(path, station, statistics_per_var): + all_vars = sorted(statistics_per_var.keys()) + return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv") - def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, + @staticmethod + def interpolate(data, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, **kwargs): """ Interpolate values according to different methods. @@ -370,8 +369,7 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: xarray.DataArray """ - self.data = self.data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, - **kwargs) + return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs) def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: """ @@ -388,7 +386,8 @@ class DataHandlerSingleStation(AbstractDataHandler): :param dim_name_of_shift: Dimension along shift will be applied """ window = -abs(window) - self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables}) + data = self.input_data.data + self.history = self.shift(data, dim_name_of_shift, window) def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, window: int) -> None: @@ -404,7 +403,8 @@ class DataHandlerSingleStation(AbstractDataHandler): :param window: lead time of label """ window = abs(window) - self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var}) + data = self.target_data.data + self.label = self.shift(data, dim_name_of_shift, window) def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: """ @@ -416,7 +416,8 @@ class DataHandlerSingleStation(AbstractDataHandler): :param target_var: Name of observation variable(s) in 'dimension' :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied """ - self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var}) + data = self.target_data.data + self.observation = self.shift(data, dim_name_of_shift, 0) def remove_nan(self, dim: str) -> None: """ @@ -443,7 +444,7 @@ class DataHandlerSingleStation(AbstractDataHandler): self.label = self.label.sel({dim: intersect}) self.observation = self.observation.sel({dim: intersect}) - def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray: + def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray: """ Set start and end date for slicing and execute self._slice(). @@ -452,9 +453,9 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: sliced data """ - start = self.start if self.start is not None else data.coords[coord][0].values - end = self.end if self.end is not None else data.coords[coord][-1].values - return self._slice(data, start, end, coord) + start = start if start is not None else data.coords[self.time_dim][0].values + end = end if end is not None else data.coords[self.time_dim][-1].values + return self._slice(data, start, end, self.time_dim) @staticmethod def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray: @@ -470,119 +471,28 @@ class DataHandlerSingleStation(AbstractDataHandler): """ return data.loc[{coord: slice(str(start), str(end))}] - def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: - """ - Set all negative concentrations to zero. - - Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/ - #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox", - "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene". - - :param data: data array containing variables to check - :param minimum: minimum value, by default this should be 0 - - :return: corrected data - """ - chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", - "propane", "so2", "toluene"] - used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys())) - data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) - return data - @staticmethod - def setup_transformation(transformation: Dict): + def setup_transformation(transformation: statistics.TransformationClass): """ Set up transformation by extracting all relevant information. - Extract all information from transformation dictionary. Possible keys are method, mean, std, min, max. - * If a transformation should be applied on base of existing values, these need to be provided in the respective - keys "mean" and "std" (again only if required for given method). - - :param transformation: the transformation dictionary as described above. - - :return: updated transformation dictionary - - ## Transformation - - There are two different approaches (called scopes) to transform the data: - 1) `station`: transform data for each station independently (somehow like batch normalisation) - 1) `data`: transform all data of each station with shared metrics - - Transformation must be set by the `transformation` attribute. If `transformation = None` is given to `ExperimentSetup`, - data is not transformed at all. For all other setups, use the following dictionary structure to specify the - transformation. - ``` - transformation = {"scope": <...>, - "method": <...>, - "mean": <...>, - "std": <...>} - ExperimentSetup(..., transformation=transformation, ...) - ``` - - ### scopes - - **station**: mean and std are not used - - **data**: either provide already calculated values for mean and std (if required by transformation method), or choose - from different calculation schemes, explained in the mean and std section. - - ### supported transformation methods - Currently supported methods are: - * standardise (default, if method is not given) - * centre - - ### mean and std - `"mean"="accurate"`: calculate the accurate values of mean and std (depending on method) by using all data. Although, - this method is accurate, it may take some time for the calculation. Furthermore, this could potentially lead to memory - issue (not explored yet, but could appear for a very big amount of data) - - `"mean"="estimate"`: estimate mean and std (depending on method). For each station, mean and std are calculated and - afterwards aggregated using the mean value over all station-wise metrics. This method is less accurate, especially - regarding the std calculation but therefore much faster. - - We recommend to use the later method *estimate* because of following reasons: - * much faster calculation - * real accuracy of mean and std is less important, because it is "just" a transformation / scaling - * accuracy of mean is almost as high as in the *accurate* case, because of - $\bar{x_{ij}} = \bar{\left(\bar{x_i}\right)_j}$. The only difference is, that in the *estimate* case, each mean is - equally weighted for each station independently of the actual data count of the station. - * accuracy of std is lower for *estimate* because of $\var{x_{ij}} \ne \bar{\left(\var{x_i}\right)_j}$, but still the mean of all - station-wise std is a decent estimate of the true std. - - `"mean"=<value, e.g. xr.DataArray>`: If mean and std are already calculated or shall be set manually, just add the - scaling values instead of the calculation method. For method *centre*, std can still be None, but is required for the - *standardise* method. **Important**: Format of given values **must** match internal data format of DataPreparation - class: `xr.DataArray` with `dims=["variables"]` and one value for each variable. - + * Either return new empty DataClass instances if given transformation arg is None, + * or return given object twice if transformation is a DataClass instance, + * or return the inputs and targets attributes if transformation is a TransformationClass instance (default + design behaviour) """ if transformation is None: - return - elif not isinstance(transformation, dict): - raise TypeError(f"`transformation' must be either `None' or dict like e.g. `{{'method': 'standardise'}}," - f" but transformation is of type {type(transformation)}.") - transformation = transformation.copy() - method = transformation.get("method", None) - mean = transformation.get("mean", None) - std = transformation.get("std", None) - max_val = transformation.get("max", None) - min_val = transformation.get("min", None) - - transformation["method"] = method - transformation["mean"] = mean - transformation["std"] = std - transformation["max"] = max_val - transformation["min"] = min_val - return transformation - - def load_data(self): - try: - self.read_data_from_disk() - except FileNotFoundError: - self.download_data() - self.load_data() - - def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean=None, - std=None, min_val=None, max_val=None) -> None: + return statistics.DataClass(), statistics.DataClass() + elif isinstance(transformation, statistics.DataClass): + return transformation, transformation + elif isinstance(transformation, statistics.TransformationClass): + return copy.deepcopy(transformation.inputs), copy.deepcopy(transformation.targets) + else: + raise NotImplementedError("Cannot handle this.") + + def transform(self, data_class, dim: Union[str, int] = 0, transform_method: str = 'standardise', + inverse: bool = False, mean=None, + std=None, min=None, max=None) -> None: """ Transform data according to given transformation settings. @@ -602,9 +512,9 @@ class DataHandlerSingleStation(AbstractDataHandler): calculated over the data in this class instance. :param std: Used for transformation (if required by 'method') based on external data. If 'None' the std is calculated over the data in this class instance. - :param min_val: Used for transformation (if required by 'method') based on external data. If 'None' min_val is + :param min: Used for transformation (if required by 'method') based on external data. If 'None' min_val is extracted from the data in this class instance. - :param max_val: Used for transformation (if required by 'method') based on external data. If 'None' max_val is + :param max: Used for transformation (if required by 'method') based on external data. If 'None' max_val is extracted from the data in this class instance. :return: xarray.DataArrays or pandas.DataFrames: @@ -614,36 +524,37 @@ class DataHandlerSingleStation(AbstractDataHandler): """ def f(data): - if method == 'standardise': + if transform_method == 'standardise': return statistics.standardise(data, dim) - elif method == 'centre': + elif transform_method == 'centre': return statistics.centre(data, dim) - elif method == 'normalise': + elif transform_method == 'normalise': # use min/max of data or given min/max raise NotImplementedError else: raise NotImplementedError def f_apply(data): - if method == "standardise": + if transform_method == "standardise": return mean, std, statistics.standardise_apply(data, mean, std) - elif method == "centre": + elif transform_method == "centre": return mean, None, statistics.centre_apply(data, mean) else: raise NotImplementedError if not inverse: - if self._transform_method is not None: - raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with " - f"{self._transform_method}. Please perform inverse transformation of data first.") + if data_class._method is not None: + raise AssertionError(f"Internal _method is already set. Therefore, data was already transformed with " + f"{data_class._method}. Please perform inverse transformation of data first.") # apply transformation on local data instance (f) if mean is None, else apply by using mean (and std) from # external data. - self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data) + data_class.mean, data_class.std, data_class.data = locals()["f" if mean is None else "f_apply"]( + data_class.data) # set transform method to find correct method for inverse transformation. - self._transform_method = method + data_class._method = transform_method else: - self.inverse_transform() + self.inverse_transform(data_class) @staticmethod def check_inverse_transform_params(mean: data_or_none, std: data_or_none, method: str) -> None: @@ -665,7 +576,7 @@ class DataHandlerSingleStation(AbstractDataHandler): if len(msg) > 0: raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}") - def inverse_transform(self) -> None: + def inverse_transform(self, data_class) -> None: """ Perform inverse transformation. @@ -685,36 +596,26 @@ class DataHandlerSingleStation(AbstractDataHandler): else: raise NotImplementedError - if self._transform_method is None: + if data_class.transform_method is None: raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.") - self.check_inverse_transform_params(self.mean, self.std, self._transform_method) - self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method) - self._transform_method = None + self.check_inverse_transform_params(data_class.mean, data_class.std, data_class._method) + data_class.data, data_class.mean, data_class.std = f_inverse(data_class.data, data_class.mean, data_class.std, + data_class._method) + data_class.transform_method = None # update X and Y self.make_samples() - def get_transformation_information(self, variable: str = None) -> Tuple[data_or_none, data_or_none, str]: + def get_transformation_targets(self) -> Tuple[data_or_none, data_or_none, str]: """ Extract transformation statistics and method. - Get mean and standard deviation for given variable and the transformation method if set. If a transformation + Get mean and standard deviation for target values and the transformation method if set. If a transformation depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are returned with None as fill value. - :param variable: Variable for which the information on transformation is requested. - :return: mean, standard deviation and transformation method """ - variable = self.target_var if variable is None else variable - try: - mean = self.mean.sel({'variables': variable}).values - except AttributeError: - mean = None - try: - std = self.std.sel({'variables': variable}).values - except AttributeError: - std = None - return mean, std, self._transform_method + return self.target_data.mean, self.target_data.std, self.target_data.transform_method if __name__ == "__main__": @@ -727,7 +628,6 @@ if __name__ == "__main__": time_dim='datetime', window_history_size=7, window_lead_time=3, interpolation_limit=0 ) # transformation={'method': 'standardise'}) - # sp.set_transformation({'method': 'standardise', 'mean': sp.mean+2, 'std': sp.std+1}) sp2 = DataHandlerSingleStation(data_path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122', statistics_per_var=statistics_per_var, station_type='background', network='UBA', sampling='daily', target_dim='variables', target_var='o3', diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 47f63a3e7bcbebd131c2a0da47d2e0833b02efed..584151e36fd0c9621d089e88b8ad61cffa0c5925 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -4,6 +4,7 @@ __date__ = '2020-09-21' import copy import inspect +import gc import logging import os import pickle @@ -15,7 +16,6 @@ import numpy as np import xarray as xr from mlair.data_handler.abstract_data_handler import AbstractDataHandler -from mlair.data_handler.station_preparation import DataHandlerSingleStation from mlair.helpers import remove_items, to_list from mlair.helpers.join import EmptyQueryResult @@ -25,11 +25,14 @@ num_or_list = Union[number, List[number]] class DefaultDataHandler(AbstractDataHandler): + from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler + from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation - _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) + _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"]) - def __init__(self, id_class: DataHandlerSingleStation, data_path: str, min_length: int = 0, - extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None): + def __init__(self, id_class: data_handler, experiment_path: str, min_length: int = 0, + extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None, + store_processed_data=True): super().__init__() self.id_class = id_class self.interpolation_dim = "datetime" @@ -39,16 +42,16 @@ class DefaultDataHandler(AbstractDataHandler): self._X_extreme = None self._Y_extreme = None _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) - self._save_file = os.path.join(data_path, f"data_preparation_{_name_affix}.pickle") + self._save_file = os.path.join(experiment_path, "data", f"{_name_affix}.pickle") self._collection = self._create_collection() self.harmonise_X() self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim) - self._store(fresh_store=True) + self._store(fresh_store=True, store_processed_data=store_processed_data) @classmethod def build(cls, station: str, **kwargs): sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} - sp = DataHandlerSingleStation(station, **sp_keys) + sp = cls.data_handler(station, **sp_keys) dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} return cls(sp, **dp_args) @@ -61,6 +64,7 @@ class DefaultDataHandler(AbstractDataHandler): def _reset_data(self): self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None + gc.collect() def _cleanup(self): directory = os.path.dirname(self._save_file) @@ -69,13 +73,14 @@ class DefaultDataHandler(AbstractDataHandler): if os.path.exists(self._save_file): shutil.rmtree(self._save_file, ignore_errors=True) - def _store(self, fresh_store=False): - self._cleanup() if fresh_store is True else None - data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme} - with open(self._save_file, "wb") as f: - pickle.dump(data, f) - logging.debug(f"save pickle data to {self._save_file}") - self._reset_data() + def _store(self, fresh_store=False, store_processed_data=True): + if store_processed_data is True: + self._cleanup() if fresh_store is True else None + data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme} + with open(self._save_file, "wb") as f: + pickle.dump(data, f) + logging.debug(f"save pickle data to {self._save_file}") + self._reset_data() def _load(self): try: @@ -140,7 +145,7 @@ class DefaultDataHandler(AbstractDataHandler): return self.id_class.observation.copy().squeeze() def get_transformation_Y(self): - return self.id_class.get_transformation_information() + return self.id_class.get_transformation_targets() def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"): @@ -212,27 +217,55 @@ class DefaultDataHandler(AbstractDataHandler): @classmethod def transformation(cls, set_stations, **kwargs): + """ + ### supported transformation methods + + Currently supported methods are: + + * standardise (default, if method is not given) + * centre + + ### mean and std estimation + + Mean and std (depending on method) are estimated. For each station, mean and std are calculated and afterwards + aggregated using the mean value over all station-wise metrics. This method is not exactly accurate, especially + regarding the std calculation but therefore much faster. Furthermore, it is a weighted mean weighted by the + time series length / number of data itself - a longer time series has more influence on the transformation + settings than a short time series. The estimation of the std in less accurate, because the unweighted mean of + all stds in not equal to the true std, but still the mean of all station-wise std is a decent estimate. Finally, + the real accuracy of mean and std is less important, because it is "just" a transformation / scaling. + + ### mean and std given + + If mean and std are not None, the default data handler expects this parameters to match the data and applies + this values to the data. Make sure that all dimensions and/or coordinates are in agreement. + """ + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} - transformation_dict = sp_keys.pop("transformation") - if transformation_dict is None: + transformation_class = sp_keys.get("transformation", None) + if transformation_class is None: return - scope = transformation_dict.pop("scope") - method = transformation_dict.pop("method") - if transformation_dict.pop("mean", None) is not None: + + transformation_inputs = transformation_class.inputs + if transformation_inputs.mean is not None: return - mean, std = None, None + means = [None, None] + stds = [None, None] for station in set_stations: try: - sp = DataHandlerSingleStation(station, transformation={"method": method}, **sp_keys) - mean = sp.mean.copy(deep=True) if mean is None else mean.combine_first(sp.mean) - std = sp.std.copy(deep=True) if std is None else std.combine_first(sp.std) + sp = cls.data_handler_transformation(station, **sp_keys) + for i, data in enumerate([sp.input_data, sp.target_data]): + means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) + stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) except (AttributeError, EmptyQueryResult): continue - if mean is None: + if means[0] is None: return None - mean_estimated = mean.mean("Stations") - std_estimated = std.mean("Stations") - return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated} + transformation_class.inputs.mean = means[0].mean("Stations") + transformation_class.inputs.std = stds[0].mean("Stations") + transformation_class.targets.mean = means[1].mean("Stations") + transformation_class.targets.std = stds[1].mean("Stations") + return transformation_class def get_coordinates(self): return self.id_class.get_coordinates() \ No newline at end of file diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index b12d9028747aa677802c4a99e35852b514128e4c..3ecf1f6213bf39d2e3571a1b451173b981a3dadf 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -32,16 +32,21 @@ def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray: :return: combined xarray """ - xarray = None - for k, v in d.items(): - if xarray is None: - xarray = v - xarray.coords[coordinate_name] = k - else: - tmp_xarray = v - tmp_xarray.coords[coordinate_name] = k - xarray = xr.concat([xarray, tmp_xarray], coordinate_name) - return xarray + if len(d.keys()) == 1: + k = list(d.keys()) + xarray: xr.DataArray = d[k[0]] + return xarray.expand_dims(dim={coordinate_name: k}, axis=0) + else: + xarray = None + for k, v in d.items(): + if xarray is None: + xarray = v + xarray.coords[coordinate_name] = k + else: + tmp_xarray = v + tmp_xarray.coords[coordinate_name] = k + xarray = xr.concat([xarray, tmp_xarray], coordinate_name) + return xarray def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float: diff --git a/mlair/helpers/join.py b/mlair/helpers/join.py index a3c6876e3ea43ff4d03243430cf6cd791d62dec2..43a0176811b54fba2983c1dba108f4c7977f1431 100644 --- a/mlair/helpers/join.py +++ b/mlair/helpers/join.py @@ -23,7 +23,8 @@ class EmptyQueryResult(Exception): def download_join(station_name: Union[str, List[str]], stat_var: dict, station_type: str = None, - network_name: str = None, sampling: str = "daily") -> [pd.DataFrame, pd.DataFrame]: + network_name: str = None, sampling: str = "daily", data_origin: Dict = None) -> [pd.DataFrame, + pd.DataFrame]: """ Read data from JOIN/TOAR. @@ -32,6 +33,8 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t :param station_type: set the station type like "traffic" or "background", can be none :param network_name: set the measurement network like "UBA" or "AIRBASE", can be none :param sampling: sampling rate of the downloaded data, either set to daily or hourly (default daily) + :param data_origin: additional dictionary to specify data origin as key (for variable) value (origin) pair. Valid + origins are "REA" for reanalysis data and "" (empty string) for observational data. :returns: data frame with all variables and statistics and meta data frame with all meta information """ @@ -42,11 +45,11 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t join_url_base, headers = join_settings(sampling) # load series information - vars_dict = load_series_information(station_name, station_type, network_name, join_url_base, headers) + vars_dict = load_series_information(station_name, station_type, network_name, join_url_base, headers, data_origin) # correct stat_var values if data is not aggregated (hourly) if sampling == "hourly": - [stat_var.update({k: "values"}) for k in stat_var.keys()] + stat_var = {key: "values" for key in stat_var.keys()} # download all variables with given statistic data = None @@ -55,7 +58,7 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t for var in _lower_list(sorted(vars_dict.keys())): if var in stat_var.keys(): - logging.debug('load: {}'.format(var)) + logging.debug('load: {}'.format(var)) # ToDo start here for #206 # create data link opts = {'base': join_url_base, 'service': 'stats', 'id': vars_dict[var], 'statistics': stat_var[var], @@ -123,7 +126,7 @@ def get_data(opts: Dict, headers: Dict) -> Union[Dict, List]: def load_series_information(station_name: List[str], station_type: str_or_none, network_name: str_or_none, - join_url_base: str, headers: Dict) -> Dict: + join_url_base: str, headers: Dict, data_origin: Dict = None) -> Dict: """ List all series ids that are available for given station id and network name. @@ -132,14 +135,36 @@ def load_series_information(station_name: List[str], station_type: str_or_none, :param network_name: measurement network of the station like "UBA" or "AIRBASE" :param join_url_base: base url name to download data from :param headers: additional headers information like authorization, can be empty + :param data_origin: additional information to select a distinct series e.g. from reanalysis (REA) or from observation + ("", empty string). This dictionary should contain a key for each variable and the information as key :return: all available series for requested station stored in an dictionary with parameter name (variable) as key and the series id as value. """ - opts = {"base": join_url_base, "service": "series", "station_id": station_name[0], "station_type": station_type, - "network_name": network_name} + opts = {"base": join_url_base, "service": "search", "station_id": station_name[0], "station_type": station_type, + "network_name": network_name, "as_dict": "true", + "columns": "id,network_name,station_id,parameter_name,parameter_label,parameter_attribute"} station_vars = get_data(opts, headers) - vars_dict = {item[3].lower(): item[0] for item in station_vars} - return vars_dict + logging.debug(f"{station_name}: {station_vars}") # ToDo start here for #206 + return _select_distinct_series(station_vars, data_origin) + + +def _select_distinct_series(vars: List[Dict], data_origin: Dict = None): + """ + Select distinct series ids for all variables. Also check if a parameter is from REA or not. + """ + if data_origin is None: + data_origin = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", + "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", + "no": "", "no2": "", "o3": "", "pm10": "", "so2": ""} + # ToDo: maybe press, wdir, wspeed from obs? or also temp, ... ? + selected = {} + for var in vars: + name = var["parameter_name"].lower() + var_attr = var["parameter_attribute"].lower() + attr = data_origin.get(name, "").lower() + if var_attr == attr: + selected[name] = var["id"] + return selected def _save_to_pandas(df: Union[pd.DataFrame, None], data: dict, stat: str, var: str) -> pd.DataFrame: diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index 056f92bec25b8d5216988f4dacb8fcd1e5257ab5..3db6618a5e8ebd575d61bc261144ff47ccaf9b53 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -9,10 +9,36 @@ import numpy as np import xarray as xr import pandas as pd from typing import Union, Tuple, Dict +from matplotlib import pyplot as plt + +from mlair.helpers import to_list, remove_items Data = Union[xr.DataArray, pd.DataFrame] +class DataClass: + + def __init__(self, data=None, mean=None, std=None, max=None, min=None, transform_method=None): + self.data = data + self.mean = mean + self.std = std + self.max = max + self.min = min + self.transform_method = transform_method + self._method = None + + def as_dict(self): + return remove_items(self.__dict__, "_method") + + +class TransformationClass: + + def __init__(self, inputs_mean=None, inputs_std=None, inputs_method=None, targets_mean=None, targets_std=None, + targets_method=None): + self.inputs = DataClass(mean=inputs_mean, std=inputs_std, transform_method=inputs_method) + self.targets = DataClass(mean=targets_mean, std=targets_std, transform_method=targets_method) + + def apply_inverse_transformation(data: Data, mean: Data, std: Data = None, method: str = "standardise") -> Data: """ Apply inverse transformation for given statistics. @@ -345,3 +371,168 @@ class SkillScores: monthly_mean[monthly_mean.index.dt.month == month, :] = mu[mu.month == month].values return monthly_mean + + +class KolmogorovZurbenkoBaseClass: + + def __init__(self, df, wl, itr, is_child=False, filter_dim="window"): + """ + It create the variables associate with the Kolmogorov-Zurbenko-filter. + + Args: + df(pd.DataFrame, None): time series of a variable + wl(list of int): window length + itr(list of int): number of iteration + """ + self.df = df + self.filter_dim = filter_dim + self.wl = to_list(wl) + self.itr = to_list(itr) + if abs(len(self.wl) - len(self.itr)) > 0: + raise ValueError("Length of lists for wl and itr must agree!") + self._isChild = is_child + self.child = self.set_child() + self.type = type(self).__name__ + + def set_child(self): + if len(self.wl) > 1: + return KolmogorovZurbenkoBaseClass(None, self.wl[1:], self.itr[1:], True, self.filter_dim) + else: + return None + + def kz_filter(self, df, m, k): + pass + + def spectral_calc(self): + df_start = self.df + kz = self.kz_filter(df_start, self.wl[0], self.itr[0]) + filtered = self.subtract(df_start, kz) + # case I: no child avail -> return kz and remaining + if self.child is None: + return [kz, filtered] + # case II: has child -> return current kz and all child results + else: + self.child.df = filtered + kz_next = self.child.spectral_calc() + return [kz] + kz_next + + @staticmethod + def subtract(minuend, subtrahend): + try: # pandas implementation + return minuend.sub(subtrahend, axis=0) + except AttributeError: # general implementation + return minuend - subtrahend + + def run(self): + return self.spectral_calc() + + def transfer_function(self): + m = self.wl[0] + k = self.itr[0] + omega = np.linspace(0.00001, 0.15, 5000) + return omega, (np.sin(m * np.pi * omega) / (m * np.sin(np.pi * omega))) ** (2 * k) + + def omega_null(self, alpha=0.5): + a = np.sqrt(6) / np.pi + b = 1 / (2 * np.array(self.itr)) + c = 1 - alpha ** b + d = np.array(self.wl) ** 2 - alpha ** b + return a * np.sqrt(c / d) + + def period_null(self, alpha=0.5): + return 1. / self.omega_null(alpha) + + def period_null_days(self, alpha=0.5): + return self.period_null(alpha) / 24. + + def plot_transfer_function(self, fig=None, name=None): + if fig is None: + fig = plt.figure() + omega, transfer_function = self.transfer_function() + if self.child is not None: + transfer_function_child = self.child.plot_transfer_function(fig) + else: + transfer_function_child = transfer_function * 0 + plt.semilogx(omega, transfer_function - transfer_function_child, + label="m={:3.0f}, k={:3.0f}, T={:6.2f}d".format(self.wl[0], + self.itr[0], + self.period_null_days())) + plt.axvline(x=self.omega_null()) + if not self._isChild: + locs, labels = plt.xticks() + plt.xticks(locs, np.round(1. / (locs * 24), 1)) + plt.xlim([0.00001, 0.15]) + plt.legend() + if name is None: + plt.show() + else: + plt.savefig(name) + else: + return transfer_function + + +class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass): + + def __init__(self, df, wl: Union[list, int], itr: Union[list, int], is_child=False, filter_dim="window", + method="mean", percentile=0.5): + """ + It create the variables associate with the KolmogorovZurbenkoFilterMovingWindow class. + + Args: + df(pd.DataFrame, xr.DataArray): time series of a variable + wl: window length + itr: number of iteration + """ + self.valid_methods = ["mean", "percentile", "median", "max", "min"] + if method not in self.valid_methods: + raise ValueError("Method '{}' is not supported. Please select from [{}].".format( + method, ", ".join(self.valid_methods))) + else: + self.method = method + if percentile > 1 or percentile < 0: + raise ValueError("Percentile must be in range [0, 1]. Given was {}!".format(percentile)) + else: + self.percentile = percentile + super().__init__(df, wl, itr, is_child, filter_dim) + + def set_child(self): + if len(self.wl) > 1: + return KolmogorovZurbenkoFilterMovingWindow(self.df, self.wl[1:], self.itr[1:], is_child=True, + filter_dim=self.filter_dim, method=self.method, + percentile=self.percentile) + else: + return None + + def kz_filter(self, df, wl, itr): + """ + It passes the low frequency time series. + + Args: + wl(int): a window length + itr(int): a number of iteration + """ + df_itr = df.__deepcopy__() + try: + kwargs = {"min_periods": 1, + "center": True, + self.filter_dim: wl} + iter_vars = df_itr.coords["variables"].values + for var in iter_vars: + df_itr_var = df_itr.sel(variables=[var]).chunk() + for _ in np.arange(0, itr): + rolling = df_itr_var.rolling(**kwargs) + if self.method == "median": + df_mv_avg_tmp = rolling.median() + elif self.method == "percentile": + df_mv_avg_tmp = rolling.quantile(self.percentile) + elif self.method == "max": + df_mv_avg_tmp = rolling.max() + elif self.method == "min": + df_mv_avg_tmp = rolling.min() + else: + df_mv_avg_tmp = rolling.mean() + df_itr_var = df_mv_avg_tmp.compute() + df_itr = df_itr.drop_sel(variables=var).combine_first(df_itr_var) + return df_itr + except ValueError: + raise ValueError diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index c9cc13bd8108e43b5a9f03682942eacdf5a55f04..a603b466e4dab0dc30b6b6b22d10b6c27ee59767 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -396,8 +396,66 @@ class MyLittleModel(AbstractModelClass): def set_compile_options(self): self.initial_lr = 1e-2 self.optimizer = keras.optimizers.adam(lr=self.initial_lr) - self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, - epochs_drop=10) + # self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, + # epochs_drop=10) + self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]} + + +class MyLittleModelHourly(AbstractModelClass): + """ + A customised model with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the + output layer depending on the window_lead_time parameter. Dropout is used between the Convolution and the first + Dense layer. + """ + + def __init__(self, input_shape: list, output_shape: list): + """ + Sets model and loss depending on the given arguments. + + :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables)) + :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast)) + """ + + assert len(input_shape) == 1 + assert len(output_shape) == 1 + super().__init__(input_shape[0], output_shape[0]) + + # settings + self.dropout_rate = 0.1 + self.regularizer = keras.regularizers.l2(0.001) + self.activation = keras.layers.PReLU + + # apply to model + self.set_model() + self.set_compile_options() + self.set_custom_objects(loss=self.compile_options['loss']) + + def set_model(self): + """ + Build the model. + """ + + # add 1 to window_size to include current time step t0 + x_input = keras.layers.Input(shape=self._input_shape) + x_in = keras.layers.Conv2D(128, (1, 1), padding='same', name='{}_Conv_1x1_128'.format("major"))(x_input) + x_in = self.activation()(x_in) + x_in = keras.layers.Conv2D(64, (1, 1), padding='same', name='{}_Conv_1x1_64'.format("major"))(x_in) + x_in = self.activation()(x_in) + x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1_32'.format("major"))(x_in) + x_in = self.activation()(x_in) + x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in) + x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in) + x_in = keras.layers.Dense(128, name='{}_Dense_128'.format("major"))(x_in) + x_in = self.activation()(x_in) + x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in) + x_in = self.activation()(x_in) + x_in = keras.layers.Dense(self._output_shape, name='{}_Dense'.format("major"))(x_in) + out_main = self.activation()(x_in) + self.model = keras.Model(inputs=x_input, outputs=[out_main]) + + def set_compile_options(self): + self.initial_lr = 1e-2 + self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]} diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 675e5ade587011a9ac835e9afb45f89173bc7653..f0b6baeb0b56126ccccb80c9da993fb406428d93 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -25,6 +25,11 @@ from mlair.helpers import TimeTrackingWrapper logging.getLogger('matplotlib').setLevel(logging.WARNING) +# import matplotlib +# matplotlib.use("TkAgg") +# import matplotlib.pyplot as plt + + class AbstractPlotClass: """ Abstract class for all plotting routines to unify plot workflow. @@ -72,6 +77,9 @@ class AbstractPlotClass: def __init__(self, plot_folder, plot_name, resolution=500): """Set up plot folder and name, and plot resolution (default 500dpi).""" + plot_folder = os.path.abspath(plot_folder) + if not os.path.exists(plot_folder): + os.makedirs(plot_folder) self.plot_folder = plot_folder self.plot_name = plot_name self.resolution = resolution @@ -82,7 +90,7 @@ class AbstractPlotClass: def _save(self, **kwargs): """Store plot locally. Name of and path to plot need to be set on initialisation.""" - plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf") + plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=self.resolution, **kwargs) plt.close('all') @@ -137,15 +145,16 @@ class PlotMonthlySummary(AbstractPlotClass): data_cnn = data.sel(type="CNN").squeeze() if len(data_cnn.shape) > 1: - data_cnn.coords["ahead"].values = [f"{days}d" for days in data_cnn.coords["ahead"].values] + data_cnn = data_cnn.assign_coords(ahead=[f"{days}d" for days in data_cnn.coords["ahead"].values]) data_obs = data.sel(type="obs", ahead=1).squeeze() data_obs.coords["ahead"] = "obs" data_concat = xr.concat([data_obs, data_cnn], dim="ahead") - data_concat = data_concat.drop("type") + data_concat = data_concat.drop_vars("type") - data_concat.index.values = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1 + new_index = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1 + data_concat = data_concat.assign_coords(index=new_index) data_concat = data_concat.clip(min=0) forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat @@ -829,8 +838,8 @@ class PlotTimeSeries: factor = 1 if self._sampling == "h": factor = 2 - f, ax = plt.subplots((end - start + 1) * factor, sharey=True, figsize=(50, 30)) - return f, ax, factor + f, ax = plt.subplots((end - start + 1) * factor, sharey=True, figsize=(50, 30), squeeze=False) + return f, ax[:, 0], factor def _plot_ahead(self, ax, data): color = sns.color_palette("Blues_d", self._window_lead_time).as_hex() @@ -902,6 +911,7 @@ class PlotAvailability(AbstractPlotClass): # create standard Gantt plot for all stations (currently in single pdf file with single page) super().__init__(plot_folder, "data_availability") self.dim = time_dimension + self.linewidth = None self.sampling = self._get_sampling(sampling) plot_dict = self._prepare_data(generators) lgd = self._plot(plot_dict) @@ -917,11 +927,11 @@ class PlotAvailability(AbstractPlotClass): lgd = self._plot(plot_dict_summary) self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") - @staticmethod - def _get_sampling(sampling): + def _get_sampling(self, sampling): if sampling == "daily": return "D" elif sampling == "hourly": + self.linewidth = 0.001 return "h" def _prepare_data(self, generators: Dict[str, DataCollection]): @@ -982,7 +992,7 @@ class PlotAvailability(AbstractPlotClass): plt_data = d.get(subset) if plt_data is None: continue - ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white") + ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth) yticklabels.append(station) ax.set_ylim([height, number_of_stations + 1]) @@ -993,10 +1003,31 @@ class PlotAvailability(AbstractPlotClass): return lgd +@TimeTrackingWrapper +class PlotSeparationOfScales(AbstractPlotClass): + + def __init__(self, collection: DataCollection, plot_folder: str = "."): + """Initialise.""" + # create standard Gantt plot for all stations (currently in single pdf file with single page) + plot_folder = os.path.join(plot_folder, "separation_of_scales") + super().__init__(plot_folder, "separation_of_scales") + self._plot(collection) + + def _plot(self, collection: DataCollection): + orig_plot_name = self.plot_name + for dh in collection: + data = dh.get_X(as_numpy=False)[0] + station = dh.id_class.station[0] + data = data.sel(Stations=station) + # plt.subplots() + data.plot(x="datetime", y="window", col="filter", row="variables", robust=True) + self.plot_name = f"{orig_plot_name}_{station}" + self._save() + + if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] path = "../../testrun_network/forecasts" plt_path = "../../" con_quan_cls = PlotConditionalQuantiles(stations, path, plt_path) - diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index f5d7d80f01de9e04a4e1e2d41901b402a17816df..9a9253eda522c39f348dd96700ed38730e87f9a8 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -4,7 +4,7 @@ __date__ = '2019-11-15' import argparse import logging import os -from typing import Union, Dict, Any, List +from typing import Union, Dict, Any, List, Callable from mlair.configuration import path_config from mlair import helpers @@ -17,7 +17,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_TRAIN_START, DEFAULT_TRAIN_END, DEFAULT_TRAIN_MIN_LENGTH, DEFAULT_VAL_START, DEFAULT_VAL_END, \ DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ - DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST + DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN from mlair.data_handler import DefaultDataHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.model_class import MyLittleModel as VanillaModel @@ -184,7 +184,7 @@ class ExperimentSetup(RunEnvironment): training) set for a second time to the sample. If multiple valus are given, a sample is added for each exceedence once. E.g. a sample with `value=2.5` occurs twice in the training set for given `extreme_values=[2, 3]`, whereas a sample with `value=5` occurs three times in the training set. For default, - upsampling of extreme values is disabled (`None`). Upsamling can be modified to manifold only values that are + upsampling of extreme values is disabled (`None`). Upsampling can be modified to manifold only values that are actually larger than given values from ``extreme_values`` (apply only on right side of distribution) by using ``extremes_on_right_tail_only``. This can be useful for positive skew variables. :param extremes_on_right_tail_only: applies only if ``extreme_values`` are given. If ``extremes_on_right_tail_only`` @@ -214,20 +214,25 @@ class ExperimentSetup(RunEnvironment): dimensions=None, time_dim=None, interpolation_method=None, - interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, - test_end=None, use_all_stations_on_all_data_sets=None, train_model: bool = None, fraction_of_train: float = None, - experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data = None, sampling: str = "daily", - create_new_model = None, bootstrap_path=None, permute_data_on_training = None, transformation=None, + interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None, + test_start=None, + test_end=None, use_all_stations_on_all_data_sets=None, train_model: bool = None, + fraction_of_train: float = None, + experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data=None, + sampling: str = None, + create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, - extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None, + extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, + number_of_bootstraps=None, create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, - hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, **kwargs): + hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None, + sampling_outputs=None, data_origin: Dict = None, **kwargs): # create run framework super().__init__() # experiment setup, hyperparameters - self._set_param("data_path", path_config.prepare_host(data_path=data_path, sampling=sampling)) + self._set_param("data_path", path_config.prepare_host(data_path=data_path)) self._set_param("hostname", path_config.get_host()) self._set_param("hpc_hosts", hpc_hosts, default=DEFAULT_HPC_HOST_LIST + DEFAULT_HPC_LOGIN_LIST) self._set_param("login_nodes", login_nodes, default=DEFAULT_HPC_LOGIN_LIST) @@ -235,7 +240,7 @@ class ExperimentSetup(RunEnvironment): if self.data_store.get("create_new_model"): train_model = True data_path = self.data_store.get("data_path") - bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path, sampling) + bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path) self._set_param("bootstrap_path", bootstrap_path) self._set_param("train_model", train_model, default=DEFAULT_TRAIN_MODEL) self._set_param("fraction_of_training", fraction_of_train, default=DEFAULT_FRACTION_OF_TRAINING) @@ -250,6 +255,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("epochs", epochs, default=DEFAULT_EPOCHS) # set experiment name + sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) # always related to output sampling experiment_name = path_config.set_experiment_name(name=experiment_date, sampling=sampling) experiment_path = path_config.set_experiment_path(name=experiment_name, path=experiment_path) self._set_param("experiment_name", experiment_name) @@ -279,15 +285,16 @@ class ExperimentSetup(RunEnvironment): path_config.check_path_and_create(self.data_store.get("logging_path")) # setup for data - self._set_param("stations", stations, default=DEFAULT_STATIONS) + self._set_param("stations", stations, default=DEFAULT_STATIONS, apply=helpers.to_list) self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys())) + self._set_param("data_origin", data_origin, default=DEFAULT_DATA_ORIGIN) self._set_param("start", start, default=DEFAULT_START) self._set_param("end", end, default=DEFAULT_END) self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE) self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA, scope="preprocessing") - self._set_param("sampling", sampling) + self._set_param("sampling_inputs", sampling_inputs, default=sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") self._set_param("data_handler", data_handler, default=DefaultDataHandler) @@ -355,12 +362,17 @@ class ExperimentSetup(RunEnvironment): raise KeyError(f"Given argument {k} with value {v} cannot be set for this experiment due to a " f"conflict with an existing entry with same naming: {k}={self.data_store.get(k)}") - def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: - """Set given parameter and log in debug.""" + def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general", + apply: Callable = None) -> Any: + """Set given parameter and log in debug. Use apply parameter to adjust the stored value (e.g. to transform value + to a list use apply=helpers.to_list).""" if value is None and default is not None: value = default + if apply is not None: + value = apply(value) self.data_store.set(param, value, scope) logging.debug(f"set experiment attribute: {param}({scope})={value}") + return value def _compare_variables_and_statistics(self): """ diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index de43f30d929db1de12681d92c9c585df5c07944e..3dc91cbd54094f116f0d959fb9c845751e998464 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -19,7 +19,8 @@ from mlair.helpers import TimeTracking, statistics, extract_value from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules.model_class import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles + PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles, \ + PlotSeparationOfScales from mlair.run_modules.run_environment import RunEnvironment @@ -262,7 +263,10 @@ class PostProcessing(RunEnvironment): plot_list = self.data_store.get("plot_list", "postprocessing") time_dimension = self.data_store.get("time_dim") - if self.bootstrap_skill_scores is not None and "PlotBootstrapSkillScore" in plot_list: + if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ("PlotSeparationOfScales" in plot_list): + PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path) + + if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list): PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN") if "PlotConditionalQuantiles" in plot_list: @@ -399,10 +403,10 @@ class PostProcessing(RunEnvironment): :return: filled data array with ols predictions """ tmp_ols = self.ols_model.predict(input_data) - if not normalised: - tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method) target_shape = ols_prediction.values.shape ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols + if not normalised: + ols_prediction = statistics.apply_inverse_transformation(ols_prediction, mean, std, transformation_method) return ols_prediction def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, mean: xr.DataArray, @@ -423,9 +427,10 @@ class PostProcessing(RunEnvironment): :return: filled data array with persistence predictions """ tmp_persi = data.copy() - if not normalised: - tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T + if not normalised: + persistence_prediction = statistics.apply_inverse_transformation(persistence_prediction, mean, std, + transformation_method) return persistence_prediction def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, mean: xr.DataArray, @@ -447,8 +452,6 @@ class PostProcessing(RunEnvironment): :return: filled data array with nn predictions """ tmp_nn = self.model.predict(input_data) - if not normalised: - tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method) if isinstance(tmp_nn, list): nn_prediction.values = tmp_nn[-1] elif tmp_nn.ndim == 3: @@ -457,6 +460,8 @@ class PostProcessing(RunEnvironment): nn_prediction.values = tmp_nn else: raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.") + if not normalised: + nn_prediction = statistics.apply_inverse_transformation(nn_prediction, mean, std, transformation_method) return nn_prediction @staticmethod @@ -528,7 +533,7 @@ class PostProcessing(RunEnvironment): # external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"]) external_data = self._create_observation(observation, None, mean, std, transformation_method, normalised=False) return external_data.rename({external_data.dims[0]: 'index'}) - except IndexError: + except (IndexError, KeyError): return None def calculate_skill_scores(self) -> Tuple[Dict, Dict]: diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index ed972896e7a39b0b56df23dbc8a8d1ae64fb4183..4cee4a9744f33c86e8802aad27125cf0e0b30f3a 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -56,7 +56,8 @@ class PreProcessing(RunEnvironment): def _run(self): stations = self.data_store.get("stations") data_handler = self.data_store.get("data_handler") - _, valid_stations = self.validate_station(data_handler, stations, "preprocessing", overwrite_local_data=True) + _, valid_stations = self.validate_station(data_handler, stations, + "preprocessing") # , store_processed_data=False) if len(valid_stations) == 0: raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.") self.data_store.set("stations", valid_stations) @@ -192,26 +193,21 @@ class PreProcessing(RunEnvironment): self.data_store.set("stations", valid_stations, scope=set_name) self.data_store.set("data_collection", collection, scope=set_name) - def validate_station(self, data_handler: AbstractDataHandler, set_stations, set_name=None, overwrite_local_data=False): + def validate_station(self, data_handler: AbstractDataHandler, set_stations, set_name=None, + store_processed_data=True): """ Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. - :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`, - `variables`, `time_dim`, `target_dim`, `target_var`). - :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolation_method`, - `window_lead_time`). - :param all_stations: All stations to check. - :param name: name to display in the logging info message - :return: Corrected list containing only valid station IDs. """ t_outer = TimeTracking() logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}") # calculate transformation using train data if set_name == "train": + logging.info("setup transformation using train data exclusively") self.transformation(data_handler, set_stations) # start station check collection = DataCollection() @@ -219,7 +215,8 @@ class PreProcessing(RunEnvironment): kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name) for station in set_stations: try: - dp = data_handler.build(station, name_affix=set_name, **kwargs) + dp = data_handler.build(station, name_affix=set_name, store_processed_data=store_processed_data, + **kwargs) collection.add(dp) valid_stations.append(station) except (AttributeError, EmptyQueryResult): @@ -234,6 +231,3 @@ class PreProcessing(RunEnvironment): transformation_dict = data_handler.transformation(stations, **kwargs) if transformation_dict is not None: self.data_store.set("transformation", transformation_dict) - - - diff --git a/mlair/workflows/abstract_workflow.py b/mlair/workflows/abstract_workflow.py index bced90bbe848cc9ebe36c583d05b62549f0ae80b..3a627d9f72a5c1c97c35b464af1b0944bc397ea5 100644 --- a/mlair/workflows/abstract_workflow.py +++ b/mlair/workflows/abstract_workflow.py @@ -16,15 +16,17 @@ class Workflow: execution but not the dependencies (workflow would probably fail in this case).""" def __init__(self, name=None): - self._registry = OrderedDict() + self._registry_kwargs = {} + self._registry = [] self._name = name if name is not None else self.__class__.__name__ def add(self, stage, **kwargs): """Add a new stage with optional kwargs.""" - self._registry[stage] = kwargs + self._registry.append(stage) + self._registry_kwargs[len(self._registry) - 1] = kwargs def run(self): """Run workflow embedded in a run environment and according to the stage's ordering.""" with RunEnvironment(name=self._name): - for stage, kwargs in self._registry.items(): - stage(**kwargs) + for pos, stage in enumerate(self._registry): + stage(**self._registry_kwargs[pos]) diff --git a/mlair/workflows/default_workflow.py b/mlair/workflows/default_workflow.py index 85d6726b70b699968933bf9af7580895490b8a6d..4d113190fdc90ec852d7db2b33459b9162867a24 100644 --- a/mlair/workflows/default_workflow.py +++ b/mlair/workflows/default_workflow.py @@ -14,28 +14,29 @@ class DefaultWorkflow(Workflow): the mentioned ordering.""" def __init__(self, stations=None, - train_model=None, create_new_model=None, - window_history_size=None, - experiment_date="testrun", - variables=None, statistics_per_var=None, - start=None, end=None, - target_var=None, target_dim=None, - window_lead_time=None, - dimensions=None, - interpolation_method=None, time_dim=None, limit_nan_fill=None, - train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, - use_all_stations_on_all_data_sets=None, fraction_of_train=None, - experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, overwrite_local_data=None, - sampling=None, - permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None, - transformation=None, - train_min_length=None, val_min_length=None, test_min_length=None, - evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None, - plot_list=None, - model=None, - batch_size=None, - epochs=None, - data_preparation=None, + train_model=None, create_new_model=None, + window_history_size=None, + experiment_date="testrun", + variables=None, statistics_per_var=None, + start=None, end=None, + target_var=None, target_dim=None, + window_lead_time=None, + dimensions=None, + interpolation_method=None, time_dim=None, limit_nan_fill=None, + train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, + use_all_stations_on_all_data_sets=None, fraction_of_train=None, + experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, + overwrite_local_data=None, + sampling=None, + permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None, + transformation=None, + train_min_length=None, val_min_length=None, test_min_length=None, + evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None, + plot_list=None, + model=None, + batch_size=None, + epochs=None, + data_handler=None, **kwargs): super().__init__() @@ -58,28 +59,29 @@ class DefaultWorkflowHPC(Workflow): Training and PostProcessing in exact the mentioned ordering.""" def __init__(self, stations=None, - train_model=None, create_new_model=None, - window_history_size=None, - experiment_date="testrun", - variables=None, statistics_per_var=None, - start=None, end=None, - target_var=None, target_dim=None, - window_lead_time=None, - dimensions=None, - interpolation_method=None, time_dim=None, limit_nan_fill=None, - train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, - use_all_stations_on_all_data_sets=None, fraction_of_train=None, - experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, overwrite_local_data=None, - sampling=None, - permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None, - transformation=None, - train_min_length=None, val_min_length=None, test_min_length=None, - evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None, - plot_list=None, - model=None, - batch_size=None, - epochs=None, - data_preparation=None, **kwargs): + train_model=None, create_new_model=None, + window_history_size=None, + experiment_date="testrun", + variables=None, statistics_per_var=None, + start=None, end=None, + target_var=None, target_dim=None, + window_lead_time=None, + dimensions=None, + interpolation_method=None, time_dim=None, limit_nan_fill=None, + train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, + use_all_stations_on_all_data_sets=None, fraction_of_train=None, + experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, + overwrite_local_data=None, + sampling=None, + permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None, + transformation=None, + train_min_length=None, val_min_length=None, test_min_length=None, + evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None, + plot_list=None, + model=None, + batch_size=None, + epochs=None, + data_handler=None, **kwargs): super().__init__() # extract all given kwargs arguments diff --git a/requirements.txt b/requirements.txt index be76eab5b74797b039682a292ae8890488c058ec..371bb776e581925e507bf06c60bd866061c52791 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,7 +61,7 @@ typing-extensions urllib3==1.25.8 wcwidth==0.1.8 Werkzeug==1.0.0 -xarray==0.15.0 +xarray==0.16.1 zipp==3.1.0 setuptools~=49.6.0 diff --git a/run_hourly_kz.py b/run_hourly_kz.py new file mode 100644 index 0000000000000000000000000000000000000000..5536b56e732d81b84dfee7f34bd68d0d2ba49020 --- /dev/null +++ b/run_hourly_kz.py @@ -0,0 +1,31 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-11-14' + +import argparse + +from mlair.workflows import DefaultWorkflow +from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilter + + +def main(parser_args): + args = dict(sampling="hourly", + window_history_size=24, **parser_args.__dict__, + data_handler=DataHandlerKzFilter, + kz_filter_length=[365 * 24, 20 * 24], # 13,5# , 4 * 24, 12, 6], + kz_filter_iter=[3, 5], # 3,4# , 3, 4, 4], + start="2006-01-01", + train_start="2006-01-01", + end="2011-12-31", + test_end="2011-12-31", + stations=["DEBW107", "DEBW013"] + ) + workflow = DefaultWorkflow(**args) + workflow.run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, + help="set experiment date as string") + args = parser.parse_args(["--experiment_date", "testrun"]) + main(args) diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..a70e2aa36e5c1da83c4f667fbbe8b27b5949b4d6 --- /dev/null +++ b/run_mixed_sampling.py @@ -0,0 +1,36 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-11-14' + +import argparse + +from mlair.workflows import DefaultWorkflow +from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \ + DataHandlerMixedSamplingSeparationOfScales + + +def main(parser_args): + args = dict(sampling="daily", + sampling_inputs="hourly", + window_history_size=24, + **parser_args.__dict__, + data_handler=DataHandlerMixedSamplingSeparationOfScales, + kz_filter_length=[100 * 24, 15 * 24], + kz_filter_iter=[4, 5], + start="2006-01-01", + train_start="2006-01-01", + end="2011-12-31", + test_end="2011-12-31", + stations=["DEBW107", "DEBW013"], + epochs=100, + network="UBA", + ) + workflow = DefaultWorkflow(**args) + workflow.run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, + help="set experiment date as string") + args = parser.parse_args(["--experiment_date", "testrun"]) + main(args) diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..fffe7c84075eeeab37ebf59d52bc42dbf87bf522 --- /dev/null +++ b/test/test_configuration/test_defaults.py @@ -0,0 +1,73 @@ +from mlair.configuration.defaults import * + + +class TestGetDefaults: + + def test_get_defaults(self): + defaults = get_defaults() + assert isinstance(defaults, dict) + assert all(map(lambda k: k in defaults.keys(), ["DEFAULT_STATIONS", "DEFAULT_BATCH_SIZE", "DEFAULT_PLOT_LIST"])) + assert all(map(lambda x: x.startswith("DEFAULT"), defaults.keys())) + + +class TestAllDefaults: + + def test_training_parameters(self): + assert DEFAULT_CREATE_NEW_MODEL is True + assert DEFAULT_TRAIN_MODEL is True + assert DEFAULT_FRACTION_OF_TRAINING == 0.8 + assert DEFAULT_EXTREME_VALUES is None + assert DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY is False + assert DEFAULT_PERMUTE_DATA is False + assert DEFAULT_BATCH_SIZE == int(256 * 2) + assert DEFAULT_EPOCHS == 20 + + def test_data_handler_parameters(self): + assert DEFAULT_STATIONS == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] + assert DEFAULT_VAR_ALL_DICT == {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', + 'u': 'average_values', + 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', + 'cloudcover': 'average_values', + 'pblheight': 'maximum'} + assert DEFAULT_NETWORK == "AIRBASE" + assert DEFAULT_STATION_TYPE == "background" + assert DEFAULT_VARIABLES == DEFAULT_VAR_ALL_DICT.keys() + assert DEFAULT_START == "1997-01-01" + assert DEFAULT_END == "2017-12-31" + assert DEFAULT_WINDOW_HISTORY_SIZE == 13 + assert DEFAULT_OVERWRITE_LOCAL_DATA is False + assert isinstance(DEFAULT_TRANSFORMATION, TransformationClass) + assert DEFAULT_TRANSFORMATION.inputs.transform_method == "standardise" + assert DEFAULT_TRANSFORMATION.targets.transform_method == "standardise" + assert DEFAULT_TARGET_VAR == "o3" + assert DEFAULT_TARGET_DIM == "variables" + assert DEFAULT_WINDOW_LEAD_TIME == 3 + assert DEFAULT_DIMENSIONS == {"new_index": ["datetime", "Stations"]} + assert DEFAULT_TIME_DIM == "datetime" + assert DEFAULT_INTERPOLATION_METHOD == "linear" + assert DEFAULT_INTERPOLATION_LIMIT == 1 + + def test_subset_parameters(self): + assert DEFAULT_TRAIN_START == "1997-01-01" + assert DEFAULT_TRAIN_END == "2007-12-31" + assert DEFAULT_TRAIN_MIN_LENGTH == 90 + assert DEFAULT_VAL_START == "2008-01-01" + assert DEFAULT_VAL_END == "2009-12-31" + assert DEFAULT_VAL_MIN_LENGTH == 90 + assert DEFAULT_TEST_START == "2010-01-01" + assert DEFAULT_TEST_END == "2017-12-31" + assert DEFAULT_TEST_MIN_LENGTH == 90 + assert DEFAULT_TRAIN_VAL_MIN_LENGTH == 180 + assert DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS is True + + def test_hpc_parameters(self): + assert DEFAULT_HPC_HOST_LIST == ["jw", "hdfmlc"] + assert DEFAULT_HPC_LOGIN_LIST == ["ju", "hdfmll"] + + def test_postprocessing_parameters(self): + assert DEFAULT_EVALUATE_BOOTSTRAPS is True + assert DEFAULT_CREATE_NEW_BOOTSTRAPS is False + assert DEFAULT_NUMBER_OF_BOOTSTRAPS == 20 + assert DEFAULT_PLOT_LIST == ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", + "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", + "PlotConditionalQuantiles", "PlotAvailability", "PlotSeparationOfScales"] diff --git a/test/test_configuration/test_path_config.py b/test/test_configuration/test_path_config.py index b97763632922fc2aaffaf267cfbc76ff99e25b6f..2ba80a3bdf62b7fdf10b645da75769435cf7b6b9 100644 --- a/test/test_configuration/test_path_config.py +++ b/test/test_configuration/test_path_config.py @@ -11,22 +11,21 @@ from mlair.helpers import PyTestRegex class TestPrepareHost: - @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest", + @mock.patch("socket.gethostname", side_effect=["ZAM144", "zam347", "jrtest", "jwtest", "runner-6HmDp9Qd-project-2411-concurrent-01"]) @mock.patch("getpass.getuser", return_value="testUser") @mock.patch("os.path.exists", return_value=True) def test_prepare_host(self, mock_host, mock_user, mock_path): - assert prepare_host() == "/home/testUser/mlair/data/toar_daily/" assert prepare_host() == "/home/testUser/Data/toar_daily/" assert prepare_host() == "/home/testUser/Data/toar_daily/" assert prepare_host() == "/p/project/cjjsc42/testUser/DATA/toar_daily/" - assert prepare_host() == "/p/project/deepacf/intelliaq/testUser/DATA/toar_daily/" - assert prepare_host() == '/home/testUser/mlair/data/toar_daily/' + assert prepare_host() == "/p/project/deepacf/intelliaq/testUser/DATA/MLAIR/" + assert prepare_host() == '/home/testUser/mlair/data/' @mock.patch("socket.gethostname", return_value="NotExistingHostName") @mock.patch("getpass.getuser", return_value="zombie21") def test_prepare_host_unknown(self, mock_user, mock_host): - assert prepare_host() == os.path.join(os.path.abspath(os.getcwd()), 'data', 'daily') + assert prepare_host() == os.path.join(os.path.abspath(os.getcwd()), 'data') @mock.patch("getpass.getuser", return_value="zombie21") @mock.patch("mlair.configuration.path_config.check_path_and_create", side_effect=PermissionError) @@ -42,13 +41,13 @@ class TestPrepareHost: # assert "does not exist for host 'linux-aa9b'" in e.value.args[0] assert PyTestRegex(r"path '.*' does not exist for host '.*'\.") == e.value.args[0] - @mock.patch("socket.gethostname", side_effect=["linux-aa9b"]) + @mock.patch("socket.gethostname", side_effect=["zam347"]) @mock.patch("getpass.getuser", return_value="testUser") @mock.patch("os.path.exists", return_value=False) @mock.patch("os.makedirs", side_effect=None) def test_os_path_exists(self, mock_host, mock_user, mock_path, mock_check): path = prepare_host() - assert path == "/home/testUser/mlair/data/toar_daily/" + assert path == "/home/testUser/Data/toar_daily/" class TestSetExperimentName: @@ -80,12 +79,12 @@ class TestSetBootstrapPath: @mock.patch("os.makedirs", side_effect=None) def test_bootstrap_path_is_none(self, mock_makedir): - bootstrap_path = set_bootstrap_path(None, 'TestDataPath/', 'daily') - assert bootstrap_path == os.path.abspath('TestDataPath/../bootstrap_daily') + bootstrap_path = set_bootstrap_path(None, 'TestDataPath/') + assert bootstrap_path == os.path.abspath('TestDataPath/bootstrap') @mock.patch("os.makedirs", side_effect=None) def test_bootstap_path_is_given(self, mock_makedir): - bootstrap_path = set_bootstrap_path('Test/path/to/boots', None, None) + bootstrap_path = set_bootstrap_path('Test/path/to/boots', None) assert bootstrap_path == os.path.abspath('./Test/path/to/boots') diff --git a/test/test_data_handler/test_data_handler.py b/test/test_data_handler/test_data_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..418c7946efe160c9bbfeccff9908a6cf17dec17f --- /dev/null +++ b/test/test_data_handler/test_data_handler.py @@ -0,0 +1,67 @@ +import pytest +import inspect + +from mlair.data_handler.abstract_data_handler import AbstractDataHandler + + +class TestDefaultDataHandler: + + def test_required_attributes(self): + dh = AbstractDataHandler + assert hasattr(dh, "_requirements") + assert hasattr(dh, "__init__") + assert hasattr(dh, "build") + assert hasattr(dh, "requirements") + assert hasattr(dh, "own_args") + assert hasattr(dh, "transformation") + assert hasattr(dh, "get_X") + assert hasattr(dh, "get_Y") + assert hasattr(dh, "get_data") + assert hasattr(dh, "get_coordinates") + + def test_init(self): + assert isinstance(AbstractDataHandler(), AbstractDataHandler) + + def test_build(self): + assert isinstance(AbstractDataHandler.build(), AbstractDataHandler) + + def test_requirements(self): + dh = AbstractDataHandler() + assert isinstance(dh._requirements, list) + assert len(dh._requirements) == 0 + assert isinstance(dh.requirements(), list) + assert len(dh.requirements()) == 0 + + def test_own_args(self): + dh = AbstractDataHandler() + assert isinstance(dh.own_args(), list) + assert len(dh.own_args()) == 0 + assert "self" not in dh.own_args() + + def test_transformation(self): + assert AbstractDataHandler.transformation() is None + + def test_get_X(self): + dh = AbstractDataHandler() + with pytest.raises(NotImplementedError): + dh.get_X() + assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_X).args) + assert (False, False) == inspect.getfullargspec(dh.get_X).defaults + + def test_get_Y(self): + dh = AbstractDataHandler() + with pytest.raises(NotImplementedError): + dh.get_Y() + assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_Y).args) + assert (False, False) == inspect.getfullargspec(dh.get_Y).defaults + + def test_get_data(self): + dh = AbstractDataHandler() + with pytest.raises(NotImplementedError): + dh.get_data() + assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_data).args) + assert (False, False) == inspect.getfullargspec(dh.get_data).defaults + + def test_get_coordinates(self): + dh = AbstractDataHandler() + assert dh.get_coordinates() is None diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py index 281d60e07463c6b5118f36714d80144443a03050..723b4a87d70453327ed6b7e355d3ef78a246652a 100644 --- a/test/test_helpers/test_helpers.py +++ b/test/test_helpers/test_helpers.py @@ -124,14 +124,22 @@ class TestPytestRegex: class TestDictToXarray: def test_dict_to_xarray(self): - array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) - array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) + array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) + array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) d = {"number1": array1, "number2": array2} res = dict_to_xarray(d, "merge_dim") assert type(res) == xr.DataArray - assert sorted(list(res.coords)) == ["merge_dim", "x"] + assert sorted(list(res.coords)) == ["merge_dim", "x", "y"] assert res.shape == (2, 2, 3) + def test_dict_to_xarray_single_entry(self): + array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) + d = {"number1": array1} + res = dict_to_xarray(d, "merge_dim") + assert type(res) == xr.DataArray + assert sorted(list(res.coords)) == ["merge_dim", "x", "y"] + assert res.shape == (1, 2, 3) + class TestFloatRound: diff --git a/test/test_join.py b/test/test_join.py index 791723335e16cf2124512629414ebe626bc20e9c..a9a4c381cbf58a272389b0b11283c8b0cce3ab42 100644 --- a/test/test_join.py +++ b/test/test_join.py @@ -3,7 +3,7 @@ from typing import Iterable import pytest from mlair.helpers.join import * -from mlair.helpers.join import _save_to_pandas, _correct_stat_name, _lower_list +from mlair.helpers.join import _save_to_pandas, _correct_stat_name, _lower_list, _select_distinct_series from mlair.configuration.join_settings import join_settings @@ -52,7 +52,7 @@ class TestGetData: class TestLoadSeriesInformation: def test_standard_query(self): - expected_subset = {'o3': 23031, 'no2': 39002, 'temp--lubw': 17059, 'wspeed': 17060} + expected_subset = {'o3': 23031, 'no2': 39002, 'temp': 85584, 'wspeed': 17060} assert expected_subset.items() <= load_series_information(['DEBW107'], None, None, join_settings()[0], {}).items() @@ -60,6 +60,38 @@ class TestLoadSeriesInformation: assert load_series_information(['DEBW107'], "traffic", None, join_settings()[0], {}) == {} +class TestSelectDistinctSeries: + + @pytest.fixture + def vars(self): + return [{'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}, + {'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', + 'parameter_attribute': ''}, + {'id': 16692, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS--LANUV', 'parameter_attribute': ''}, + {'id': 16693, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP--LANUV', 'parameter_attribute': ''}, + {'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}, + {'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}, + {'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}] + + def test_no_origin_given(self, vars): + res = _select_distinct_series(vars) + assert res == {"no2": 16686, "o3": 16687, "cloudcover": 54036, "temp": 88491, "press": 102660} + + def test_different_origins(self, vars): + origin = {"no2": "test", "temp": "", "cloudcover": "REA"} + res = _select_distinct_series(vars, data_origin=origin) + assert res == {"o3": 16687, "press": 16692, "temp": 16693, "cloudcover": 54036} + res = _select_distinct_series(vars, data_origin={}) + assert res == {"no2": 16686, "o3": 16687, "press": 16692, "temp": 16693} + + class TestSaveToPandas: @staticmethod diff --git a/test/test_run_modules/test_experiment_setup.py b/test/test_run_modules/test_experiment_setup.py index ff35508542b694eb1def0ba791d9a5f70043f19c..7c63d3d101176a40749ce903f569263b9c884d5e 100644 --- a/test/test_run_modules/test_experiment_setup.py +++ b/test/test_run_modules/test_experiment_setup.py @@ -4,7 +4,7 @@ import os import pytest -from mlair.helpers import TimeTracking +from mlair.helpers import TimeTracking, to_list from mlair.configuration.path_config import prepare_host from mlair.run_modules.experiment_setup import ExperimentSetup @@ -33,6 +33,16 @@ class TestExperimentSetup: empty_obj._set_param("AnotherNoneTester", None) assert empty_obj.data_store.get("AnotherNoneTester", "general") is None + def test_set_param_with_apply(self, caplog, empty_obj): + empty_obj._set_param("NoneTester", None, default="notNone", apply=None) + assert empty_obj.data_store.get("NoneTester") == "notNone" + empty_obj._set_param("NoneTester", None, default="notNone", apply=to_list) + assert empty_obj.data_store.get("NoneTester") == ["notNone"] + empty_obj._set_param("NoneTester", None, apply=to_list) + assert empty_obj.data_store.get("NoneTester") == [None] + empty_obj._set_param("NoneTester", 2.3, apply=int) + assert empty_obj.data_store.get("NoneTester") == 2 + def test_init_default(self): exp_setup = ExperimentSetup() data_store = exp_setup.data_store diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index c0b625ef70deeb0686b236275e6bd1182ad48d41..c2b58cbd2160bd958c76ba67649ef8caba09fcb4 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -61,7 +61,8 @@ class TestTraining: obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train") obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val") obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test") - os.makedirs(path) + if not os.path.exists(path): + os.makedirs(path) obj.data_store.set("experiment_path", path, "general") os.makedirs(batch_path) obj.data_store.set("batch_path", batch_path, "general") @@ -125,7 +126,8 @@ class TestTraining: @pytest.fixture def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var): - data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(os.path.dirname(__file__), 'data'), + data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(path, 'data'), + experiment_path=os.path.join(path, 'exp_path'), statistics_per_var=statistics_per_var, station_type="background", network="AIRBASE", sampling="daily", target_dim="variables", target_var="o3", time_dim="datetime", @@ -169,7 +171,8 @@ class TestTraining: @pytest.fixture def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path): - os.makedirs(path) + if not os.path.exists(path): + os.makedirs(path) os.makedirs(model_path) obj = RunEnvironment() obj.data_store.set("data_collection", data_collection, "general.train") diff --git a/test/test_statistics.py b/test/test_statistics.py index d4a72674ae89ecd106ff1861aa6ee26567da3243..76adc1bdd210e072b4fc9be717269c6ceb951fec 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -3,7 +3,9 @@ import pandas as pd import pytest import xarray as xr -from mlair.helpers.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, centre_apply, \ +from mlair.helpers.statistics import DataClass, TransformationClass +from mlair.helpers.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, \ + centre_apply, \ apply_inverse_transformation lazy = pytest.lazy_fixture @@ -113,3 +115,50 @@ class TestCentre: data = centre_apply(data_orig, mean) mean_expected = np.array([2, -5, 10]) - np.array([2, 10, 3]) assert np.testing.assert_almost_equal(data.mean(dim), mean_expected, decimal=1) is None + + +class TestDataClass: + + def test_init(self): + dc = DataClass() + assert all([obj is None for obj in [dc.data, dc.mean, dc.std, dc.max, dc.min, dc.transform_method, dc._method]]) + + def test_init_values(self): + dc = DataClass(data=12, mean=2, std="test", max=23.4, min=np.array([3]), transform_method="f") + assert dc.data == 12 + assert dc.mean == 2 + assert dc.std == "test" + assert dc.max == 23.4 + assert np.testing.assert_array_equal(dc.min, np.array([3])) is None + assert dc.transform_method == "f" + assert dc._method is None + + def test_as_dict(self): + dc = DataClass(std=23) + dc._method = "f(x)" + assert dc.as_dict() == {"data": None, "mean": None, "std": 23, "max": None, "min": None, + "transform_method": None} + + +class TestTransformationClass: + + def test_init(self): + tc = TransformationClass() + assert hasattr(tc, "inputs") + assert isinstance(tc.inputs, DataClass) + assert hasattr(tc, "targets") + assert isinstance(tc.targets, DataClass) + assert tc.inputs.mean is None + assert tc.targets.std is None + + def test_init_values(self): + tc = TransformationClass(inputs_mean=1, inputs_std=2, inputs_method="f", targets_mean=3, targets_std=4, + targets_method="g") + assert tc.inputs.mean == 1 + assert tc.inputs.std == 2 + assert tc.inputs.transform_method == "f" + assert tc.inputs.max is None + assert tc.targets.mean == 3 + assert tc.targets.std == 4 + assert tc.targets.transform_method == "g" + assert tc.inputs.min is None