diff --git a/conftest.py b/conftest.py index abb0c0f52757e4b2228d7d48e3dc07e08b302841..b63d3efb33f5b2c02185f16e8753231d1853e66c 100644 --- a/conftest.py +++ b/conftest.py @@ -66,5 +66,5 @@ def default_session_fixture(request): # request.addfinalizer(unpatch) - with mock.patch("multiprocessing.cpu_count", return_value=1): + with mock.patch("psutil.cpu_count", return_value=1): yield diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index 31b58a56375ea26a857ee132c2170680bab4e55a..00815419b43ffc2466f41bfede3a96311a752fdf 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -46,21 +46,25 @@ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True DEFAULT_EVALUATE_BOOTSTRAPS = True DEFAULT_CREATE_NEW_BOOTSTRAPS = False DEFAULT_NUMBER_OF_BOOTSTRAPS = 20 +DEFAULT_BOOTSTRAP_TYPE = "singleinput" +DEFAULT_BOOTSTRAP_METHOD = "shuffle" DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", "PlotAvailability", "PlotAvailabilityHistogram", "PlotDataHistogram", "PlotOversampling", - "PlotOversamplingContingency"] + "PlotOversamplingContingency", "PlotPeriodogram"] DEFAULT_SAMPLING = "daily" DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", "no": "", "no2": "", "o3": "", "pm10": "", "so2": ""} DEFAULT_USE_MULTIPROCESSING = True DEFAULT_USE_MULTIPROCESSING_ON_DEBUG = False +DEFAULT_MAX_NUMBER_MULTIPROCESSING = 16 DEFAULT_OVERSAMPLING_BINS = 10 DEFAULT_OVERSAMPLING_RATES_CAP = 100 DEFAULT_OVERSAMPLING_METHOD = None + def get_defaults(): """Return all default parameters set in defaults.py""" return {key: value for key, value in globals().items() if key.startswith('DEFAULT')} diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py index 419db059a58beeb4ed7e3e198e41b565f8dc7d25..36d6e9ae5394705af4b9fbcfd1d8ff77572642b5 100644 --- a/mlair/data_handler/abstract_data_handler.py +++ b/mlair/data_handler/abstract_data_handler.py @@ -11,6 +11,7 @@ from mlair.helpers import remove_items class AbstractDataHandler: _requirements = [] + _store_attributes = [] def __init__(self, *args, **kwargs): pass @@ -32,6 +33,31 @@ class AbstractDataHandler: list_of_args = arg_spec.args + arg_spec.kwonlyargs return remove_items(list_of_args, ["self"] + list(args)) + @classmethod + def store_attributes(cls) -> list: + """ + Let MLAir know that some data should be stored in the data store. This is used for calculations on the train + subset that should be applied to validation and test subset. + + To work properly, add a class variable cls._store_attributes to your data handler. If your custom data handler + is constructed on different data handlers (e.g. like the DefaultDataHandler), it is required to overwrite the + get_store_attributs method in addition to return attributes from the corresponding subclasses. This is not + required, if only attributes from the main class are to be returned. + + Note, that MLAir will store these attributes with the data handler's identification. This depends on the custom + data handler setting. When loading an attribute from the data handler, it is therefore required to extract the + right information by using the class identification. In case of the DefaultDataHandler this can be achieved to + convert all keys of the attribute to string and compare these with the station parameter. + """ + return list(set(cls._store_attributes)) + + def get_store_attributes(self): + """Returns all attribute names and values that are indicated by the store_attributes method.""" + attr_dict = {} + for attr in self.store_attributes(): + attr_dict[attr] = self.__getattribute__(attr) + return attr_dict + @classmethod def transformation(cls, *args, **kwargs): return None diff --git a/mlair/data_handler/bootstraps.py b/mlair/data_handler/bootstraps.py index 68a4bbc4bc9620bfb54ba23fef1ce882e76c8626..e03881484bfc9b8275ede8a4432072c74643994a 100644 --- a/mlair/data_handler/bootstraps.py +++ b/mlair/data_handler/bootstraps.py @@ -15,69 +15,175 @@ __date__ = '2020-02-07' import os from collections import Iterator, Iterable from itertools import chain +from typing import Union, List import numpy as np import xarray as xr from mlair.data_handler.abstract_data_handler import AbstractDataHandler +from mlair.helpers.helpers import to_list class BootstrapIterator(Iterator): _position: int = None - def __init__(self, data: "BootStraps"): + def __init__(self, data: "BootStraps", method): assert isinstance(data, BootStraps) self._data = data self._dimension = data.bootstrap_dimension - self._collection = self._data.bootstraps() + self.boot_dim = "boots" + self._method = method + self._collection = self.create_collection(self._data.data, self._dimension) self._position = 0 + def __next__(self): + """Return next element or stop iteration.""" + raise NotImplementedError + + @classmethod + def create_collection(cls, data, dim): + raise NotImplementedError + + def _reshape(self, d): + if isinstance(d, list): + return list(map(lambda x: self._reshape(x), d)) + # return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d)) + else: + shape = d.shape + return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1]) + + def _to_numpy(self, d): + if isinstance(d, list): + return list(map(lambda x: self._to_numpy(x), d)) + else: + return d.values + + def apply_bootstrap_method(self, data: np.ndarray) -> Union[np.ndarray, List[np.ndarray]]: + """ + Apply predefined bootstrap method from given data. + + :param data: data to apply bootstrap method on + :return: processed data as numpy array + """ + if isinstance(data, list): + return list(map(lambda x: self.apply_bootstrap_method(x.values), data)) + else: + return self._method.apply(data) + + +class BootstrapIteratorSingleInput(BootstrapIterator): + _position: int = None + + def __init__(self, *args): + super().__init__(*args) + def __next__(self): """Return next element or stop iteration.""" try: index, dimension = self._collection[self._position] nboot = self._data.number_of_bootstraps _X, _Y = self._data.data.get_data(as_numpy=False) - _X = list(map(lambda x: x.expand_dims({'boots': range(nboot)}, axis=-1), _X)) - _Y = _Y.expand_dims({"boots": range(nboot)}, axis=-1) + _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) + _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) single_variable = _X[index].sel({self._dimension: [dimension]}) - shuffled_variable = self.shuffle(single_variable.values) - shuffled_data = xr.DataArray(shuffled_variable, coords=single_variable.coords, dims=single_variable.dims) - _X[index] = shuffled_data.combine_first(_X[index]).reindex_like(_X[index]) + bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) + bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, + dims=single_variable.dims) + _X[index] = bootstrapped_data.combine_first(_X[index]).reindex_like(_X[index]) self._position += 1 except IndexError: raise StopIteration() _X, _Y = self._to_numpy(_X), self._to_numpy(_Y) return self._reshape(_X), self._reshape(_Y), (index, dimension) - @staticmethod - def _reshape(d): - if isinstance(d, list): - return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d)) - else: - shape = d.shape - return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1]) + @classmethod + def create_collection(cls, data, dim): + l = [] + for i, x in enumerate(data.get_X(as_numpy=False)): + l.append(list(map(lambda y: (i, y), x.indexes[dim]))) + return list(chain(*l)) - @staticmethod - def _to_numpy(d): - if isinstance(d, list): - return list(map(lambda x: x.values, d)) - else: - return d.values - @staticmethod - def shuffle(data: np.ndarray) -> np.ndarray: - """ - Shuffle randomly from given data (draw elements with replacement). +class BootstrapIteratorVariable(BootstrapIterator): - :param data: data to shuffle - :return: shuffled data as numpy array - """ + def __init__(self, *args): + super().__init__(*args) + + def __next__(self): + """Return next element or stop iteration.""" + try: + dimension = self._collection[self._position] + nboot = self._data.number_of_bootstraps + _X, _Y = self._data.data.get_data(as_numpy=False) + _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) + _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) + for index in range(len(_X)): + single_variable = _X[index].sel({self._dimension: [dimension]}) + bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) + bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, + dims=single_variable.dims) + _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims) + self._position += 1 + except IndexError: + raise StopIteration() + _X, _Y = self._to_numpy(_X), self._to_numpy(_Y) + return self._reshape(_X), self._reshape(_Y), (None, dimension) + + @classmethod + def create_collection(cls, data, dim): + l = set() + for i, x in enumerate(data.get_X(as_numpy=False)): + l.update(x.indexes[dim].to_list()) + return to_list(l) + + +class BootstrapIteratorBranch(BootstrapIterator): + + def __init__(self, *args): + super().__init__(*args) + + def __next__(self): + try: + index = self._collection[self._position] + nboot = self._data.number_of_bootstraps + _X, _Y = self._data.data.get_data(as_numpy=False) + _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) + _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) + for dimension in _X[index].coords[self._dimension].values: + single_variable = _X[index].sel({self._dimension: [dimension]}) + bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) + bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, + dims=single_variable.dims) + _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims) + self._position += 1 + except IndexError: + raise StopIteration() + _X, _Y = self._to_numpy(_X), self._to_numpy(_Y) + return self._reshape(_X), self._reshape(_Y), (None, index) + + @classmethod + def create_collection(cls, data, dim): + return list(range(len(data.get_X(as_numpy=False)))) + + +class ShuffleBootstraps: + + @staticmethod + def apply(data): size = data.shape return np.random.choice(data.reshape(-1, ), size=size) +class MeanBootstraps: + + def __init__(self, mean): + self._mean = mean + + def apply(self, data): + return np.ones_like(data) * self._mean + + class BootStraps(Iterable): """ Main class to perform bootstrap operations. @@ -89,10 +195,19 @@ class BootStraps(Iterable): this variable). The tuple is interesting if X consists on mutliple input streams X_i (e.g. two or more stations) because it shows which variable of which input X_i has been bootstrapped. All bootstrap combinations can be retrieved by calling the .bootstraps() method. Further more, by calling the .get_orig_prediction() this class - imitates according to the set number of bootstraps the original prediction + imitates according to the set number of bootstraps the original prediction. + + As bootstrap method, this class can currently make use of the ShuffleBoostraps class that uses drawing with + replacement to destroy the variables information by keeping its statistical properties. Use `bootstrap="shuffle"` to + call this method. Another method is the zero mean bootstrapping triggered by `bootstrap="zero_mean"` and performed + by the MeanBootstraps class. This method destroy the variable's information by a mode collapse to constant value of + zero. In case, the variable is normalized with a zero mean, this is equivalent to a mode collapse to the variable's + mean value. Statistics in general are not conserved in this case, but the mean value of course. A custom mean value + for bootstrapping is currently not supported. """ + def __init__(self, data: AbstractDataHandler, number_of_bootstraps: int = 10, - bootstrap_dimension: str = "variables"): + bootstrap_dimension: str = "variables", bootstrap_type="singleinput", bootstrap_method="shuffle"): """ Create iterable class to be ready to iter. @@ -100,20 +215,24 @@ class BootStraps(Iterable): :param number_of_bootstraps: the number of bootstrap realisations """ self.data = data - self.number_of_bootstraps = number_of_bootstraps + self.number_of_bootstraps = number_of_bootstraps if bootstrap_method == "shuffle" else 1 self.bootstrap_dimension = bootstrap_dimension + self.bootstrap_method = {"shuffle": ShuffleBootstraps(), + "zero_mean": MeanBootstraps(mean=0)}.get( + bootstrap_method) # todo adjust number of bootstraps if mean bootstrapping + self.BootstrapIterator = {"singleinput": BootstrapIteratorSingleInput, + "branch": BootstrapIteratorBranch, + "variable": BootstrapIteratorVariable}.get(bootstrap_type, + BootstrapIteratorSingleInput) def __iter__(self): - return BootstrapIterator(self) + return self.BootstrapIterator(self, self.bootstrap_method) def __len__(self): - return len(self.bootstraps()) + return len(self.BootstrapIterator.create_collection(self.data, self.bootstrap_dimension)) def bootstraps(self): - l = [] - for i, x in enumerate(self.data.get_X(as_numpy=False)): - l.append(list(map(lambda y: (i, y), x.indexes['variables']))) - return list(chain(*l)) + return self.BootstrapIterator.create_collection(self.data, self.bootstrap_dimension) def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN") -> np.ndarray: """ diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py deleted file mode 100644 index 539712b39e51c32203e1c55e28ce2eff24069479..0000000000000000000000000000000000000000 --- a/mlair/data_handler/data_handler_kz_filter.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Data Handler using kz-filtered data.""" - -__author__ = 'Lukas Leufen' -__date__ = '2020-08-26' - -import inspect -import numpy as np -import pandas as pd -import xarray as xr -from typing import List, Union, Tuple, Optional -from functools import partial - -from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation -from mlair.data_handler import DefaultDataHandler -from mlair.helpers import remove_items, to_list, TimeTrackingWrapper -from mlair.helpers.statistics import KolmogorovZurbenkoFilterMovingWindow as KZFilter - -# define a more general date type for type hinting -str_or_list = Union[str, List[str]] - - -class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): - """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered.""" - - _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - _hash = DataHandlerSingleStation._hash + ["kz_filter_length", "kz_filter_iter", "filter_dim"] - - DEFAULT_FILTER_DIM = "filter" - - def __init__(self, *args, kz_filter_length, kz_filter_iter, filter_dim=DEFAULT_FILTER_DIM, **kwargs): - self._check_sampling(**kwargs) - # self.original_data = None # ToDo: implement here something to store unfiltered data - self.kz_filter_length = to_list(kz_filter_length) - self.kz_filter_iter = to_list(kz_filter_iter) - self.filter_dim = filter_dim - self.cutoff_period = None - self.cutoff_period_days = None - super().__init__(*args, **kwargs) - - def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: - """ - Adjust setup of transformation because kfz filtered data will have negative values which is not compatible with - the log transformation. Therefore, replace all log transformation methods by a default standardization. This is - only applied on input side. - """ - transformation = super(__class__, self).setup_transformation(transformation) - if transformation[0] is not None: - for k, v in transformation[0].items(): - if v["method"] == "log": - transformation[0][k]["method"] = "standardise" - return transformation - - def _check_sampling(self, **kwargs): - assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution - - def make_input_target(self): - data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, - self.station_type, self.network, self.store_data_locally, self.data_origin) - self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - limit=self.interpolation_limit) - self.set_inputs_and_targets() - self.apply_kz_filter() - # this is just a code snippet to check the results of the kz filter - # import matplotlib - # matplotlib.use("TkAgg") - # import matplotlib.pyplot as plt - # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() - # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") - - @TimeTrackingWrapper - def apply_kz_filter(self): - """Apply kolmogorov zurbenko filter only on inputs.""" - kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim=self.time_dim) - filtered_data: List[xr.DataArray] = kz.run() - self.cutoff_period = kz.period_null() - self.cutoff_period_days = kz.period_null_days() - self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) - - def create_filter_index(self) -> pd.Index: - """ - Round cut off periods in days and append 'res' for residuum index. - - Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append - 'res' for residuum index. - """ - index = np.round(self.cutoff_period_days, 1) - f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) - index = list(map(f, index.tolist())) - index = list(map(lambda x: str(x) + "d", index)) + ["res"] - return pd.Index(index, name=self.filter_dim) - - def get_transposed_history(self) -> xr.DataArray: - """Return history. - - :return: history with dimensions datetime, window, Stations, variables, filter. - """ - return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim, - self.filter_dim).copy() - - def _create_lazy_data(self): - return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days] - - def _extract_lazy(self, lazy_data): - _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data - f_prep = partial(self._slice_prep, start=self.start, end=self.end) - self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) - - -class DataHandlerKzFilter(DefaultDataHandler): - """Data handler using kz filtered data.""" - - data_handler = DataHandlerKzFilterSingleStation - data_handler_transformation = DataHandlerKzFilterSingleStation - _requirements = data_handler.requirements() diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index a10364333f3671448c560b40283fb2645d251428..8205ae6c28f3683b1052c292e5d063d8bca555dc 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -2,11 +2,15 @@ __author__ = 'Lukas Leufen' __date__ = '2020-11-05' from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation -from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation +from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation, \ + DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation +from mlair.data_handler.data_handler_with_filter import DataHandlerClimateFirFilter, DataHandlerFirFilter, \ + DataHandlerKzFilter from mlair.data_handler import DefaultDataHandler from mlair import helpers from mlair.helpers import remove_items from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD +from mlair.helpers.filter import filter_width_kzf import inspect from typing import Callable @@ -66,7 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): self.station_type, self.network, self.store_data_locally, self.data_origin, self.start, self.end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], - limit=self.interpolation_limit[ind]) + limit=self.interpolation_limit[ind], sampling=self.sampling[ind]) + return data def set_inputs_and_targets(self): @@ -94,8 +99,8 @@ class DataHandlerMixedSampling(DefaultDataHandler): class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation, - DataHandlerKzFilterSingleStation): - _requirements1 = DataHandlerKzFilterSingleStation.requirements() + DataHandlerFilterSingleStation): + _requirements1 = DataHandlerFilterSingleStation.requirements() _requirements2 = DataHandlerMixedSamplingSingleStation.requirements() _requirements = list(set(_requirements1 + _requirements2)) @@ -107,19 +112,16 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def make_input_target(self): """ - A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values + A FIR filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values with daily resolution. """ self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data self.set_inputs_and_targets() - self.apply_kz_filter() + self.apply_filter() def estimate_filter_width(self): - """ - f = 0.5 / (len * sqrt(itr)) -> T = 1 / f - :return: - """ - return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2) + """Return maximum filter width.""" + raise NotImplementedError @staticmethod def _add_time_delta(date, delta): @@ -152,26 +154,120 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi self.station_type, self.network, self.store_data_locally, self.data_origin, start, end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], - limit=self.interpolation_limit[ind]) + limit=self.interpolation_limit[ind], sampling=self.sampling[ind]) return data def _extract_lazy(self, lazy_data): - _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data + _data, self.meta, _input_data, _target_data = lazy_data start_inp, end_inp = self.update_start_end(0) self._data = tuple(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1])) self.input_data = self._slice_prep(_input_data, start_inp, end_inp) self.target_data = self._slice_prep(_target_data, self.start, self.end) -class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): +class DataHandlerMixedSamplingWithKzFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, + DataHandlerKzFilterSingleStation): + _requirements1 = DataHandlerKzFilterSingleStation.requirements() + _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() + _requirements = list(set(_requirements1 + _requirements2)) + + def estimate_filter_width(self): + """ + f = 0.5 / (len * sqrt(itr)) -> T = 1 / f + :return: + """ + return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2) + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days, \ + self.filter_dim_order = lazy_data + super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + + +class DataHandlerMixedSamplingWithKzFilter(DataHandlerKzFilter): + """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" + + data_handler = DataHandlerMixedSamplingWithKzFilterSingleStation + data_handler_transformation = DataHandlerMixedSamplingWithKzFilterSingleStation + _requirements = data_handler.requirements() + + +class DataHandlerMixedSamplingWithFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, + DataHandlerFirFilterSingleStation): + _requirements1 = DataHandlerFirFilterSingleStation.requirements() + _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() + _requirements = list(set(_requirements1 + _requirements2)) + + def estimate_filter_width(self): + """Filter width is determined by the filter with the highest order.""" + return max(self.filter_order) + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data + super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + + @staticmethod + def _get_fs(**kwargs): + """Return frequency in 1/day (not Hz)""" + sampling = kwargs.get("sampling")[0] + if sampling == "daily": + return 1 + elif sampling == "hourly": + return 24 + else: + raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.") + + +class DataHandlerMixedSamplingWithFirFilter(DataHandlerFirFilter): + """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" + + data_handler = DataHandlerMixedSamplingWithFirFilterSingleStation + data_handler_transformation = DataHandlerMixedSamplingWithFirFilterSingleStation + _requirements = data_handler.requirements() + + +class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, + DataHandlerClimateFirFilterSingleStation): + _requirements1 = DataHandlerClimateFirFilterSingleStation.requirements() + _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() + _requirements = list(set(_requirements1 + _requirements2)) + + def estimate_filter_width(self): + """Filter width is determined by the filter with the highest order.""" + if isinstance(self.filter_order[0], tuple): + return max([filter_width_kzf(*e) for e in self.filter_order]) + else: + return max(self.filter_order) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori, \ + self.filter_dim_order = lazy_data + DataHandlerMixedSamplingWithFilterSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data)) + + @staticmethod + def _get_fs(**kwargs): + """Return frequency in 1/day (not Hz)""" + sampling = kwargs.get("sampling")[0] + if sampling == "daily": + return 1 + elif sampling == "hourly": + return 24 + else: + raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.") + + +class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" - data_handler = DataHandlerMixedSamplingWithFilterSingleStation - data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation + data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation + data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation _requirements = data_handler.requirements() -class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFilterSingleStation): +class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation): """ Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the separation frequency of a filtered time series the time step delta for input data is adjusted (see image below). @@ -181,8 +277,8 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil """ - _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements() - _hash = DataHandlerMixedSamplingWithFilterSingleStation._hash + ["time_delta"] + _requirements = DataHandlerMixedSamplingWithKzFilterSingleStation.requirements() + _hash = DataHandlerMixedSamplingWithKzFilterSingleStation._hash + ["time_delta"] def __init__(self, *args, time_delta=np.sqrt, **kwargs): assert isinstance(time_delta, Callable) diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 89aafa2c7030427e105b663c97998c3ecf09eaaf..4330efd9ee5d3ae8a64c6eb9b95a0c58e18b3c36 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -280,7 +280,7 @@ class DataHandlerSingleStation(AbstractDataHandler): self.station_type, self.network, self.store_data_locally, self.data_origin, self.start, self.end) self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - limit=self.interpolation_limit) + limit=self.interpolation_limit, sampling=self.sampling) self.set_inputs_and_targets() def set_inputs_and_targets(self): @@ -406,7 +406,8 @@ class DataHandlerSingleStation(AbstractDataHandler): "propane", "so2", "toluene"] # used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys())) used_chem_vars = list(set(chem_vars) & set(data.variables.values)) - data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) + if len(used_chem_vars) > 0: + data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data def setup_data_path(self, data_path: str, sampling: str): @@ -468,9 +469,8 @@ class DataHandlerSingleStation(AbstractDataHandler): all_vars = sorted(statistics_per_var.keys()) return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv") - @staticmethod - def interpolate(data, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, - **kwargs): + def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None, + use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs): """ Interpolate values according to different methods. @@ -507,8 +507,22 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: xarray.DataArray """ + data = self.create_full_time_dim(data, dim, sampling) return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs) + @staticmethod + def create_full_time_dim(data, dim, sampling): + """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" + start = data.coords[dim].values[0] + end = data.coords[dim].values[-1] + freq = {"daily": "1D", "hourly": "1H"}.get(sampling) + datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq)) + t = data.sel({dim: start}, drop=True) + res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) + res = res.transpose(*data.dims) + res.loc[data.coords] = data + return res + def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: """ Create a xr.DataArray containing history data. diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..e76f396aea80b2db76e01ea5baacf71d024b0d23 --- /dev/null +++ b/mlair/data_handler/data_handler_with_filter.py @@ -0,0 +1,501 @@ +"""Data Handler using kz-filtered data.""" + +__author__ = 'Lukas Leufen' +__date__ = '2020-08-26' + +import inspect +import numpy as np +import pandas as pd +import xarray as xr +from typing import List, Union, Tuple, Optional +from functools import partial +import logging +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.data_handler import DefaultDataHandler +from mlair.helpers import remove_items, to_list, TimeTrackingWrapper +from mlair.helpers.filter import KolmogorovZurbenkoFilterMovingWindow as KZFilter +from mlair.helpers.filter import FIRFilter, ClimateFIRFilter, omega_null_kzf + +# define a more general date type for type hinting +str_or_list = Union[str, List[str]] + + +# cutoff_p = [(None, 14), (8, 6), (2, 0.8), (0.8, None)] +# cutoff = list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), cutoff_p)) +# fs = 24. +# # order = int(60 * fs) + 1 +# order = np.array([int(14 * fs) + 1, int(14 * fs) + 1, int(4 * fs) + 1, int(2 * fs) + 1]) +# print("cutoff period", cutoff_p) +# print("cutoff", cutoff) +# print("fs", fs) +# print("order", order) +# print("delay", 0.5 * (order-1) / fs) +# window = ("kaiser", 5) +# # low pass +# y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low = cutoff[0][0], cutoff_high = cutoff[0][1], window=window) +# filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape) + + +class DataHandlerFilterSingleStation(DataHandlerSingleStation): + """General data handler for a single station to be used by a superior data handler.""" + + _requirements = remove_items(DataHandlerSingleStation.requirements(), "station") + _hash = DataHandlerSingleStation._hash + ["filter_dim"] + + DEFAULT_FILTER_DIM = "filter" + + def __init__(self, *args, filter_dim=DEFAULT_FILTER_DIM, **kwargs): + # self.original_data = None # ToDo: implement here something to store unfiltered data + self.filter_dim = filter_dim + self.filter_dim_order = None + super().__init__(*args, **kwargs) + + def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: + """ + Adjust setup of transformation because filtered data will have negative values which is not compatible with + the log transformation. Therefore, replace all log transformation methods by a default standardization. This is + only applied on input side. + """ + transformation = super(__class__, self).setup_transformation(transformation) + if transformation[0] is not None: + for k, v in transformation[0].items(): + if v["method"] == "log": + transformation[0][k]["method"] = "standardise" + elif v["method"] == "min_max": + transformation[0][k]["method"] = "standardise" + return transformation + + def _check_sampling(self, **kwargs): + assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution, does it? + + def make_input_target(self): + data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, + self.station_type, self.network, self.store_data_locally, self.data_origin, + self.start, self.end) + self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) + self.set_inputs_and_targets() + self.apply_filter() + # this is just a code snippet to check the results of the kz filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() + # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") + + def apply_filter(self): + raise NotImplementedError + + def create_filter_index(self) -> pd.Index: + """Create name for filter dimension.""" + raise NotImplementedError + + def get_transposed_history(self) -> xr.DataArray: + """Return history. + + :return: history with dimensions datetime, window, Stations, variables, filter. + """ + return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim, + self.filter_dim).copy() + + def _create_lazy_data(self): + raise NotImplementedError + + def _extract_lazy(self, lazy_data): + _data, self.meta, _input_data, _target_data = lazy_data + f_prep = partial(self._slice_prep, start=self.start, end=self.end) + self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) + + +class DataHandlerFilter(DefaultDataHandler): + """Data handler using FIR filtered data.""" + + data_handler = DataHandlerFilterSingleStation + data_handler_transformation = DataHandlerFilterSingleStation + _requirements = data_handler.requirements() + + def __init__(self, *args, use_filter_branches=False, **kwargs): + self.use_filter_branches = use_filter_branches + super().__init__(*args, **kwargs) + + @classmethod + def own_args(cls, *args): + """Return all arguments (including kwonlyargs).""" + super_own_args = DefaultDataHandler.own_args(*args) + arg_spec = inspect.getfullargspec(cls) + list_of_args = arg_spec.args + arg_spec.kwonlyargs + super_own_args + return remove_items(list_of_args, ["self"] + list(args)) + + def get_X_original(self): + if self.use_filter_branches is True: + X = [] + for data in self._collection: + X_total = data.get_X() + filter_dim = data.filter_dim + for filter_name in data.filter_dim_order: + X.append(X_total.sel({filter_dim: filter_name}, drop=True)) + return X + else: + return super().get_X_original() + + +class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): + """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered.""" + + _requirements = remove_items(DataHandlerFilterSingleStation.requirements(), "station") + _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type", + "_add_unfiltered"] + + DEFAULT_WINDOW_TYPE = ("kaiser", 5) + DEFAULT_ADD_UNFILTERED = False + + def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, + filter_add_unfiltered=DEFAULT_ADD_UNFILTERED, **kwargs): + # self._check_sampling(**kwargs) + # self.original_data = None # ToDo: implement here something to store unfiltered data + self.fs = self._get_fs(**kwargs) + if filter_window_type == "kzf": + filter_cutoff_period = self._get_kzf_cutoff_period(filter_order, self.fs) + self.filter_cutoff_period, removed_index = self._prepare_filter_cutoff_period(filter_cutoff_period, self.fs) + self.filter_cutoff_freq = self._period_to_freq(self.filter_cutoff_period) + assert len(self.filter_cutoff_period) == (len(filter_order) - len(removed_index)) + self.filter_order = self._prepare_filter_order(filter_order, removed_index, self.fs) + self.filter_window_type = filter_window_type + self._add_unfiltered = filter_add_unfiltered + super().__init__(*args, **kwargs) + + @staticmethod + def _prepare_filter_order(filter_order, removed_index, fs): + order = [] + for i, o in enumerate(filter_order): + if i not in removed_index: + if isinstance(o, tuple): + fo = (o[0] * fs, o[1]) + else: + fo = int(o * fs) + fo = fo + 1 if fo % 2 == 0 else fo + order.append(fo) + return order + + @staticmethod + def _prepare_filter_cutoff_period(filter_cutoff_period, fs): + """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair.""" + cutoff_tmp = (lambda x: [x] if isinstance(x, tuple) else to_list(x))(filter_cutoff_period) + cutoff = [] + removed = [] + for i, (low, high) in enumerate(cutoff_tmp): + low = low if (low is None or low > 2. / fs) else None + high = high if (high is None or high > 2. / fs) else None + if any([low, high]): + cutoff.append((low, high)) + else: + removed.append(i) + return cutoff, removed + + @staticmethod + def _get_kzf_cutoff_period(kzf_settings, fs): + cutoff = [] + for (m, k) in kzf_settings: + w0 = omega_null_kzf(m * fs, k) * fs + cutoff.append(1. / w0) + return cutoff + + @staticmethod + def _period_to_freq(cutoff_p): + return list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), + cutoff_p)) + + @staticmethod + def _get_fs(**kwargs): + """Return frequency in 1/day (not Hz)""" + sampling = kwargs.get("sampling") + if sampling == "daily": + return 1 + elif sampling == "hourly": + return 24 + else: + raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.") + + @TimeTrackingWrapper + def apply_filter(self): + """Apply FIR filter only on inputs.""" + fir = FIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq, + self.filter_window_type, self.target_dim) + self.fir_coeff = fir.filter_coefficients() + fir_data = fir.filtered_data() + if self._add_unfiltered is True: + fir_data.append(self.input_data) + self.input_data = xr.concat(fir_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) + # this is just a code snippet to check the results of the kz filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() + # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") + + def create_filter_index(self) -> pd.Index: + """ + Create name for filter dimension. Use 'high' or 'low' for high/low pass data and 'bandi' for band pass data with + increasing numerator i (starting from 1). If 1 low, 2 band, and 1 high pass filter is used the filter index will + become to ['low', 'band1', 'band2', 'high']. + """ + index = [] + band_num = 1 + for (low, high) in self.filter_cutoff_period: + if low is None: + index.append("low") + elif high is None: + index.append("high") + else: + index.append(f"band{band_num}") + band_num += 1 + if self._add_unfiltered: + index.append("unfiltered") + self.filter_dim_order = index + return pd.Index(index, name=self.filter_dim) + + def _create_lazy_data(self): + return [self._data, self.meta, self.input_data, self.target_data, self.fir_coeff, self.filter_dim_order] + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data + super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + + +class DataHandlerFirFilter(DataHandlerFilter): + """Data handler using FIR filtered data.""" + + data_handler = DataHandlerFirFilterSingleStation + data_handler_transformation = DataHandlerFirFilterSingleStation + _requirements = data_handler.requirements() + + +class DataHandlerKzFilterSingleStation(DataHandlerFilterSingleStation): + """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered.""" + + _requirements = remove_items(inspect.getfullargspec(DataHandlerFilterSingleStation).args, ["self", "station"]) + _hash = DataHandlerFilterSingleStation._hash + ["kz_filter_length", "kz_filter_iter"] + + def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs): + self._check_sampling(**kwargs) + # self.original_data = None # ToDo: implement here something to store unfiltered data + self.kz_filter_length = to_list(kz_filter_length) + self.kz_filter_iter = to_list(kz_filter_iter) + self.cutoff_period = None + self.cutoff_period_days = None + super().__init__(*args, **kwargs) + + @TimeTrackingWrapper + def apply_filter(self): + """Apply kolmogorov zurbenko filter only on inputs.""" + kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim=self.time_dim) + filtered_data: List[xr.DataArray] = kz.run() + self.cutoff_period = kz.period_null() + self.cutoff_period_days = kz.period_null_days() + self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) + # this is just a code snippet to check the results of the kz filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() + # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") + + def create_filter_index(self) -> pd.Index: + """ + Round cut off periods in days and append 'res' for residuum index. + + Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append + 'res' for residuum index. + """ + index = np.round(self.cutoff_period_days, 1) + f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) + index = list(map(f, index.tolist())) + index = list(map(lambda x: str(x) + "d", index)) + ["res"] + self.filter_dim_order = index + return pd.Index(index, name=self.filter_dim) + + def _create_lazy_data(self): + return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days, + self.filter_dim_order] + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days, \ + self.filter_dim_order = lazy_data + super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + + +class DataHandlerKzFilter(DataHandlerFilter): + """Data handler using kz filtered data.""" + + data_handler = DataHandlerKzFilterSingleStation + data_handler_transformation = DataHandlerKzFilterSingleStation + _requirements = data_handler.requirements() + + +class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation): + """ + Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered. In contrast to + the simple DataHandlerFirFilterSingleStation, this data handler is centered around t0 to have no time delay. For + values in the future (t > t0), this data handler assumes a climatological value for the low pass data and values of + 0 for all residuum components. + + :param apriori: Data to use as apriori information. This should be either a xarray dataarray containing monthly or + any other heuristic to support the clim filter, or a list of such arrays containing heuristics for all residua + in addition. The 2nd can be used together with apriori_type `residuum_stats` which estimates the error of the + residuum when the clim filter should be applied with exogenous parameters. If apriori_type is None/`zeros` data + can be provided, but this is not required in this case. + :param apriori_type: set type of information that is provided to the clim filter. For the first low pass always a + calculated or given statistic is used. For residuum prediction a constant value of zero is assumed if + apriori_type is None or `zeros`, and a climatology of the residuum is used for `residuum_stats`. + :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by + parameter apriori_type. This is only applicable for hourly resolution data. + """ + + _requirements = remove_items(DataHandlerFirFilterSingleStation.requirements(), "station") + _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal"] + _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"] + + def __init__(self, *args, apriori=None, apriori_type=None, apriori_diurnal=False, apriori_sel_opts=None, + plot_path=None, name_affix=None, **kwargs): + self.apriori_type = apriori_type + self.climate_filter_coeff = None # coefficents of the used FIR filter + self.apriori = apriori # exogenous apriori information or None to calculate from data (endogenous) + self.apriori_diurnal = apriori_diurnal + self.all_apriori = None # collection of all apriori information + self.apriori_sel_opts = apriori_sel_opts # ensure to separate exogenous and endogenous information + self.plot_path = plot_path # use this path to create insight plots + self.plot_name_affix = name_affix + super().__init__(*args, **kwargs) + + @TimeTrackingWrapper + def apply_filter(self): + """Apply FIR filter only on inputs.""" + self.apriori = self.apriori.get(str(self)) if isinstance(self.apriori, dict) else self.apriori + logging.info(f"{self.station}: call ClimateFIRFilter") + plot_name = str(self) # if self.plot_name_affix is None else f"{str(self)}_{self.plot_name_affix}" + climate_filter = ClimateFIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, + self.filter_cutoff_freq, + self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim, + apriori_type=self.apriori_type, apriori=self.apriori, + apriori_diurnal=self.apriori_diurnal, sel_opts=self.apriori_sel_opts, + plot_path=self.plot_path, plot_name=plot_name, + minimum_length=self.window_history_size, new_dim=self.window_dim) + self.climate_filter_coeff = climate_filter.filter_coefficients + + # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori + if self.apriori_type == "residuum_stats": + self.apriori = climate_filter.apriori_data + else: + self.apriori = climate_filter.initial_apriori_data + self.all_apriori = climate_filter.apriori_data + + climate_filter_data = [c.sel({self.window_dim: slice(-self.window_history_size, 0)}) for c in + climate_filter.filtered_data] + + # create input data with filter index + input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) + + # add unfiltered raw data + if self._add_unfiltered is True: + data_raw = self.shift(self.input_data, self.time_dim, -self.window_history_size) + data_raw = data_raw.expand_dims({self.filter_dim: ["unfiltered"]}, -1) + input_data = xr.concat([input_data, data_raw], self.filter_dim) + + self.input_data = input_data + + # this is just a code snippet to check the results of the filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() + # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") + + def create_filter_index(self) -> pd.Index: + """ + Round cut off periods in days and append 'res' for residuum index. + + Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append + 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition. + """ + index = np.round(self.filter_cutoff_period, 1) + f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) + index = list(map(f, index.tolist())) + index = list(map(lambda x: str(x) + "d", index)) + ["res"] + if self._add_unfiltered: + index.append("unfiltered") + self.filter_dim_order = index + return pd.Index(index, name=self.filter_dim) + + def _create_lazy_data(self): + return [self._data, self.meta, self.input_data, self.target_data, self.climate_filter_coeff, + self.apriori, self.all_apriori, self.filter_dim_order] + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori, \ + self.filter_dim_order = lazy_data + DataHandlerSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data)) + + @staticmethod + def _prepare_filter_cutoff_period(filter_cutoff_period, fs): + """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair.""" + cutoff = [] + removed = [] + for i, period in enumerate(to_list(filter_cutoff_period)): + if period > 2. / fs: + cutoff.append(period) + else: + removed.append(i) + return cutoff, removed + + @staticmethod + def _period_to_freq(cutoff_p): + return [1. / x for x in cutoff_p] + + def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: + """ + Create a xr.DataArray containing history data. + + Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted + data. This is used to represent history in the data. Results are stored in history attribute. + + :param dim_name_of_inputs: Name of dimension which contains the input variables + :param window: number of time steps to look back in history + Note: window will be treated as negative value. This should be in agreement with looking back on + a time line. Nonetheless positive values are allowed but they are converted to its negative + expression + :param dim_name_of_shift: Dimension along shift will be applied + """ + data = self.input_data + sampling = {"daily": "D", "hourly": "h"}.get(to_list(self.sampling)[0]) + data.coords[dim_name_of_shift] = data.coords[dim_name_of_shift] - np.timedelta64(self.window_history_offset, + sampling) + data.coords[self.window_dim] = data.coords[self.window_dim] + self.window_history_offset + self.history = data + + def call_transform(self, inverse=False): + opts_input = self._transformation[0] + self.input_data, opts_input = self.transform(self.input_data, dim=[self.time_dim, self.window_dim], + inverse=inverse, opts=opts_input, + transformation_dim=self.target_dim) + opts_target = self._transformation[1] + self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse, + opts=opts_target, transformation_dim=self.target_dim) + self._transformation = (opts_input, opts_target) + + +class DataHandlerClimateFirFilter(DataHandlerFilter): + """Data handler using climatic adjusted FIR filtered data.""" + + data_handler = DataHandlerClimateFirFilterSingleStation + data_handler_transformation = DataHandlerClimateFirFilterSingleStation + _requirements = data_handler.requirements() + _store_attributes = data_handler.store_attributes() + + # def get_X_original(self): + # X = [] + # for data in self._collection: + # X_total = data.get_X() + # filter_dim = data.filter_dim + # for filter in data.filter_dim_order: + # X.append(X_total.sel({filter_dim: filter}, drop=True)) + # return X diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 8d977e115cf7ea85d4d83bfac4c59977412ab8a7..c97d57ef7edf26c258040047343a701974a9a8f1 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -33,14 +33,16 @@ class DefaultDataHandler(AbstractDataHandler): from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"]) + _store_attributes = data_handler.store_attributes() DEFAULT_ITER_DIM = "Stations" DEFAULT_TIME_DIM = "datetime" + MAX_NUMBER_MULTIPROCESSING = 16 def __init__(self, id_class: data_handler, experiment_path: str, min_length: int = 0, extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None, store_processed_data=True, iter_dim=DEFAULT_ITER_DIM, time_dim=DEFAULT_TIME_DIM, - use_multiprocessing=True): + use_multiprocessing=True, max_number_multiprocessing=MAX_NUMBER_MULTIPROCESSING): super().__init__() self.id_class = id_class self.time_dim = time_dim @@ -51,6 +53,7 @@ class DefaultDataHandler(AbstractDataHandler): self._X_extreme = None self._Y_extreme = None self._use_multiprocessing = use_multiprocessing + self._max_number_multiprocessing = max_number_multiprocessing _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) self._save_file = os.path.join(experiment_path, "data", f"{_name_affix}.pickle") self._collection = self._create_collection() @@ -79,7 +82,7 @@ class DefaultDataHandler(AbstractDataHandler): def _cleanup(self): directory = os.path.dirname(self._save_file) if os.path.exists(directory) is False: - os.makedirs(directory) + os.makedirs(directory, exist_ok=True) if os.path.exists(self._save_file): shutil.rmtree(self._save_file, ignore_errors=True) @@ -93,6 +96,16 @@ class DefaultDataHandler(AbstractDataHandler): logging.debug(f"save pickle data to {self._save_file}") self._reset_data() + def get_store_attributes(self): + attr_dict = {} + for attr in self.store_attributes(): + try: + val = self.__getattribute__(attr) + except AttributeError: + val = self.id_class.__getattribute__(attr) + attr_dict[attr] = val + return attr_dict + @staticmethod def _force_dask_computation(data): try: @@ -333,7 +346,9 @@ class DefaultDataHandler(AbstractDataHandler): if "feature_range" in opts.keys(): transformation_dict[i][var]["feature_range"] = opts.get("feature_range", None) - if multiprocessing.cpu_count() > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution + max_process = kwargs.get("max_number_multiprocessing", 16) + n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus + if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution logging.info("use parallel transformation approach") pool = multiprocessing.Pool( min([psutil.cpu_count(logical=False), len(set_stations), 16])) # use only physical cpus diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..a63cef975888162f335e4528c2f99bdfc7a892d5 --- /dev/null +++ b/mlair/helpers/filter.py @@ -0,0 +1,918 @@ +import gc +import warnings +from typing import Union, Callable, Tuple +import logging +import os +import time + +import datetime +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +from scipy import signal +import xarray as xr +import dask.array as da + +from mlair.helpers import to_list, TimeTrackingWrapper, TimeTracking + + +class FIRFilter: + + def __init__(self, data, fs, order, cutoff, window, dim): + + filtered = [] + h = [] + for i in range(len(order)): + fi, hi = fir_filter(data, fs, order=order[i], cutoff_low=cutoff[i][0], cutoff_high=cutoff[i][1], + window=window, dim=dim, h=None, causal=True, padlen=None) + filtered.append(fi) + h.append(hi) + + self._filtered = filtered + self._h = h + + def filter_coefficients(self): + return self._h + + def filtered_data(self): + return self._filtered + # + # y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low=cutoff[0][0], cutoff_high=cutoff[0][1], + # window=window) + # filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape) + # # band pass + # y_band, h_band = fir_filter(station_data.values.flatten(), fs, order[1], cutoff_low=cutoff[1][0], + # cutoff_high=cutoff[1][1], window=window) + # filtered_band = xr.ones_like(station_data) * y_band.reshape(station_data.values.shape) + # # band pass 2 + # y_band_2, h_band_2 = fir_filter(station_data.values.flatten(), fs, order[2], cutoff_low=cutoff[2][0], + # cutoff_high=cutoff[2][1], window=window) + # filtered_band_2 = xr.ones_like(station_data) * y_band_2.reshape(station_data.values.shape) + # # high pass + # y_high, h_high = fir_filter(station_data.values.flatten(), fs, order[3], cutoff_low=cutoff[3][0], + # cutoff_high=cutoff[3][1], window=window) + # filtered_high = xr.ones_like(station_data) * y_high.reshape(station_data.values.shape) + + +class ClimateFIRFilter: + from mlair.plotting.data_insight_plotting import PlotClimateFirFilter + + def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None, + apriori_diurnal=False, sel_opts=None, plot_path=None, plot_name=None, + minimum_length=None, new_dim=None): + """ + :param data: data to filter + :param fs: sampling frequency in 1/days -> 1d: fs=1 -> 1H: fs=24 + :param order: a tuple with the order of the filter in same ordering like cutoff + :param cutoff: a tuple with the cutoff frequencies (all are applied as low pass) + :param window: window type of the filter (e.g. hamming) + :param time_dim: name of time dimension to apply filter along + :param var_dim: name of variables dimension + :param apriori: apriori information to use for the first low pass. If None, climatology is calculated on the + provided data. + :param apriori_type: type of apriori information to use. Climatology will be used always for first low pass. For + the residuum either the value zero is used (apriori_type is None or "zeros") or a climatology on the + residua is used ("residuum_stats"). + :param apriori_diurnal: Use diurnal cycle as additional apriori information (only applicable for hourly + resoluted data). The mean anomaly of each hour is added to the apriori_type information. + """ + logging.info(f"{plot_name}: start init ClimateFIRFilter") + self.plot_path = plot_path + self.plot_name = plot_name + self.plot_data = [] + filtered = [] + h = [] + if sel_opts is not None: + sel_opts = sel_opts if isinstance(sel_opts, dict) else {time_dim: sel_opts} + sampling = {1: "1d", 24: "1H"}.get(int(fs)) + logging.debug(f"{plot_name}: create diurnal_anomalies") + if apriori_diurnal is True and sampling == "1H": + # diurnal_anomalies = self.create_hourly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, + # as_anomaly=True) + diurnal_anomalies = self.create_seasonal_hourly_mean(data, sel_opts=sel_opts, sampling=sampling, + time_dim=time_dim, + as_anomaly=True) + else: + diurnal_anomalies = 0 + logging.debug(f"{plot_name}: create monthly apriori") + if apriori is None: + apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, + time_dim=time_dim) + diurnal_anomalies + logging.debug(f"{plot_name}: apriori shape = {apriori.shape}") + apriori_list = to_list(apriori) + input_data = data.__deepcopy__() + + # for viz + plot_dates = None + + # create tmp dimension to apply filter, search for unused name + new_dim = self._create_tmp_dimension(input_data) if new_dim is None else new_dim + + for i in range(len(order)): + logging.info(f"{plot_name}: start filter for order {order[i]}") + # calculate climatological filter + # ToDo: remove all methods except the vectorized version + _minimum_length = self._minimum_length(order, minimum_length, i, window) + fi, hi, apriori, plot_data = self.clim_filter(input_data, fs, cutoff[i], order[i], + apriori=apriori_list[i], + sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, + window=window, var_dim=var_dim, + minimum_length=_minimum_length, new_dim=new_dim, + plot_dates=plot_dates) + + logging.info(f"{plot_name}: finished clim_filter calculation") + if minimum_length is None: + filtered.append(fi) + else: + filtered.append(fi.sel({new_dim: slice(-minimum_length, 0)})) + h.append(hi) + gc.collect() + self.plot_data.append(plot_data) + plot_dates = {e["t0"] for e in plot_data} + + # calculate residuum + logging.info(f"{plot_name}: calculate residuum") + coord_range = range(fi.coords[new_dim].values.min(), fi.coords[new_dim].values.max() + 1) + if new_dim in input_data.coords: + input_data = input_data.sel({new_dim: coord_range}) - fi + else: + input_data = self._shift_data(input_data, coord_range, time_dim, var_dim, new_dim) - fi + + # create new apriori information for next iteration if no further apriori is provided + if len(apriori_list) <= i + 1: + logging.info(f"{plot_name}: create diurnal_anomalies") + if apriori_diurnal is True and sampling == "1H": + # diurnal_anomalies = self.create_hourly_mean(input_data.sel({new_dim: 0}, drop=True), + # sel_opts=sel_opts, sampling=sampling, + # time_dim=time_dim, as_anomaly=True) + diurnal_anomalies = self.create_seasonal_hourly_mean(input_data.sel({new_dim: 0}, drop=True), + sel_opts=sel_opts, sampling=sampling, + time_dim=time_dim, as_anomaly=True) + else: + diurnal_anomalies = 0 + logging.info(f"{plot_name}: create monthly apriori") + if apriori_type is None or apriori_type == "zeros": # zero version + apriori_list.append(xr.zeros_like(apriori_list[i]) + diurnal_anomalies) + elif apriori_type == "residuum_stats": # calculate monthly statistic on residuum + apriori_list.append( + -self.create_monthly_mean(input_data.sel({new_dim: 0}, drop=True), sel_opts=sel_opts, + sampling=sampling, + time_dim=time_dim) + diurnal_anomalies) + else: + raise ValueError(f"Cannot handle unkown apriori type: {apriori_type}. Please choose from None, " + f"`zeros` or `residuum_stats`.") + # add last residuum to filtered + if minimum_length is None: + filtered.append(input_data) + else: + filtered.append(input_data.sel({new_dim: slice(-minimum_length, 0)})) + # filtered.append(input_data) + self._filtered = filtered + self._h = h + self._apriori = apriori_list + + # visualize + if self.plot_path is not None: + self.PlotClimateFirFilter(self.plot_path, self.plot_data, sampling, plot_name) + # self._plot(sampling, new_dim=new_dim) + + @staticmethod + def _minimum_length(order, minimum_length, pos, window): + next_order = 0 + if pos + 1 < len(order): + next_order = order[pos + 1] + if window == "kzf" and isinstance(next_order, tuple): + next_order = filter_width_kzf(*next_order) + if minimum_length is not None: + next_order = next_order + minimum_length + return next_order if next_order > 0 else None + + @staticmethod + def create_unity_array(data, time_dim, extend_range=366): + """Create a xr data array filled with ones. time_dim is extended by extend_range days in future and past.""" + coords = data.coords + + # extend time_dim by given extend_range days + start = coords[time_dim][0].values.astype("datetime64[D]") - np.timedelta64(extend_range, "D") + end = coords[time_dim][-1].values.astype("datetime64[D]") + np.timedelta64(extend_range, "D") + new_time_axis = np.arange(start, end).astype("datetime64[ns]") + + # construct data array with updated coords + new_coords = {k: data.coords[k].values if k != time_dim else new_time_axis for k in coords} + new_array = xr.DataArray(1, coords=new_coords, dims=new_coords.keys()).transpose(*data.dims) + + # loffset is required because resampling uses last day in month as resampling timestamp + return new_array.resample({time_dim: "1m"}, loffset=datetime.timedelta(days=-15)).max() + + def create_monthly_mean(self, data, sel_opts=None, sampling="1d", time_dim="datetime"): + """Calculate monthly statistics.""" + + # create unity xarray in monthly resolution with sampling point in mid of each month + monthly = self.create_unity_array(data, time_dim) + + # apply selection if given (only use subset for monthly means) + if sel_opts is not None: + data = data.sel(**sel_opts) + + # create monthly mean and replace entries in unity array + monthly_mean = data.groupby(f"{time_dim}.month").mean() + for month in monthly_mean.month.values: + monthly = xr.where((monthly[f"{time_dim}.month"] == month), + monthly_mean.sel(month=month, drop=True), + monthly) + # transform monthly information into original sampling rate + return monthly.resample({time_dim: sampling}).interpolate() + + # for month in monthly_mean.month.values: + # loc = (monthly[f"{time_dim}.month"] == month) + # monthly.loc[{time_dim: loc}] = monthly_mean.sel(month=month, drop=True) + # aggregate monthly information (shift by half month, because resample base is last day) + # return monthly.resample({time_dim: "1m"}).max().resample({time_dim: sampling}).interpolate() + + @staticmethod + def create_hourly_mean(data, sel_opts=None, sampling="1H", time_dim="datetime", as_anomaly=True): + """Calculate hourly statistics. Either the absolute value or the anomaly (as_anomaly=True).""" + # can only be used for hourly sampling rate + assert sampling == "1H" + + # create unity xarray in hourly resolution + hourly = xr.ones_like(data) + + # apply selection if given (only use subset for hourly means) + if sel_opts is not None: + data = data.sel(**sel_opts) + + # create mean for each hour and replace entries in unity array, calculate anomaly if enabled + hourly_mean = data.groupby(f"{time_dim}.hour").mean() + if as_anomaly is True: + hourly_mean = hourly_mean - hourly_mean.mean("hour") + for hour in hourly_mean.hour.values: + loc = (hourly[f"{time_dim}.hour"] == hour) + hourly.loc[{f"{time_dim}": loc}] = hourly_mean.sel(hour=hour) + return hourly + + def create_seasonal_hourly_mean(self, data, sel_opts=None, sampling="1H", time_dim="datetime", as_anomaly=True): + """Calculate hourly statistics. Either the absolute value or the anomaly (as_anomaly=True).""" + # can only be used for hourly sampling rate + assert sampling == "1H" + + # apply selection if given (only use subset for seasonal hourly means) + if sel_opts is not None: + data = data.sel(**sel_opts) + + # create unity xarray in monthly resolution with sampling point in mid of each month + monthly = self.create_unity_array(data, time_dim) * np.nan + + seasonal_hourly_means = {} + + for month in data.groupby(f"{time_dim}.month").groups.keys(): + # select each month + single_month_data = data.sel({time_dim: (data[f"{time_dim}.month"] == month)}) + hourly_mean = single_month_data.groupby(f"{time_dim}.hour").mean() + if as_anomaly is True: + hourly_mean = hourly_mean - hourly_mean.mean("hour") + seasonal_hourly_means[month] = hourly_mean + + seasonal_coll = [] + for hour in data.groupby(f"{time_dim}.hour").groups.keys(): + h_coll = monthly.__deepcopy__() + for month in seasonal_hourly_means.keys(): + hourly_mean_single_month = seasonal_hourly_means[month].sel(hour=hour, drop=True) + h_coll = xr.where((h_coll[f"{time_dim}.month"] == month), + hourly_mean_single_month, + h_coll) + h_coll = h_coll.resample({time_dim: sampling}).interpolate() + h_coll = h_coll.sel({time_dim: (h_coll[f"{time_dim}.hour"] == hour)}) + seasonal_coll.append(h_coll) + hourly = xr.concat(seasonal_coll, time_dim).sortby(time_dim).resample({time_dim: sampling}).interpolate() + + return hourly + + @staticmethod + def extend_apriori(data, apriori, time_dim, sampling="1d"): + """ + Extend time range of apriori information. + + This method may not working properly if length of apriori is less then one year. + """ + dates = data.coords[time_dim].values + td_type = {"1d": "D", "1H": "h"}.get(sampling) + + # apriori starts after data + if dates[0] < apriori.coords[time_dim].values[0]: + logging.debug(f"{data.coords['Stations'].values[0]}: apriori starts after data") + + # add difference in full years + date_diff = abs(dates[0] - apriori.coords[time_dim].values[0]).astype("timedelta64[D]") + extend_range = np.ceil(date_diff / (np.timedelta64(1, "D") * 365)).astype(int) * 365 + factor = 1 if td_type == "D" else 24 + + # get fill data range + start = apriori.coords[time_dim][0].values.astype("datetime64[%s]" % td_type) + end = apriori.coords[time_dim][0].values.astype("datetime64[%s]" % td_type) + np.timedelta64( + 366 * factor + 1, td_type) + + # fill year by year + for i in range(365, extend_range + 365, 365): + apriori_tmp = apriori.sel({time_dim: slice(start, end)}) # hint: slice includes end date + new_time_axis = apriori_tmp.coords[time_dim] - np.timedelta64(i * factor, td_type) + apriori_tmp.coords[time_dim] = new_time_axis + apriori = apriori.combine_first(apriori_tmp) + + # apriori ends before data + if dates[-1] + np.timedelta64(365, "D") > apriori.coords[time_dim].values[-1]: + logging.debug(f"{data.coords['Stations'].values[0]}: apriori ends before data") + + # add difference in full years + 1 year (because apriori is used as future estimate) + date_diff = abs(dates[-1] - apriori.coords[time_dim].values[-1]).astype("timedelta64[D]") + extend_range = np.ceil(date_diff / (np.timedelta64(1, "D") * 365)).astype(int) * 365 + 365 + factor = 1 if td_type == "D" else 24 + + # get fill data range + start = apriori.coords[time_dim][-1].values.astype("datetime64[%s]" % td_type) - np.timedelta64( + 366 * factor + 1, td_type) + end = apriori.coords[time_dim][-1].values.astype("datetime64[%s]" % td_type) + + # fill year by year + for i in range(365, extend_range + 365, 365): + apriori_tmp = apriori.sel({time_dim: slice(start, end)}) # hint: slice includes end date + new_time_axis = apriori_tmp.coords[time_dim] + np.timedelta64(i * factor, td_type) + apriori_tmp.coords[time_dim] = new_time_axis + apriori = apriori.combine_first(apriori_tmp) + + return apriori + + @TimeTrackingWrapper + def clim_filter(self, data, fs, cutoff_high, order, apriori=None, sel_opts=None, + sampling="1d", time_dim="datetime", var_dim="variables", window: Union[str, Tuple] = "hamming", + minimum_length=None, new_dim="window", plot_dates=None): + + logging.debug(f"{data.coords['Stations'].values[0]}: extend apriori") + + # calculate apriori information from data if not given and extend its range if not sufficient long enough + if apriori is None: + apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim) + apriori = apriori.astype(data.dtype) + apriori = self.extend_apriori(data, apriori, time_dim, sampling) + + # calculate FIR filter coefficients + if window == "kzf": + h = firwin_kzf(*order) + else: + h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window) + length = len(h) + + # use filter length if no minimum is given, otherwise use minimum + half filter length for extension + extend_length_history = length if minimum_length is None else minimum_length + int((length + 1) / 2) + extend_length_future = int((length + 1) / 2) + 1 + + # collect some data for visualization + plot_pos = np.array([0.25, 1.5, 2.75, 4]) * 365 * fs + if plot_dates is None: + plot_dates = [data.isel({time_dim: int(pos)}).coords[time_dim].values for pos in plot_pos if + pos < len(data.coords[time_dim])] + plot_data = [] + + coll = [] + + for var in reversed(data.coords[var_dim].values): + logging.info(f"{data.coords['Stations'].values[0]} ({var}): sel data") + + _start = pd.to_datetime(data.coords[time_dim].min().values).year + _end = pd.to_datetime(data.coords[time_dim].max().values).year + filt_coll = [] + for _year in range(_start, _end + 1): + logging.info(f"{data.coords['Stations'].values[0]} ({var}): year={_year}") + + time_slice = self._create_time_range_extend(_year, sampling, extend_length_history) + d = data.sel({var_dim: [var], time_dim: time_slice}) + a = apriori.sel({var_dim: [var], time_dim: time_slice}) + if len(d.coords[time_dim]) == 0: # no data at all for this year + continue + + # combine historical data / observation [t0-length,t0] and climatological statistics [t0+1,t0+length] + if new_dim not in d.coords: + history = self._shift_data(d, range(int(-extend_length_history), 1), time_dim, var_dim, new_dim) + else: + history = d.sel({new_dim: slice(int(-extend_length_history), 0)}) + if new_dim not in a.coords: + future = self._shift_data(a, range(1, extend_length_future), time_dim, var_dim, new_dim) + else: + future = a.sel({new_dim: slice(1, extend_length_future)}) + filter_input_data = xr.concat([history.dropna(time_dim), future], dim=new_dim, join="left") + try: + filter_input_data = filter_input_data.sel({time_dim: str(_year)}) + except KeyError: # no valid data for this year + continue + if len(filter_input_data.coords[time_dim]) == 0: # no valid data for this year + continue + + logging.debug(f"{data.coords['Stations'].values[0]} ({var}): start filter convolve") + with TimeTracking(name=f"{data.coords['Stations'].values[0]} ({var}): filter convolve", + logging_level=logging.DEBUG): + filt = xr.apply_ufunc(fir_filter_convolve, filter_input_data, + input_core_dims=[[new_dim]], + output_core_dims=[[new_dim]], + vectorize=True, + kwargs={"h": h}, + output_dtypes=[d.dtype]) + + if minimum_length is None: + filt_coll.append(filt.sel({new_dim: slice(-extend_length_history, 0)}, drop=True)) + else: + filt_coll.append(filt.sel({new_dim: slice(-minimum_length, 0)}, drop=True)) + + # visualization + for viz_date in set(plot_dates).intersection(filt.coords[time_dim].values): + try: + td_type = {"1d": "D", "1H": "h"}.get(sampling) + t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type) + t_plus = viz_date + np.timedelta64(int(extend_length_future), td_type) + if new_dim not in d.coords: + tmp_filter_data = self._shift_data(d.sel({time_dim: slice(t_minus, t_plus)}), + range(int(-extend_length_history), + int(extend_length_future)), + time_dim, var_dim, new_dim).sel({time_dim: viz_date}) + else: + # tmp_filter_data = d.sel({time_dim: viz_date, + # new_dim: slice(int(-extend_length_history), int(extend_length_future))}) + tmp_filter_data = None + valid_range = range(int((length + 1) / 2) if minimum_length is None else minimum_length, 1) + plot_data.append({"t0": viz_date, + "var": var, + "filter_input": filter_input_data.sel({time_dim: viz_date}), + "filter_input_nc": tmp_filter_data, + "valid_range": valid_range, + "time_range": d.sel( + {time_dim: slice(t_minus, t_plus - np.timedelta64(1, td_type))}).coords[ + time_dim].values, + "h": h, + "new_dim": new_dim}) + except: + pass + + # collect all filter results + coll.append(xr.concat(filt_coll, time_dim)) + gc.collect() + + logging.debug(f"{data.coords['Stations'].values[0]}: concat all variables") + res = xr.concat(coll, var_dim) + # create result array with same shape like input data, gabs are filled by nans + logging.debug(f"{data.coords['Stations'].values[0]}: create res_full") + + new_coords = {**{k: data.coords[k].values for k in data.coords if k != new_dim}, new_dim: res.coords[new_dim]} + dims = [*data.dims, new_dim] if new_dim not in data.dims else data.dims + res = res.transpose(*dims) + # res_full = xr.DataArray(dims=dims, coords=new_coords) + # res_full.loc[res.coords] = res + # res_full.compute() + res_full = res.broadcast_like(xr.DataArray(dims=dims, coords=new_coords)) + return res_full, h, apriori, plot_data + + @staticmethod + def _create_time_range_extend(year, sampling, extend_length): + td_type = {"1d": "D", "1H": "h"}.get(sampling) + delta = np.timedelta64(extend_length + 1, td_type) + start = np.datetime64(f"{year}-01-01") - delta + end = np.datetime64(f"{year}-12-31") + delta + return slice(start, end) + + @staticmethod + def _create_tmp_dimension(data): + new_dim = "window" + count = 0 + while new_dim in data.dims: + new_dim += new_dim + count += 1 + if count > 10: + raise ValueError("Could not create new dimension.") + return new_dim + + def _shift_data(self, data, index_value, time_dim, squeeze_dim, new_dim): + coll = [] + for i in index_value: + coll.append(data.shift({time_dim: -i})) + new_ind = self.create_index_array(new_dim, index_value, squeeze_dim) + return xr.concat(coll, dim=new_ind) + + @staticmethod + def create_index_array(index_name: str, index_value, squeeze_dim: str): + ind = pd.DataFrame({'val': index_value}, index=index_value) + res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze( + dim=squeeze_dim, + drop=True) + res.name = index_name + return res + + def _plot(self, sampling, new_dim="window"): + h = None + td_type = {"1d": "D", "1H": "h"}.get(sampling) + if self.plot_path is None: + return + plot_folder = os.path.join(os.path.abspath(self.plot_path), "climFIR") + if not os.path.exists(plot_folder): + os.makedirs(plot_folder) + + # set plot parameter + rc_params = {'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'medium', + 'axes.titlesize': 'large', + } + plt.rcParams.update(rc_params) + + plot_dict = {} + for i, o in enumerate(range(len(self.plot_data))): + plot_data = self.plot_data[i] + for p_d in plot_data: + var = p_d.get("var") + t0 = p_d.get("t0") + filter_input = p_d.get("filter_input") + filter_input_nc = p_d.get("filter_input_nc") + valid_range = p_d.get("valid_range") + time_range = p_d.get("time_range") + new_dim = p_d.get("new_dim") + h = p_d.get("h") + plot_dict_var = plot_dict.get(var, {}) + plot_dict_t0 = plot_dict_var.get(t0, {}) + plot_dict_order = {"filter_input": filter_input, + "filter_input_nc": filter_input_nc, + "valid_range": valid_range, + "time_range": time_range, + "order": o, "h": h} + plot_dict_t0[i] = plot_dict_order + plot_dict_var[t0] = plot_dict_t0 + plot_dict[var] = plot_dict_var + + for var, viz_date_dict in plot_dict.items(): + for it0, t0 in enumerate(viz_date_dict.keys()): + viz_data = viz_date_dict[t0] + residuum_true = None + for ifilter in sorted(viz_data.keys()): + data = viz_data[ifilter] + filter_input = data["filter_input"] + filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel( + {new_dim: filter_input.coords[new_dim]}) + valid_range = data["valid_range"] + time_axis = data["time_range"] + # time_axis = pd.date_range(t_minus, t_plus, freq=sampling) + filter_order = data["order"] + h = data["h"] + t_minus = t0 + np.timedelta64(-int(1.5 * valid_range.start), td_type) + t_plus = t0 + np.timedelta64(int(0.5 * valid_range.start), td_type) + fig, ax = plt.subplots() + ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type), + t0 + np.timedelta64(valid_range.stop - 1, td_type), color="whitesmoke", + label="valid area") + ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)") + + # original data + ax.plot(time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed", + label="original") + + # clim apriori + if ifilter == 0: + d_tmp = filter_input.sel( + {new_dim: slice(0, filter_input.coords[new_dim].values.max())}).values.flatten() + else: + d_tmp = filter_input.values.flatten() + ax.plot(time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid", + label="estimated future") + + # clim filter response + filt = xr.apply_ufunc(fir_filter_convolve, filter_input, + input_core_dims=[[new_dim]], + output_core_dims=[[new_dim]], + vectorize=True, + kwargs={"h": h}, + output_dtypes=[filter_input.dtype]) + ax.plot(time_axis, filt.values.flatten(), color="black", linestyle="solid", + label="clim filter response", linewidth=2) + residuum_estimated = filter_input - filt + + # ideal filter response + filt = xr.apply_ufunc(fir_filter_convolve, filter_input_nc, + input_core_dims=[[new_dim]], + output_core_dims=[[new_dim]], + vectorize=True, + kwargs={"h": h}, + output_dtypes=[filter_input.dtype]) + ax.plot(time_axis, filt.values.flatten(), color="black", linestyle="dashed", + label="ideal filter response", linewidth=2) + residuum_true = filter_input_nc - filt + + # set title, legend, and save plot + ax_start = max(t_minus, time_axis[0]) + ax_end = min(t_plus, time_axis[-1]) + ax.set_xlim((ax_start, ax_end)) + plt.title(f"Input of ClimFilter ({str(var)})") + plt.legend() + fig.autofmt_xdate() + plt.tight_layout() + plot_name = os.path.join(plot_folder, + f"climFIR_{self.plot_name}_{str(var)}_{it0}_{ifilter}.pdf") + plt.savefig(plot_name, dpi=300) + plt.close('all') + + # plot residuum + fig, ax = plt.subplots() + ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type), + t0 + np.timedelta64(valid_range.stop - 1, td_type), color="whitesmoke", + label="valid area") + ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)") + ax.plot(time_axis, residuum_true.values.flatten(), color="black", linestyle="dashed", + label="ideal filter residuum", linewidth=2) + ax.plot(time_axis, residuum_estimated.values.flatten(), color="black", linestyle="solid", + label="clim filter residuum", linewidth=2) + ax.set_xlim((ax_start, ax_end)) + plt.title(f"Residuum of ClimFilter ({str(var)})") + plt.legend(loc="upper left") + fig.autofmt_xdate() + plt.tight_layout() + plot_name = os.path.join(plot_folder, + f"climFIR_{self.plot_name}_{str(var)}_{it0}_{ifilter}_residuum.pdf") + plt.savefig(plot_name, dpi=300) + plt.close('all') + + @property + def filter_coefficients(self): + return self._h + + @property + def filtered_data(self): + return self._filtered + + @property + def apriori_data(self): + return self._apriori + + @property + def initial_apriori_data(self): + return self.apriori_data[0] + + +def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="hamming", dim="variables", h=None, + causal=True, padlen=None): + """Expects xarray.""" + if h is None: + cutoff = [] + if cutoff_low is not None: + cutoff += [cutoff_low] + if cutoff_high is not None: + cutoff += [cutoff_high] + if len(cutoff) == 2: + filter_type = "bandpass" + elif len(cutoff) == 1 and cutoff_low is not None: + filter_type = "highpass" + elif len(cutoff) == 1 and cutoff_high is not None: + filter_type = "lowpass" + else: + raise ValueError("Please provide either cutoff_low or cutoff_high.") + h = signal.firwin(order, cutoff, pass_zero=filter_type, fs=fs, window=window) + filtered = xr.ones_like(data) + for var in data.coords[dim]: + d = data.sel({dim: var}).values.flatten() + if causal: + y = signal.lfilter(h, 1., d) + else: + padlen = padlen if padlen is not None else 3 * len(h) + y = signal.filtfilt(h, 1., d, padlen=padlen) + filtered.loc[{dim: var}] = y + return filtered, h + + +def fir_filter_convolve(data, h): + return signal.convolve(data, h, mode='same', method="direct") / sum(h) + + +class KolmogorovZurbenkoBaseClass: + + def __init__(self, df, wl, itr, is_child=False, filter_dim="window"): + """ + It create the variables associate with the Kolmogorov-Zurbenko-filter. + + Args: + df(pd.DataFrame, None): time series of a variable + wl(list of int): window length + itr(list of int): number of iteration + """ + self.df = df + self.filter_dim = filter_dim + self.wl = to_list(wl) + self.itr = to_list(itr) + if abs(len(self.wl) - len(self.itr)) > 0: + raise ValueError("Length of lists for wl and itr must agree!") + self._isChild = is_child + self.child = self.set_child() + self.type = type(self).__name__ + + def set_child(self): + if len(self.wl) > 1: + return KolmogorovZurbenkoBaseClass(None, self.wl[1:], self.itr[1:], True, self.filter_dim) + else: + return None + + def kz_filter(self, df, m, k): + pass + + def spectral_calc(self): + df_start = self.df + kz = self.kz_filter(df_start, self.wl[0], self.itr[0]) + filtered = self.subtract(df_start, kz) + # case I: no child avail -> return kz and remaining + if self.child is None: + return [kz, filtered] + # case II: has child -> return current kz and all child results + else: + self.child.df = filtered + kz_next = self.child.spectral_calc() + return [kz] + kz_next + + @staticmethod + def subtract(minuend, subtrahend): + try: # pandas implementation + return minuend.sub(subtrahend, axis=0) + except AttributeError: # general implementation + return minuend - subtrahend + + def run(self): + return self.spectral_calc() + + def transfer_function(self): + m = self.wl[0] + k = self.itr[0] + omega = np.linspace(0.00001, 0.15, 5000) + return omega, (np.sin(m * np.pi * omega) / (m * np.sin(np.pi * omega))) ** (2 * k) + + def omega_null(self, alpha=0.5): + a = np.sqrt(6) / np.pi + b = 1 / (2 * np.array(self.itr)) + c = 1 - alpha ** b + d = np.array(self.wl) ** 2 - alpha ** b + return a * np.sqrt(c / d) + + def period_null(self, alpha=0.5): + return 1. / self.omega_null(alpha) + + def period_null_days(self, alpha=0.5): + return self.period_null(alpha) / 24. + + def plot_transfer_function(self, fig=None, name=None): + if fig is None: + fig = plt.figure() + omega, transfer_function = self.transfer_function() + if self.child is not None: + transfer_function_child = self.child.plot_transfer_function(fig) + else: + transfer_function_child = transfer_function * 0 + plt.semilogx(omega, transfer_function - transfer_function_child, + label="m={:3.0f}, k={:3.0f}, T={:6.2f}d".format(self.wl[0], + self.itr[0], + self.period_null_days())) + plt.axvline(x=self.omega_null()) + if not self._isChild: + locs, labels = plt.xticks() + plt.xticks(locs, np.round(1. / (locs * 24), 1)) + plt.xlim([0.00001, 0.15]) + plt.legend() + if name is None: + plt.show() + else: + plt.savefig(name) + else: + return transfer_function + + +class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass): + + def __init__(self, df, wl: Union[list, int], itr: Union[list, int], is_child=False, filter_dim="window", + method="mean", percentile=0.5): + """ + It create the variables associate with the KolmogorovZurbenkoFilterMovingWindow class. + + Args: + df(pd.DataFrame, xr.DataArray): time series of a variable + wl: window length + itr: number of iteration + """ + self.valid_methods = ["mean", "percentile", "median", "max", "min"] + if method not in self.valid_methods: + raise ValueError("Method '{}' is not supported. Please select from [{}].".format( + method, ", ".join(self.valid_methods))) + else: + self.method = method + if percentile > 1 or percentile < 0: + raise ValueError("Percentile must be in range [0, 1]. Given was {}!".format(percentile)) + else: + self.percentile = percentile + super().__init__(df, wl, itr, is_child, filter_dim) + + def set_child(self): + if len(self.wl) > 1: + return KolmogorovZurbenkoFilterMovingWindow(self.df, self.wl[1:], self.itr[1:], is_child=True, + filter_dim=self.filter_dim, method=self.method, + percentile=self.percentile) + else: + return None + + @TimeTrackingWrapper + def kz_filter_new(self, df, wl, itr): + """ + It passes the low frequency time series. + + If filter method is from mean, max, min this method will call construct and rechunk before the actual + calculation to improve performance. If filter method is either median or percentile this approach is not + applicable and depending on the data and window size, this method can become slow. + + Args: + wl(int): a window length + itr(int): a number of iteration + """ + warnings.filterwarnings("ignore") + df_itr = df.__deepcopy__() + try: + kwargs = {"min_periods": int(0.7 * wl), + "center": True, + self.filter_dim: wl} + for i in np.arange(0, itr): + print(i) + rolling = df_itr.chunk().rolling(**kwargs) + if self.method not in ["percentile", "median"]: + rolling = rolling.construct("construct").chunk("auto") + if self.method == "median": + df_mv_avg_tmp = rolling.median() + elif self.method == "percentile": + df_mv_avg_tmp = rolling.quantile(self.percentile) + elif self.method == "max": + df_mv_avg_tmp = rolling.max("construct") + elif self.method == "min": + df_mv_avg_tmp = rolling.min("construct") + else: + df_mv_avg_tmp = rolling.mean("construct") + df_itr = df_mv_avg_tmp.compute() + del df_mv_avg_tmp, rolling + gc.collect() + return df_itr + except ValueError: + raise ValueError + + @TimeTrackingWrapper + def kz_filter(self, df, wl, itr): + """ + It passes the low frequency time series. + + Args: + wl(int): a window length + itr(int): a number of iteration + """ + import warnings + warnings.filterwarnings("ignore") + df_itr = df.__deepcopy__() + try: + kwargs = {"min_periods": int(0.7 * wl), + "center": True, + self.filter_dim: wl} + iter_vars = df_itr.coords["variables"].values + for var in iter_vars: + df_itr_var = df_itr.sel(variables=[var]) + for _ in np.arange(0, itr): + df_itr_var = df_itr_var.chunk() + rolling = df_itr_var.rolling(**kwargs) + if self.method == "median": + df_mv_avg_tmp = rolling.median() + elif self.method == "percentile": + df_mv_avg_tmp = rolling.quantile(self.percentile) + elif self.method == "max": + df_mv_avg_tmp = rolling.max() + elif self.method == "min": + df_mv_avg_tmp = rolling.min() + else: + df_mv_avg_tmp = rolling.mean() + df_itr_var = df_mv_avg_tmp.compute() + df_itr.loc[{"variables": [var]}] = df_itr_var + return df_itr + except ValueError: + raise ValueError + + +def firwin_kzf(m, k): + coef = np.ones(m) + for i in range(1, k): + t = np.zeros((m, m + i * (m - 1))) + for km in range(m): + t[km, km:km + coef.size] = coef + coef = np.sum(t, axis=0) + return coef / m ** k + + +def omega_null_kzf(m, k, alpha=0.5): + a = np.sqrt(6) / np.pi + b = 1 / (2 * np.array(k)) + c = 1 - alpha ** b + d = np.array(m) ** 2 - alpha ** b + return a * np.sqrt(c / d) + + +def filter_width_kzf(m, k): + return k * (m - 1) + 1 diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index b57b733b08c4635a16d7fd18e99538a991521fd8..5ddaa3ee3fe505eeb7c8082274d9cd888cec720f 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -9,7 +9,7 @@ import numpy as np import xarray as xr import dask.array as da -from typing import Dict, Callable, Union, List, Any +from typing import Dict, Callable, Union, List, Any, Tuple def to_list(obj: Any) -> List: @@ -68,9 +68,9 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce return round_type(number * multiplier) / multiplier -def remove_items(obj: Union[List, Dict], items: Any): +def remove_items(obj: Union[List, Dict, Tuple], items: Any): """ - Remove item(s) from either list or dictionary. + Remove item(s) from either list, tuple or dictionary. :param obj: object to remove items from (either dictionary or list) :param items: elements to remove from obj. Can either be a list or single entry / key @@ -99,6 +99,8 @@ def remove_items(obj: Union[List, Dict], items: Any): return remove_from_list(obj, items) elif isinstance(obj, dict): return remove_from_dict(obj, items) + elif isinstance(obj, tuple): + return tuple(remove_from_list(to_list(obj), items)) else: raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") @@ -177,5 +179,3 @@ def convert2xrda(arr: Union[xr.DataArray, xr.Dataset, np.ndarray, int, float], kwargs.update({'dims': dims, 'coords': coords}) return xr.DataArray(arr, **kwargs) - - diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index 30391998c65950f12fc6824626638788e1bd721b..a1e713a8c135800d02ff7c27894485a5da7fae37 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -9,12 +9,7 @@ import numpy as np import xarray as xr import pandas as pd from typing import Union, Tuple, Dict, List -from matplotlib import pyplot as plt import itertools -import gc -import warnings - -from mlair.helpers import to_list, TimeTracking, TimeTrackingWrapper Data = Union[xr.DataArray, pd.DataFrame] @@ -262,11 +257,12 @@ class SkillScores: """ models_default = ["cnn", "persi", "ols"] - def __init__(self, external_data: Data, models=None, observation_name="obs"): + def __init__(self, external_data: Union[Data, None], models=None, observation_name="obs", ahead_dim="ahead"): """Set internal data.""" self.external_data = external_data self.models = self.set_model_names(models) self.observation_name = observation_name + self.ahead_dim = ahead_dim def set_model_names(self, models: List[str]) -> List[str]: """Either use given models or use defaults.""" @@ -288,19 +284,17 @@ class SkillScores: combination_strings = [f"{first}-{second}" for (first, second) in combinations] return combinations, combination_strings - def skill_scores(self, window_lead_time: int) -> pd.DataFrame: + def skill_scores(self) -> pd.DataFrame: """ Calculate skill scores for all combinations of model names. - :param window_lead_time: length of forecast steps - :return: skill score for each comparison and forecast step """ - ahead_names = list(range(1, window_lead_time + 1)) + ahead_names = list(self.external_data[self.ahead_dim].data) combinations, combination_strings = self.get_model_name_combinations() skill_score = pd.DataFrame(index=combination_strings) for iahead in ahead_names: - data = self.external_data.sel(ahead=iahead) + data = self.external_data.sel({self.ahead_dim: iahead}) skill_score[iahead] = [self.general_skill_score(data, forecast_name=first, reference_name=second, @@ -308,8 +302,7 @@ class SkillScores: for (first, second) in combinations] return skill_score - def climatological_skill_scores(self, internal_data: Data, window_lead_time: int, - forecast_name: str) -> xr.DataArray: + def climatological_skill_scores(self, internal_data: Data, forecast_name: str) -> xr.DataArray: """ Calculate climatological skill scores according to Murphy (1988). @@ -317,20 +310,19 @@ class SkillScores: is part of parameters. :param internal_data: internal data - :param window_lead_time: interested time step of forecast horizon to select data :param forecast_name: name of the forecast to use for this calculation (must be available in `data`) :return: all CASES as well as all terms """ - ahead_names = list(range(1, window_lead_time + 1)) + ahead_names = list(self.external_data[self.ahead_dim].data) all_terms = ['AI', 'AII', 'AIII', 'AIV', 'BI', 'BII', 'BIV', 'CI', 'CIV', 'CASE I', 'CASE II', 'CASE III', 'CASE IV'] skill_score = xr.DataArray(np.full((len(all_terms), len(ahead_names)), np.nan), coords=[all_terms, ahead_names], - dims=['terms', 'ahead']) + dims=['terms', self.ahead_dim]) for iahead in ahead_names: - data = internal_data.sel(ahead=iahead) + data = internal_data.sel({self.ahead_dim: iahead}) skill_score.loc[["CASE I", "AI", "BI", "CI"], iahead] = np.stack(self._climatological_skill_score( data, mu_type=1, forecast_name=forecast_name, observation_name=self.observation_name).values.flatten()) @@ -338,8 +330,8 @@ class SkillScores: skill_score.loc[["CASE II", "AII", "BII"], iahead] = np.stack(self._climatological_skill_score( data, mu_type=2, forecast_name=forecast_name, observation_name=self.observation_name).values.flatten()) - if self.external_data is not None: - external_data = self.external_data.sel(ahead=iahead, type=[self.observation_name]) + if self.external_data is not None and self.observation_name in self.external_data.coords["type"]: + external_data = self.external_data.sel({self.ahead_dim: iahead, "type": [self.observation_name]}) skill_score.loc[["CASE III", "AIII"], iahead] = np.stack(self._climatological_skill_score( data, mu_type=3, forecast_name=forecast_name, observation_name=self.observation_name, external_data=external_data).values.flatten()) @@ -378,12 +370,12 @@ class SkillScores: skill_score = 1 - mse(observation, forecast) / mse(observation, reference) return skill_score.values - @staticmethod - def skill_score_pre_calculations(data: Data, observation_name: str, forecast_name: str) -> Tuple[np.ndarray, - np.ndarray, - np.ndarray, - Data, - Dict[str, Data]]: + def skill_score_pre_calculations(self, data: Data, observation_name: str, forecast_name: str) -> Tuple[np.ndarray, + np.ndarray, + np.ndarray, + Data, + Dict[ + str, Data]]: """ Calculate terms AI, BI, and CI, mean, variance and pearson's correlation and clean up data. @@ -396,7 +388,7 @@ class SkillScores: :returns: Terms AI, BI, and CI, internal data without nans and mean, variance, correlation and its p-value """ - data = data.sel(type=[observation_name, forecast_name]).drop("ahead") + data = data.sel(type=[observation_name, forecast_name]).drop(self.ahead_dim) data = data.dropna("index") mean = data.mean("index") @@ -483,212 +475,3 @@ class SkillScores: return monthly_mean - -class KolmogorovZurbenkoBaseClass: - - def __init__(self, df, wl, itr, is_child=False, filter_dim="window"): - """ - It create the variables associate with the Kolmogorov-Zurbenko-filter. - - Args: - df(pd.DataFrame, None): time series of a variable - wl(list of int): window length - itr(list of int): number of iteration - """ - self.df = df - self.filter_dim = filter_dim - self.wl = to_list(wl) - self.itr = to_list(itr) - if abs(len(self.wl) - len(self.itr)) > 0: - raise ValueError("Length of lists for wl and itr must agree!") - self._isChild = is_child - self.child = self.set_child() - self.type = type(self).__name__ - - def set_child(self): - if len(self.wl) > 1: - return KolmogorovZurbenkoBaseClass(None, self.wl[1:], self.itr[1:], True, self.filter_dim) - else: - return None - - def kz_filter(self, df, m, k): - pass - - def spectral_calc(self): - df_start = self.df - kz = self.kz_filter(df_start, self.wl[0], self.itr[0]) - filtered = self.subtract(df_start, kz) - # case I: no child avail -> return kz and remaining - if self.child is None: - return [kz, filtered] - # case II: has child -> return current kz and all child results - else: - self.child.df = filtered - kz_next = self.child.spectral_calc() - return [kz] + kz_next - - @staticmethod - def subtract(minuend, subtrahend): - try: # pandas implementation - return minuend.sub(subtrahend, axis=0) - except AttributeError: # general implementation - return minuend - subtrahend - - def run(self): - return self.spectral_calc() - - def transfer_function(self): - m = self.wl[0] - k = self.itr[0] - omega = np.linspace(0.00001, 0.15, 5000) - return omega, (np.sin(m * np.pi * omega) / (m * np.sin(np.pi * omega))) ** (2 * k) - - def omega_null(self, alpha=0.5): - a = np.sqrt(6) / np.pi - b = 1 / (2 * np.array(self.itr)) - c = 1 - alpha ** b - d = np.array(self.wl) ** 2 - alpha ** b - return a * np.sqrt(c / d) - - def period_null(self, alpha=0.5): - return 1. / self.omega_null(alpha) - - def period_null_days(self, alpha=0.5): - return self.period_null(alpha) / 24. - - def plot_transfer_function(self, fig=None, name=None): - if fig is None: - fig = plt.figure() - omega, transfer_function = self.transfer_function() - if self.child is not None: - transfer_function_child = self.child.plot_transfer_function(fig) - else: - transfer_function_child = transfer_function * 0 - plt.semilogx(omega, transfer_function - transfer_function_child, - label="m={:3.0f}, k={:3.0f}, T={:6.2f}d".format(self.wl[0], - self.itr[0], - self.period_null_days())) - plt.axvline(x=self.omega_null()) - if not self._isChild: - locs, labels = plt.xticks() - plt.xticks(locs, np.round(1. / (locs * 24), 1)) - plt.xlim([0.00001, 0.15]) - plt.legend() - if name is None: - plt.show() - else: - plt.savefig(name) - else: - return transfer_function - - -class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass): - - def __init__(self, df, wl: Union[list, int], itr: Union[list, int], is_child=False, filter_dim="window", - method="mean", percentile=0.5): - """ - It create the variables associate with the KolmogorovZurbenkoFilterMovingWindow class. - - Args: - df(pd.DataFrame, xr.DataArray): time series of a variable - wl: window length - itr: number of iteration - """ - self.valid_methods = ["mean", "percentile", "median", "max", "min"] - if method not in self.valid_methods: - raise ValueError("Method '{}' is not supported. Please select from [{}].".format( - method, ", ".join(self.valid_methods))) - else: - self.method = method - if percentile > 1 or percentile < 0: - raise ValueError("Percentile must be in range [0, 1]. Given was {}!".format(percentile)) - else: - self.percentile = percentile - super().__init__(df, wl, itr, is_child, filter_dim) - - def set_child(self): - if len(self.wl) > 1: - return KolmogorovZurbenkoFilterMovingWindow(self.df, self.wl[1:], self.itr[1:], is_child=True, - filter_dim=self.filter_dim, method=self.method, - percentile=self.percentile) - else: - return None - - @TimeTrackingWrapper - def kz_filter_new(self, df, wl, itr): - """ - It passes the low frequency time series. - - If filter method is from mean, max, min this method will call construct and rechunk before the actual - calculation to improve performance. If filter method is either median or percentile this approach is not - applicable and depending on the data and window size, this method can become slow. - - Args: - wl(int): a window length - itr(int): a number of iteration - """ - warnings.filterwarnings("ignore") - df_itr = df.__deepcopy__() - try: - kwargs = {"min_periods": int(0.7 * wl), - "center": True, - self.filter_dim: wl} - for i in np.arange(0, itr): - print(i) - rolling = df_itr.chunk().rolling(**kwargs) - if self.method not in ["percentile", "median"]: - rolling = rolling.construct("construct").chunk("auto") - if self.method == "median": - df_mv_avg_tmp = rolling.median() - elif self.method == "percentile": - df_mv_avg_tmp = rolling.quantile(self.percentile) - elif self.method == "max": - df_mv_avg_tmp = rolling.max("construct") - elif self.method == "min": - df_mv_avg_tmp = rolling.min("construct") - else: - df_mv_avg_tmp = rolling.mean("construct") - df_itr = df_mv_avg_tmp.compute() - del df_mv_avg_tmp, rolling - gc.collect() - return df_itr - except ValueError: - raise ValueError - - @TimeTrackingWrapper - def kz_filter(self, df, wl, itr): - """ - It passes the low frequency time series. - - Args: - wl(int): a window length - itr(int): a number of iteration - """ - import warnings - warnings.filterwarnings("ignore") - df_itr = df.__deepcopy__() - try: - kwargs = {"min_periods": int(0.7 * wl), - "center": True, - self.filter_dim: wl} - iter_vars = df_itr.coords["variables"].values - for var in iter_vars: - df_itr_var = df_itr.sel(variables=[var]) - for _ in np.arange(0, itr): - df_itr_var = df_itr_var.chunk() - rolling = df_itr_var.rolling(**kwargs) - if self.method == "median": - df_mv_avg_tmp = rolling.median() - elif self.method == "percentile": - df_mv_avg_tmp = rolling.quantile(self.percentile) - elif self.method == "max": - df_mv_avg_tmp = rolling.max() - elif self.method == "min": - df_mv_avg_tmp = rolling.min() - else: - df_mv_avg_tmp = rolling.mean() - df_itr_var = df_mv_avg_tmp.compute() - df_itr.loc[{"variables": [var]}] = df_itr_var - return df_itr - except ValueError: - raise ValueError diff --git a/mlair/helpers/time_tracking.py b/mlair/helpers/time_tracking.py index c85a6a047943a589a9d076584ae40186634db767..3105ebcd04406b7d449ba312bd3af46f83e3a716 100644 --- a/mlair/helpers/time_tracking.py +++ b/mlair/helpers/time_tracking.py @@ -68,11 +68,12 @@ class TimeTracking(object): The only disadvantage of the latter implementation is, that the duration is logged but not returned. """ - def __init__(self, start=True, name="undefined job"): + def __init__(self, start=True, name="undefined job", logging_level=logging.INFO): """Construct time tracking and start if enabled.""" self.start = None self.end = None self._name = name + self._logging = {logging.INFO: logging.info, logging.DEBUG: logging.debug}.get(logging_level, logging.info) if start: self._start() @@ -128,4 +129,4 @@ class TimeTracking(object): def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Stop time tracking on exit and log info about passed time.""" self.stop() - logging.info(f"{self._name} finished after {self}") \ No newline at end of file + self._logging(f"{self._name} finished after {self}") diff --git a/mlair/model_modules/fully_connected_networks.py b/mlair/model_modules/fully_connected_networks.py index 9fb08cdf6efacab12c2828ed221966586bce1d08..0338033315d294c2e54de8b038bba2123d2fee77 100644 --- a/mlair/model_modules/fully_connected_networks.py +++ b/mlair/model_modules/fully_connected_networks.py @@ -1,11 +1,11 @@ __author__ = "Lukas Leufen" -__date__ = '2021-02-' +__date__ = '2021-02-18' from functools import reduce, partial from mlair.model_modules import AbstractModelClass from mlair.helpers import select_from_dict -from mlair.model_modules.loss import var_loss, custom_loss +from mlair.model_modules.loss import var_loss, custom_loss, l_p_loss import keras @@ -20,7 +20,8 @@ class FCN(AbstractModelClass): "sigmoid": partial(keras.layers.Activation, "sigmoid"), "linear": partial(keras.layers.Activation, "linear"), "selu": partial(keras.layers.Activation, "selu"), - "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25))} + "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)), + "leakyrelu": partial(keras.layers.LeakyReLU)} _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform", "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(), "prelu": keras.initializers.he_normal()} @@ -31,12 +32,31 @@ class FCN(AbstractModelClass): def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None, - **kwargs): + batch_normalization=False, **kwargs): """ Sets model and loss depending on the given arguments. :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables)) :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast)) + + Customize this FCN model via the following parameters: + + :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu, + leakyrelu. (Default relu) + :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default + linear) + :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam) + :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each + layer. (Default 1) + :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10) + :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the + settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the + hidden layer. The number of hidden layers is equal to the total length of this list. + :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the + network at all. (Default None) + :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted + between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer + is added if set to false. (Default false) """ assert len(input_shape) == 1 @@ -49,6 +69,7 @@ class FCN(AbstractModelClass): self.activation_output = self._set_activation(activation_output) self.activation_output_name = activation_output self.optimizer = self._set_optimizer(optimizer, **kwargs) + self.bn = batch_normalization self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration self._update_model_name() self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") @@ -58,7 +79,7 @@ class FCN(AbstractModelClass): # apply to model self.set_model() self.set_compile_options() - self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss) + self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss, l_p_loss=l_p_loss(.5)) def _set_activation(self, activation): try: @@ -115,27 +136,29 @@ class FCN(AbstractModelClass): """ Build the model. """ - x_input = keras.layers.Input(shape=self._input_shape) - x_in = keras.layers.Flatten()(x_input) if isinstance(self.layer_configuration, tuple) is True: n_layer, n_hidden = self.layer_configuration - for layer in range(n_layer): - x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer, - kernel_regularizer=self.kernel_regularizer)(x_in) - x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in) - if self.dropout is not None: - x_in = self.dropout(self.dropout_rate)(x_in) + conf = [n_hidden for _ in range(n_layer)] else: assert isinstance(self.layer_configuration, list) is True - for layer, n_hidden in enumerate(self.layer_configuration): - x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer, - kernel_regularizer=self.kernel_regularizer)(x_in) - x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in) - if self.dropout is not None: - x_in = self.dropout(self.dropout_rate)(x_in) + conf = self.layer_configuration + + x_input = keras.layers.Input(shape=self._input_shape) + x_in = keras.layers.Flatten()(x_input) + + for layer, n_hidden in enumerate(conf): + x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer)(x_in) + if self.bn is True: + x_in = keras.layers.BatchNormalization()(x_in) + x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in) + if self.dropout is not None: + x_in = self.dropout(self.dropout_rate)(x_in) + x_in = keras.layers.Dense(self._output_shape)(x_in) out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) self.model = keras.Model(inputs=x_input, outputs=[out]) + print(self.model.summary()) def set_compile_options(self): self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])], @@ -167,3 +190,191 @@ class FCN_64_32_16(FCN): def _update_model_name(self): self.model_name = "FCN" super()._update_model_name() + + +class BranchedInputFCN(AbstractModelClass): + """ + A customisable fully connected network (64, 32, 16, window_lead_time), where the last layer is the output layer depending + on the window_lead_time parameter. + """ + + _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"), + "sigmoid": partial(keras.layers.Activation, "sigmoid"), + "linear": partial(keras.layers.Activation, "linear"), + "selu": partial(keras.layers.Activation, "selu"), + "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)), + "leakyrelu": partial(keras.layers.LeakyReLU)} + _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform", + "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(), + "prelu": keras.initializers.he_normal()} + _optimizer = {"adam": keras.optimizers.adam, "sgd": keras.optimizers.SGD} + _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2} + _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"] + _dropout = {"selu": keras.layers.AlphaDropout} + + def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", + optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None, + batch_normalization=False, **kwargs): + """ + Sets model and loss depending on the given arguments. + + :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables)) + :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast)) + + Customize this FCN model via the following parameters: + + :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu, + leakyrelu. (Default relu) + :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default + linear) + :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam) + :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each + layer. (Default 1) + :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10) + :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the + settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the + hidden layer. The number of hidden layers is equal to the total length of this list. + :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the + network at all. (Default None) + :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted + between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer + is added if set to false. (Default false) + """ + + super().__init__(input_shape, output_shape[0]) + + # settings + self.activation = self._set_activation(activation) + self.activation_name = activation + self.activation_output = self._set_activation(activation_output) + self.activation_output_name = activation_output + self.optimizer = self._set_optimizer(optimizer, **kwargs) + self.bn = batch_normalization + self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration + self._update_model_name() + self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") + self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs) + self.dropout, self.dropout_rate = self._set_dropout(activation, dropout) + + # apply to model + self.set_model() + self.set_compile_options() + self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss) + + def _set_activation(self, activation): + try: + return self._activation.get(activation.lower()) + except KeyError: + raise AttributeError(f"Given activation {activation} is not supported in this model class.") + + def _set_optimizer(self, optimizer, **kwargs): + try: + opt_name = optimizer.lower() + opt = self._optimizer.get(opt_name) + opt_kwargs = {} + if opt_name == "adam": + opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"]) + elif opt_name == "sgd": + opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"]) + return opt(**opt_kwargs) + except KeyError: + raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.") + + def _set_regularizer(self, regularizer, **kwargs): + if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): + return None + try: + reg_name = regularizer.lower() + reg = self._regularizer.get(reg_name) + reg_kwargs = {} + if reg_name in ["l1", "l2"]: + reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True) + if reg_name in reg_kwargs: + reg_kwargs["l"] = reg_kwargs.pop(reg_name) + elif reg_name == "l1_l2": + reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True) + return reg(**reg_kwargs) + except KeyError: + raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") + + def _set_dropout(self, activation, dropout_rate): + if dropout_rate is None: + return None, None + assert 0 <= dropout_rate < 1 + return self._dropout.get(activation, keras.layers.Dropout), dropout_rate + + def _update_model_name(self): + n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}" + n_output = str(self._output_shape) + + if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2: + n_layer, n_hidden = self.layer_configuration + branch = [f"{n_hidden}" for _ in range(n_layer)] + else: + branch = [f"{n}" for n in self.layer_configuration] + + concat = [] + n_neurons_concat = int(branch[-1]) * len(self._input_shape) + for exp in reversed(range(2, len(self._input_shape) + 1)): + n_neurons = self._output_shape ** exp + if n_neurons < n_neurons_concat: + if len(concat) == 0: + concat.append(f"1x{n_neurons}") + else: + concat.append(str(n_neurons)) + self.model_name += "_".join(["", n_input, *branch, *concat, n_output]) + + def set_model(self): + """ + Build the model. + """ + + if isinstance(self.layer_configuration, tuple) is True: + n_layer, n_hidden = self.layer_configuration + conf = [n_hidden for _ in range(n_layer)] + else: + assert isinstance(self.layer_configuration, list) is True + conf = self.layer_configuration + + x_input = [] + x_in = [] + + for branch in range(len(self._input_shape)): + x_input_b = keras.layers.Input(shape=self._input_shape[branch]) + x_input.append(x_input_b) + x_in_b = keras.layers.Flatten()(x_input_b) + + for layer, n_hidden in enumerate(conf): + x_in_b = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer, + name=f"Dense_branch{branch + 1}_{layer + 1}")(x_in_b) + if self.bn is True: + x_in_b = keras.layers.BatchNormalization()(x_in_b) + x_in_b = self.activation(name=f"{self.activation_name}_branch{branch + 1}_{layer + 1}")(x_in_b) + if self.dropout is not None: + x_in_b = self.dropout(self.dropout_rate)(x_in_b) + x_in.append(x_in_b) + x_concat = keras.layers.Concatenate()(x_in) + + n_neurons_concat = int(conf[-1]) * len(self._input_shape) + layer_concat = 0 + for exp in reversed(range(2, len(self._input_shape) + 1)): + n_neurons = self._output_shape ** exp + if n_neurons < n_neurons_concat: + layer_concat += 1 + x_concat = keras.layers.Dense(n_neurons, name=f"Dense_{layer_concat}")(x_concat) + if self.bn is True: + x_concat = keras.layers.BatchNormalization()(x_concat) + x_concat = self.activation(name=f"{self.activation_name}_{layer_concat}")(x_concat) + if self.dropout is not None: + x_concat = self.dropout(self.dropout_rate)(x_concat) + x_concat = keras.layers.Dense(self._output_shape)(x_concat) + out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat) + self.model = keras.Model(inputs=x_input, outputs=[out]) + print(self.model.summary()) + + def set_compile_options(self): + self.compile_options = {"loss": [keras.losses.mean_squared_error], + "metrics": ["mse", "mae", var_loss]} + # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss], loss_weights=[2, 1])], + # "metrics": ["mse", "mae", var_loss]} diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index 33358e566ef80f28ee7740531b71d1a83abde115..e0f54282010e765fb3d8b0aca191a75c0b22fdf9 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -8,6 +8,7 @@ import math import pickle from typing import Union, List from typing_extensions import TypedDict +from time import time import numpy as np from keras import backend as K @@ -111,6 +112,20 @@ class LearningRateDecay(History): return K.get_value(self.model.optimizer.lr) +class EpoTimingCallback(Callback): + def __init__(self): + self.epo_timing = {'epo_timing': []} + self.logs = [] + self.starttime = None + super().__init__() + + def on_epoch_begin(self, epoch: int, logs=None): + self.starttime = time() + + def on_epoch_end(self, epoch: int, logs=None): + self.epo_timing["epo_timing"].append(time()-self.starttime) + + class ModelCheckpointAdvanced(ModelCheckpoint): """ Enhance the standard ModelCheckpoint class by additional saves of given callbacks. diff --git a/mlair/model_modules/loss.py b/mlair/model_modules/loss.py index ba871e983ecfa1e91676d53b834ebd622c00fe49..2034c5a7795fad302d2a289e6fadbd5e295117cc 100644 --- a/mlair/model_modules/loss.py +++ b/mlair/model_modules/loss.py @@ -16,10 +16,10 @@ def l_p_loss(power: int) -> Callable: :return: loss for given power """ - def loss(y_true, y_pred): + def l_p_loss(y_true, y_pred): return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1) - return loss + return l_p_loss def var_loss(y_true, y_pred) -> Callable: diff --git a/mlair/model_modules/recurrent_networks.py b/mlair/model_modules/recurrent_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..95c48bc8659354c7c669bb03a7591dafbbe9f262 --- /dev/null +++ b/mlair/model_modules/recurrent_networks.py @@ -0,0 +1,194 @@ +__author__ = "Lukas Leufen" +__date__ = '2021-05-25' + +from functools import reduce, partial + +from mlair.model_modules import AbstractModelClass +from mlair.helpers import select_from_dict +from mlair.model_modules.loss import var_loss, custom_loss + +import keras + + +class RNN(AbstractModelClass): + """ + + """ + + _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"), + "sigmoid": partial(keras.layers.Activation, "sigmoid"), + "linear": partial(keras.layers.Activation, "linear"), + "selu": partial(keras.layers.Activation, "selu"), + "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)), + "leakyrelu": partial(keras.layers.LeakyReLU)} + _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform", + "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(), + "prelu": keras.initializers.he_normal()} + _optimizer = {"adam": keras.optimizers.adam, "sgd": keras.optimizers.SGD} + _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2} + _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"] + _dropout = {"selu": keras.layers.AlphaDropout} + _rnn = {"lstm": keras.layers.LSTM, "gru": keras.layers.GRU} + + def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", + activation_rnn="tanh", dropout_rnn=0, + optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None, + batch_normalization=False, rnn_type="lstm", add_dense_layer=False, **kwargs): + """ + Sets model and loss depending on the given arguments. + + :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables)) + :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast)) + + Customize this RNN model via the following parameters: + + :param activation: set your desired activation function for appended dense layers (add_dense_layer=True=. Choose + from relu, tanh, sigmoid, linear, selu, prelu, leakyrelu. (Default relu) + :param activation_rnn: set your desired activation function of the rnn output. Choose from relu, tanh, sigmoid, + linear, selu, prelu, leakyrelu. (Default tanh) + :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default + linear) + :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam) + :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each + layer. (Default 1) + :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10) + :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the + settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the + hidden layer. The number of hidden layers is equal to the total length of this list. + :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the + network at all. (Default None) + :param dropout_rnn: use recurrent dropout with given rate. This is applied along the recursion and not after + a rnn layer. (Default 0) + :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted + between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer + is added if set to false. (Default false) + :param rnn_type: define which kind of recurrent network should be applied. Chose from either lstm or gru. All + units will be of this kind. (Default lstm) + """ + + assert len(input_shape) == 1 + assert len(output_shape) == 1 + super().__init__(input_shape[0], output_shape[0]) + + # settings + self.activation = self._set_activation(activation.lower()) + self.activation_name = activation + self.activation_rnn = self._set_activation(activation_rnn.lower()) + self.activation_rnn_name = activation + self.activation_output = self._set_activation(activation_output.lower()) + self.activation_output_name = activation_output + self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs) + self.bn = batch_normalization + self.add_dense_layer = add_dense_layer + self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration + self.RNN = self._rnn.get(rnn_type.lower()) + self._update_model_name(rnn_type) + self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") + # self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs) + self.dropout, self.dropout_rate = self._set_dropout(activation, dropout) + assert 0 <= dropout_rnn <= 1 + self.dropout_rnn = dropout_rnn + + # apply to model + self.set_model() + self.set_compile_options() + self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss) + + def set_model(self): + """ + Build the model. + """ + if isinstance(self.layer_configuration, tuple) is True: + n_layer, n_hidden = self.layer_configuration + conf = [n_hidden for _ in range(n_layer)] + else: + assert isinstance(self.layer_configuration, list) is True + conf = self.layer_configuration + + x_input = keras.layers.Input(shape=self._input_shape) + x_in = keras.layers.Reshape((self._input_shape[0], reduce((lambda x, y: x * y), self._input_shape[1:])))( + x_input) + + for layer, n_hidden in enumerate(conf): + return_sequences = (layer < len(conf) - 1) + x_in = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn)(x_in) + if self.bn is True: + x_in = keras.layers.BatchNormalization()(x_in) + x_in = self.activation_rnn(name=f"{self.activation_rnn_name}_{layer + 1}")(x_in) + if self.dropout is not None: + x_in = self.dropout(self.dropout_rate)(x_in) + + if self.add_dense_layer is True: + x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}", + kernel_initializer=self.kernel_initializer, )(x_in) + x_in = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_in) + x_in = keras.layers.Dense(self._output_shape)(x_in) + out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) + self.model = keras.Model(inputs=x_input, outputs=[out]) + print(self.model.summary()) + + # x_in = keras.layers.LSTM(32)(x_in) + # if self.dropout is not None: + # x_in = self.dropout(self.dropout_rate)(x_in) + # x_in = keras.layers.RepeatVector(self._output_shape)(x_in) + # x_in = keras.layers.LSTM(32, return_sequences=True)(x_in) + # if self.dropout is not None: + # x_in = self.dropout(self.dropout_rate)(x_in) + # out = keras.layers.TimeDistributed(keras.layers.Dense(1))(x_in) + # out = keras.layers.Flatten()(out) + + def _set_dropout(self, activation, dropout_rate): + if dropout_rate is None: + return None, None + assert 0 <= dropout_rate < 1 + return self._dropout.get(activation, keras.layers.Dropout), dropout_rate + + def _set_activation(self, activation): + try: + return self._activation.get(activation.lower()) + except KeyError: + raise AttributeError(f"Given activation {activation} is not supported in this model class.") + + def set_compile_options(self): + self.compile_options = {"loss": [keras.losses.mean_squared_error], + "metrics": ["mse", "mae", var_loss]} + + def _set_optimizer(self, optimizer, **kwargs): + try: + opt_name = optimizer.lower() + opt = self._optimizer.get(opt_name) + opt_kwargs = {} + if opt_name == "adam": + opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"]) + elif opt_name == "sgd": + opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"]) + return opt(**opt_kwargs) + except KeyError: + raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.") + # + # def _set_regularizer(self, regularizer, **kwargs): + # if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): + # return None + # try: + # reg_name = regularizer.lower() + # reg = self._regularizer.get(reg_name) + # reg_kwargs = {} + # if reg_name in ["l1", "l2"]: + # reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True) + # if reg_name in reg_kwargs: + # reg_kwargs["l"] = reg_kwargs.pop(reg_name) + # elif reg_name == "l1_l2": + # reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True) + # return reg(**reg_kwargs) + # except KeyError: + # raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") + + def _update_model_name(self, rnn_type): + n_input = str(reduce(lambda x, y: x * y, self._input_shape)) + n_output = str(self._output_shape) + self.model_name = rnn_type.upper() + if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2: + n_layer, n_hidden = self.layer_configuration + self.model_name += "_".join(["", n_input, *[f"{n_hidden}" for _ in range(n_layer)], n_output]) + else: + self.model_name += "_".join(["", n_input, *[f"{n}" for n in self.layer_configuration], n_output]) diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index 26376637b947f6cd97b66d584583a70c09ae868b..c4c1f4af8c6077a0f2a07b08ebc1d97d68eaf549 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -3,6 +3,7 @@ __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2021-04-13' from typing import List, Dict +import dill import os import logging import multiprocessing @@ -16,7 +17,7 @@ from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, date from astropy.timeseries import LombScargle from mlair.data_handler import DataCollection -from mlair.helpers import TimeTrackingWrapper, to_list +from mlair.helpers import TimeTrackingWrapper, to_list, remove_items from mlair.plotting.abstract_plot_class import AbstractPlotClass @TimeTrackingWrapper @@ -526,16 +527,18 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover self.variables_dim = variables_dim self.time_dim = time_dim self.window_dim = window_dim - self.inputs, self.targets = self._get_inputs_targets(generators, self.variables_dim) + self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim) self.bins = {} self.interval_width = {} self.bin_edges = {} # input plots - self._calculate_hist(generators, self.inputs, input_data=True) - for subset in generators.keys(): - self._plot(add_name="input", subset=subset) - self._plot_combined(add_name="input") + for branch_pos in range(number_of_branches): + self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos) + add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}" + for subset in generators.keys(): + self._plot(add_name=add_name, subset=subset) + self._plot_combined(add_name=add_name) # target plots self._calculate_hist(generators, self.targets, input_data=False) @@ -549,16 +552,17 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover gen = gens[k][0] inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist()) targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist()) - return inputs, targets + n_branches = len(gen.get_X(as_numpy=False)) + return inputs, targets, n_branches - def _calculate_hist(self, generators, variables, input_data=True): + def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0): n_bins = 100 for set_type, generator in generators.items(): tmp_bins = {} tmp_edges = {} end = {} start = {} - f = lambda x: x.get_X(as_numpy=False)[0] if input_data is True else x.get_Y(as_numpy=False) + f = lambda x: x.get_X(as_numpy=False)[branch_pos] if input_data is True else x.get_Y(as_numpy=False) for gen in generator: w = min(abs(f(gen).coords[self.window_dim].values)) data = f(gen).sel({self.window_dim: w}) @@ -866,13 +870,15 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) logging.info(f"... plotting {plot_name}") pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) - colors = ["blue", "red", "green", "orange", "purple", "black", "grey"] + colors = ["grey", "blue", "red", "green", "orange", "purple", "black"] label_names = ["orig"] + label_names max_iter = len(self.plot_data) var_keys = self.plot_data[0].keys() for var in var_keys: fig, ax = plt.subplots() for i in reversed(range(max_iter)): + if label_names[i] == "unfiltered": + continue # do not include the filter 'unfiltered' because this is equal to the 'orig' data plot_data = self.plot_data[i] c = colors[i] ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0) @@ -889,9 +895,13 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover plt.close('all') -def f_proc(var, d_var, f_index): # pragma: no cover +def f_proc(var, d_var, f_index, time_dim="datetime"): # pragma: no cover var_str = str(var) - t = (d_var.datetime - d_var.datetime[0]).astype("timedelta64[h]").values / np.timedelta64(1, "D") + t = (d_var[time_dim] - d_var[time_dim][0]).astype("timedelta64[h]").values / np.timedelta64(1, "D") + if len(d_var.shape) > 1: # use only max value if dimensions are remaining (e.g. max(window) -> latest value) + to_remove = remove_items(d_var.coords.dims, time_dim) + for e in to_list(to_remove): + d_var = d_var.sel({e: d_var[e].max()}) pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").power(f_index) # f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").autopower() return var_str, f_index, pgram @@ -923,3 +933,218 @@ def f_proc_hist(data, variables, n_bins, variables_dim): # pragma: no cover res[var], bin_edges[var] = np.histogram(d.values, n_bins) interval_width[var] = bin_edges[var][1] - bin_edges[var][0] return res, interval_width, bin_edges + + +class PlotClimateFirFilter(AbstractPlotClass): + """ + Plot climate FIR filter components. + + * Creates a separate folder climFIR inside the given plot directory. + * For each station up to 4 examples are shown (1 for each season). + * Each filtered component and its residuum is drawn in a separate plot. + * A filter component plot includes the climate FIR input, the filter response, the true non-causal (ideal) filter + input, and the corresponding ideal response (containing information about future) + * A filter residuum plot include the climate FIR residuum and the ideal filter residuum. + """ + + def __init__(self, plot_folder, plot_data, sampling, name): + + from mlair.helpers.filter import fir_filter_convolve + + # adjust default plot parameters + rc_params = { + 'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'medium', + 'axes.titlesize': 'large'} + if plot_folder is None: + return + + self.style_dict = { + "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"}, + "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"}, + "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2}, + "ideal": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2}, + "valid_area": {"color": "whitesmoke", "label": "valid area"}, + "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"} + } + + plot_folder = os.path.join(os.path.abspath(plot_folder), "climFIR") + self.fir_filter_convolve = fir_filter_convolve + super().__init__(plot_folder, plot_name=None, rc_params=rc_params) + plot_dict, new_dim = self._prepare_data(plot_data) + self._name = name + self._plot(plot_dict, sampling, new_dim) + self._store_plot_data(plot_data) + + def _prepare_data(self, data): + """Restructure plot data.""" + plot_dict = {} + new_dim = None + for i, o in enumerate(range(len(data))): + plot_data = data[i] + for p_d in plot_data: + var = p_d.get("var") + t0 = p_d.get("t0") + filter_input = p_d.get("filter_input") + filter_input_nc = p_d.get("filter_input_nc") + valid_range = p_d.get("valid_range") + time_range = p_d.get("time_range") + if new_dim is None: + new_dim = p_d.get("new_dim") + else: + assert new_dim == p_d.get("new_dim") + h = p_d.get("h") + plot_dict_var = plot_dict.get(var, {}) + plot_dict_t0 = plot_dict_var.get(t0, {}) + plot_dict_order = {"filter_input": filter_input, + "filter_input_nc": filter_input_nc, + "valid_range": valid_range, + "time_range": time_range, + "order": len(h), "h": h} + plot_dict_t0[i] = plot_dict_order + plot_dict_var[t0] = plot_dict_t0 + plot_dict[var] = plot_dict_var + return plot_dict, new_dim + + def _plot(self, plot_dict, sampling, new_dim="window"): + td_type = {"1d": "D", "1H": "h"}.get(sampling) + for var, viz_date_dict in plot_dict.items(): + for it0, t0 in enumerate(viz_date_dict.keys()): + viz_data = viz_date_dict[t0] + residuum_true = None + for ifilter in sorted(viz_data.keys()): + data = viz_data[ifilter] + filter_input = data["filter_input"] + filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel( + {new_dim: filter_input.coords[new_dim]}) + valid_range = data["valid_range"] + time_axis = data["time_range"] + filter_order = data["order"] + h = data["h"] + fig, ax = plt.subplots() + + # plot backgrounds + self._plot_valid_area(ax, t0, valid_range, td_type) + self._plot_t0(ax, t0) + + # original data + self._plot_original_data(ax, time_axis, filter_input_nc) + + # clim apriori + self._plot_apriori(ax, time_axis, filter_input, new_dim, ifilter) + + # clim filter response + residuum_estimated = self._plot_clim_filter(ax, time_axis, filter_input, new_dim, h, + output_dtypes=filter_input.dtype) + + # ideal filter response + residuum_true = self._plot_ideal_filter(ax, time_axis, filter_input_nc, new_dim, h, + output_dtypes=filter_input.dtype) + + # set title, legend, and save plot + xlims = self._set_xlim(ax, t0, filter_order, valid_range, td_type, time_axis) + + plt.title(f"Input of ClimFilter ({str(var)})") + plt.legend() + fig.autofmt_xdate() + plt.tight_layout() + self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}" + self._save() + + # plot residuum + fig, ax = plt.subplots() + self._plot_valid_area(ax, t0, valid_range, td_type) + self._plot_t0(ax, t0) + self._plot_series(ax, time_axis, residuum_true.values.flatten(), style="ideal") + self._plot_series(ax, time_axis, residuum_estimated.values.flatten(), style="clim") + ax.set_xlim(xlims) + plt.title(f"Residuum of ClimFilter ({str(var)})") + plt.legend(loc="upper left") + fig.autofmt_xdate() + plt.tight_layout() + + self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum" + self._save() + + def _set_xlim(self, ax, t0, order, valid_range, td_type, time_axis): + """ + Set xlims + + Use order and valid_range to find a good zoom in that hides edges of filter values that are effected by reduced + filter order. Limits are returned to be usable for other plots. + """ + t_minus_delta = max(1.5 * valid_range.start, 0.3 * order) + t_plus_delta = max(0.5 * valid_range.start, 0.3 * order) + t_minus = t0 + np.timedelta64(-int(t_minus_delta), td_type) + t_plus = t0 + np.timedelta64(int(t_plus_delta), td_type) + ax_start = max(t_minus, time_axis[0]) + ax_end = min(t_plus, time_axis[-1]) + ax.set_xlim((ax_start, ax_end)) + return ax_start, ax_end + + def _plot_valid_area(self, ax, t0, valid_range, td_type): + ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type), + t0 + np.timedelta64(valid_range.stop - 1, td_type), **self.style_dict["valid_area"]) + + def _plot_t0(self, ax, t0): + ax.axvline(t0, **self.style_dict["t0"]) + + def _plot_series(self, ax, time_axis, data, style): + ax.plot(time_axis, data, **self.style_dict[style]) + + def _plot_original_data(self, ax, time_axis, data): + # original data + filter_input_nc = data + self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), style="original") + # self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed", + # label="original") + + def _plot_apriori(self, ax, time_axis, data, new_dim, ifilter): + # clim apriori + filter_input = data + if ifilter == 0: + d_tmp = filter_input.sel( + {new_dim: slice(0, filter_input.coords[new_dim].values.max())}).values.flatten() + else: + d_tmp = filter_input.values.flatten() + self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, style="apriori") + # self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid", + # label="estimated future") + + def _plot_clim_filter(self, ax, time_axis, data, new_dim, h, output_dtypes): + filter_input = data + # clim filter response + filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input, + input_core_dims=[[new_dim]], + output_core_dims=[[new_dim]], + vectorize=True, + kwargs={"h": h}, + output_dtypes=[output_dtypes]) + self._plot_series(ax, time_axis, filt.values.flatten(), style="clim") + # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="solid", + # label="clim filter response", linewidth=2) + residuum_estimated = filter_input - filt + return residuum_estimated + + def _plot_ideal_filter(self, ax, time_axis, data, new_dim, h, output_dtypes): + filter_input_nc = data + # ideal filter response + filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input_nc, + input_core_dims=[[new_dim]], + output_core_dims=[[new_dim]], + vectorize=True, + kwargs={"h": h}, + output_dtypes=[output_dtypes]) + self._plot_series(ax, time_axis, filt.values.flatten(), style="ideal") + # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="dashed", + # label="ideal filter response", linewidth=2) + residuum_true = filter_input_nc - filt + return residuum_true + + def _store_plot_data(self, data): + """Store plot data. Could be loaded in a notebook to redraw.""" + file = os.path.join(self.plot_folder, "plot_data.pickle") + with open(file, "wb") as f: + dill.dump(data, f) diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index b5e76e5540a06aa5ae33ec85b0e4dfe73931dc9b..29ed4054206f77ca919a416dd1792193dec4aef6 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -217,6 +217,8 @@ class PlotMonthlySummary(AbstractPlotClass): data_nn = data.sel(type=self._model_name).squeeze() if len(data_nn.shape) > 1: data_nn = data_nn.assign_coords(ahead=[f"{days}d" for days in data_nn.coords["ahead"].values]) + else: + data_nn.coords["ahead"].values = str(data_nn.coords["ahead"].values) + "d" data_obs = data.sel(type="obs", ahead=1).squeeze() data_obs.coords["ahead"] = "obs" @@ -744,7 +746,9 @@ class PlotBootstrapSkillScore(AbstractPlotClass): """ - def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars: List = None): + def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars: List = None, + sampling: str = "daily", ahead_dim: str = "ahead", bootstrap_type: str = None, + bootstrap_method: str = None): """ Set attributes and create plot. @@ -752,20 +756,46 @@ class PlotBootstrapSkillScore(AbstractPlotClass): :param plot_folder: path to save the plot (default: current directory) :param model_setup: architecture type to specify plot name (default "CNN") :param separate_vars: variables to plot separated (default: ['o3']) + :param sampling: type of sampling rate, should be either hourly or daily (default: "daily") + :param ahead_dim: name of the ahead dimensions (default: "ahead") + :param bootstrap_annotation: additional information to use in the file name (default: None) """ - super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}") + annotation = ["_".join([s for s in ["", bootstrap_type, bootstrap_method] if s is not None])][0] + super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}{annotation}") if separate_vars is None: separate_vars = ['o3'] self._labels = None self._x_name = "boot_var" - self._data = self._prepare_data(data) - self._plot() - self._save() - self.plot_name += '_separated' - self._plot(separate_vars=separate_vars) - self._save(bbox_inches='tight') + self._ahead_dim = ahead_dim + self._boot_type = self._set_bootstrap_type(bootstrap_type) + self._boot_method = self._set_bootstrap_method(bootstrap_method) + + self._title = f"Bootstrap analysis ({self._boot_method}, {self._boot_type})" + self._data = self._prepare_data(data, sampling) + if "branch" in self._data.columns: + plot_name = self.plot_name + for branch in self._data["branch"].unique(): + self._title = f"Bootstrap analysis ({self._boot_method}, {self._boot_type}, {branch})" + self._plot(branch=branch) + self.plot_name = f"{plot_name}_{branch}" + self._save() + else: + self._plot() + self._save() + if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0: + self.plot_name += '_separated' + self._plot(separate_vars=separate_vars) + self._save(bbox_inches='tight') + + @staticmethod + def _set_bootstrap_type(boot_type): + return {"singleinput": "single input"}.get(boot_type, boot_type) + + @staticmethod + def _set_bootstrap_method(boot_method): + return {"zero_mean": "zero mean", "shuffle": "shuffled"}.get(boot_method, boot_method) - def _prepare_data(self, data: Dict) -> pd.DataFrame: + def _prepare_data(self, data: Dict, sampling: str) -> pd.DataFrame: """ Shrink given data, if only scores are relevant. @@ -775,23 +805,53 @@ class PlotBootstrapSkillScore(AbstractPlotClass): :param data: dictionary with station names as keys and 2D xarrays as values :return: pre-processed data set """ - data = helpers.dict_to_xarray(data, "station").sortby(self._x_name) - new_boot_coords = self._return_vars_without_number_tag(data.coords['boot_var'].values, split_by='_', keep=1) - data = data.assign_coords({'boot_var': new_boot_coords}) - self._labels = [str(i) + "d" for i in data.coords["ahead"].values] - if "station" not in data.dims: - data = data.expand_dims("station") - return data.to_dataframe("data").reset_index(level=[0, 1, 2]) + station_dim = "station" + data = helpers.dict_to_xarray(data, station_dim).sortby(self._x_name) + if self._boot_type == "single input": + number_tags = self._get_number_tag(data.coords[self._x_name].values, split_by='_') + new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', + keep=1, as_unique=True) + values = data.values.reshape((data.shape[0], len(new_boot_coords), len(number_tags), data.shape[-1])) + data = xr.DataArray(values, coords={station_dim: data.coords["station"], self._x_name: new_boot_coords, + "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim]}, + dims=[station_dim, self._x_name, "branch", self._ahead_dim]) + else: + try: + new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', + keep=1) + data = data.assign_coords({self._x_name: new_boot_coords}) + except NotImplementedError: + pass + _, sampling_letter = self._get_target_sampling(sampling, 1) + self._labels = [str(i) + sampling_letter for i in data.coords[self._ahead_dim].values] + if station_dim not in data.dims: + data = data.expand_dims(station_dim) + return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()) + + @staticmethod + def _get_target_sampling(sampling, pos): + sampling = (sampling, sampling) if isinstance(sampling, str) else sampling + sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "") + return sampling, sampling_letter - def _return_vars_without_number_tag(self, values, split_by, keep): + def _return_vars_without_number_tag(self, values, split_by, keep, as_unique=False): arr = np.array([v.split(split_by) for v in values]) num = arr[:, 0] + if arr.shape[keep] == 1: # keep dim has only length 1, no number tags required + return num new_val = arr[:, keep] if self._all_values_are_equal(num, axis=0): return new_val + elif as_unique is True: + return np.unique(new_val) else: raise NotImplementedError + @staticmethod + def _get_number_tag(values, split_by): + arr = np.array([v.split(split_by) for v in values]) + num = arr[:, 0] + return np.unique(num).tolist() @staticmethod def _all_values_are_equal(arr, axis=0): @@ -809,45 +869,36 @@ class PlotBootstrapSkillScore(AbstractPlotClass): """ return "" if score_only else "terms and " - def _plot(self, separate_vars=None): + def _plot(self, branch=None, separate_vars=None): """Plot climatological skill score.""" if separate_vars is None: - self._plot_all_variables() + self._plot_all_variables(branch) else: self._plot_selected_variables(separate_vars) def _plot_selected_variables(self, separate_vars: List): - # if separate_vars is None: - # separate_vars = ['o3'] data = self._data - self.raise_error_if_separate_vars_do_not_exist(data, separate_vars) - all_variables = self._get_unique_values_from_column_of_df(data, 'boot_var') - # remaining_vars = helpers.list_pop(all_variables, separate_vars) #remove_items + self.raise_error_if_separate_vars_do_not_exist(data, separate_vars, self._x_name) + all_variables = self._get_unique_values_from_column_of_df(data, self._x_name) remaining_vars = helpers.remove_items(all_variables, separate_vars) - data_first = self._select_data(df=data, variables=separate_vars, column_name='boot_var') - data_second = self._select_data(df=data, variables=remaining_vars, column_name='boot_var') - - fig, ax = plt.subplots(nrows=1, ncols=2, - gridspec_kw={'width_ratios': [len(separate_vars), - len(remaining_vars) - ] - } - ) + data_first = self._select_data(df=data, variables=separate_vars, column_name=self._x_name) + data_second = self._select_data(df=data, variables=remaining_vars, column_name=self._x_name) + + fig, ax = plt.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [len(separate_vars), + len(remaining_vars)]}) if len(separate_vars) > 1: first_box_width = .8 else: first_box_width = 2. - sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_first, ax=ax[0], whis=1., palette="Blues_d", - showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, - flierprops={"marker": "."}, width=first_box_width - ) + sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_first, ax=ax[0], whis=1., + palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, + flierprops={"marker": "."}, width=first_box_width) ax[0].set(ylabel=f"skill score", xlabel="") - sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_second, ax=ax[1], whis=1., palette="Blues_d", - showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, - flierprops={"marker": "."}, - ) + sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_second, ax=ax[1], whis=1., + palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, + flierprops={"marker": "."}) ax[1].set(ylabel="", xlabel="") ax[1].yaxis.tick_right() handles, _ = ax[1].get_legend_handles_labels() @@ -882,9 +933,11 @@ class PlotBootstrapSkillScore(AbstractPlotClass): align_yaxis(ax[0], ax[1]) align_yaxis(ax[0], ax[1]) + plt.title(self._title) @staticmethod def _select_data(df: pd.DataFrame, variables: List[str], column_name: str) -> pd.DataFrame: + selected_data = None for i, variable in enumerate(variables): if i == 0: selected_data = df.loc[df[column_name] == variable] @@ -893,28 +946,29 @@ class PlotBootstrapSkillScore(AbstractPlotClass): selected_data = pd.concat([selected_data, tmp_var], axis=0) return selected_data - def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars): - if not self._variables_exist_in_df(df=data, variables=separate_vars): + def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars, column_name): + if not self._variables_exist_in_df(df=data, variables=separate_vars, column_name=column_name): raise ValueError(f"At least one entry of `separate_vars' does not exist in `self.data' ") @staticmethod def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List: return list(df[column_name].unique()) - def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str = 'boot_var'): + def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str): vars_in_df = set(self._get_unique_values_from_column_of_df(df, column_name)) return set(variables).issubset(vars_in_df) - def _plot_all_variables(self): + def _plot_all_variables(self, branch=None): """ """ fig, ax = plt.subplots() - sns.boxplot(x=self._x_name, y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d", + plot_data = self._data if branch is None else self._data[self._data["branch"] == str(branch)] + sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1., palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."}) ax.axhline(y=0, color="grey", linewidth=.5) plt.xticks(rotation=45) - ax.set(ylabel=f"skill score", xlabel="", title="summary of all stations") + ax.set(ylabel=f"skill score", xlabel="", title=self._title) handles, _ = ax.get_legend_handles_labels() ax.legend(handles, self._labels) plt.tight_layout() @@ -1029,8 +1083,6 @@ class PlotTimeSeries: def _plot_obs(self, ax, data): ahead = 1 obs_data = data.sel(type="obs", ahead=ahead).shift(index=ahead) - # index = data.index + np.timedelta64(1, self._sampling) - # ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs") ax.plot(obs_data, color=matplotlib.colors.cnames["green"], label="obs") @staticmethod diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index c5687e372298f9625794243324c77f2ed6abedb9..4755fff5b1709c688e420ed585e22b1ad9eab124 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -6,6 +6,7 @@ import logging import os import sys from typing import Union, Dict, Any, List, Callable +from dill.source import getsource from mlair.configuration import path_config from mlair import helpers @@ -20,7 +21,9 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \ DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_OVERSAMPLING_BINS, \ - DEFAULT_OVERSAMPLING_RATES_CAP, DEFAULT_OVERSAMPLING_METHOD + DEFAULT_OVERSAMPLING_RATES_CAP, DEFAULT_OVERSAMPLING_METHOD, \ + DEFAULT_MAX_NUMBER_MULTIPROCESSING, \ + DEFAULT_BOOTSTRAP_TYPE, DEFAULT_BOOTSTRAP_METHOD from mlair.data_handler import DefaultDataHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel @@ -215,11 +218,12 @@ class ExperimentSetup(RunEnvironment): create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, - number_of_bootstraps=None, - create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, + number_of_bootstraps=None, create_new_bootstraps=None, bootstrap_method=None, bootstrap_type=None, + data_path: str = None, batch_path: str = None, login_nodes=None, hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, data_origin: Dict = None, competitors: list = None, competitor_path: str = None, use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None, + max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, oversampling_bins=None, oversampling_rates_cap=None, oversampling_method = None, **kwargs): # create run framework @@ -273,6 +277,8 @@ class ExperimentSetup(RunEnvironment): default=DEFAULT_USE_MULTIPROCESSING_ON_DEBUG) else: self._set_param("use_multiprocessing", use_multiprocessing, default=DEFAULT_USE_MULTIPROCESSING) + self._set_param("max_number_multiprocessing", max_number_multiprocessing, + default=DEFAULT_MAX_NUMBER_MULTIPROCESSING) # batch path (temporary) self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data")) @@ -357,6 +363,8 @@ class ExperimentSetup(RunEnvironment): self._set_param("create_new_bootstraps", create_new_bootstraps, scope="general.postprocessing") self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS, scope="general.postprocessing") + self._set_param("bootstrap_method", bootstrap_method, default=DEFAULT_BOOTSTRAP_METHOD) + self._set_param("bootstrap_type", bootstrap_type, default=DEFAULT_BOOTSTRAP_TYPE) self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") self._set_param("neighbors", ["DEBW030"]) # TODO: just for testing @@ -373,6 +381,9 @@ class ExperimentSetup(RunEnvironment): # set model architecture class self._set_param("model_class", model, VanillaModel) + # store starting script if provided + if start_script is not None: + self._store_start_script(start_script, experiment_path) # set remaining kwargs if len(kwargs) > 0: @@ -395,6 +406,18 @@ class ExperimentSetup(RunEnvironment): logging.debug(f"set experiment attribute: {param}({scope})={value}") return value + @staticmethod + def _store_start_script(start_script, store_path): + out_file = os.path.join(store_path, "start_script.txt") + if isinstance(start_script, Callable): + with open(out_file, "w") as fh: + fh.write(getsource(start_script)) + if isinstance(start_script, str): + with open(start_script, 'r') as f: + with open(out_file, "w") as out: + for line in (f.readlines()): + print(line, end='', file=out) + def _compare_variables_and_statistics(self): """ Compare variables and statistics. diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 8fae430fb48a28bdd8b21f8bfcfc7c569eb24f6c..83f4a2bd96314d6f8c53f8cc9407cbc12e7b9a16 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -12,7 +12,7 @@ import keras import pandas as pd import tensorflow as tf -from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler +from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.configuration import path_config @@ -119,11 +119,14 @@ class ModelSetup(RunEnvironment): """ lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) hist = HistoryAdvanced() + epo_timing = EpoTimingCallback() self.data_store.set("hist", hist, scope="model") + self.data_store.set("epo_timing", epo_timing, scope="model") callbacks = CallbackHandler() if lr is not None: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") + callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing") callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') self.data_store.set("callbacks", callbacks, self.scope) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 8a594808536ca5552003e88c4dbfd181237bb526..cef2c6510ae283b5ce5ca826b0d721edf6a57e76 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -86,6 +86,7 @@ class PostProcessing(RunEnvironment): self.competitor_path = self.data_store.get("competitor_path") self.competitors = to_list(self.data_store.get_default("competitors", default=[])) self.forecast_indicator = "nn" + self.ahead_dim = "ahead" self._run() def _run(self): @@ -103,7 +104,10 @@ class PostProcessing(RunEnvironment): if self.data_store.get("evaluate_bootstraps", "postprocessing"): with TimeTracking(name="calculate bootstraps"): create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing") - self.bootstrap_postprocessing(create_new_bootstraps) + bootstrap_method = self.data_store.get("bootstrap_method", "postprocessing") + bootstrap_type = self.data_store.get("bootstrap_type", "postprocessing") + self.bootstrap_postprocessing(create_new_bootstraps, bootstrap_type=bootstrap_type, + bootstrap_method=bootstrap_method) # skill scores and error metrics with TimeTracking(name="calculate skill scores"): @@ -136,7 +140,8 @@ class PostProcessing(RunEnvironment): continue return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None - def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0) -> None: + def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type="singleinput", + bootstrap_method="shuffle") -> None: """ Calculate skill scores of bootstrapped data. @@ -149,18 +154,26 @@ class PostProcessing(RunEnvironment): :param _iter: internal counter to reduce unnecessary recursive calls (maximum number is 2, otherwise something went wrong). """ - try: - if create_new_bootstraps: - self.create_bootstrap_forecast() - self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores() - except FileNotFoundError: - if _iter != 0: - raise RuntimeError("bootstrap_postprocessing is called for the 2nd time. This means, that calling" - "manually the reason for the failure.") - logging.info("Couldn't load all files, restart bootstrap postprocessing with create_new_bootstraps=True.") - self.bootstrap_postprocessing(True, _iter=1) - - def create_bootstrap_forecast(self) -> None: + self.bootstrap_skill_scores = {} + for boot_type in to_list(bootstrap_type): + self.bootstrap_skill_scores[boot_type] = {} + for boot_method in to_list(bootstrap_method): + try: + if create_new_bootstraps: + self.create_bootstrap_forecast(bootstrap_type=boot_type, bootstrap_method=boot_method) + boot_skill_score = self.calculate_bootstrap_skill_scores(bootstrap_type=boot_type, + bootstrap_method=boot_method) + self.bootstrap_skill_scores[boot_type][boot_method] = boot_skill_score + except FileNotFoundError: + if _iter != 0: + raise RuntimeError(f"bootstrap_postprocessing ({boot_type}, {boot_type}) was called for the 2nd" + f" time. This means, that something internally goes wrong. Please check for " + f"possible errors") + logging.info(f"Could not load all files for bootstrapping ({boot_type}, {boot_type}), restart " + f"bootstrap postprocessing with create_new_bootstraps=True.") + self.bootstrap_postprocessing(True, _iter=1, bootstrap_type=boot_type, bootstrap_method=boot_method) + + def create_bootstrap_forecast(self, bootstrap_type, bootstrap_method) -> None: """ Create bootstrapped predictions for all stations and variables. @@ -168,16 +181,15 @@ class PostProcessing(RunEnvironment): `bootstraps_labels_{station}.nc`. """ # forecast - with TimeTracking(name=inspect.stack()[0].function): + with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"): # extract all requirements from data store - bootstrap_path = self.data_store.get("bootstrap_path") forecast_path = self.data_store.get("forecast_path") number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing") - dims = ["index", "ahead", "type"] + dims = ["index", self.ahead_dim, "type"] for station in self.test_data: - logging.info(str(station)) X, Y = None, None - bootstraps = BootStraps(station, number_of_bootstraps) + bootstraps = BootStraps(station, number_of_bootstraps, bootstrap_type=bootstrap_type, + bootstrap_method=bootstrap_method) for boot in bootstraps: X, Y, (index, dimension) = boot # make bootstrap predictions @@ -188,18 +200,19 @@ class PostProcessing(RunEnvironment): bootstrap_predictions = np.expand_dims(bootstrap_predictions, axis=-1) shape = bootstrap_predictions.shape coords = (range(shape[0]), range(1, shape[1] + 1)) - var = f"{index}_{dimension}" + var = f"{index}_{dimension}" if index is not None else str(dimension) tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims) - file_name = os.path.join(forecast_path, f"bootstraps_{station}_{var}.nc") + file_name = os.path.join(forecast_path, + f"bootstraps_{station}_{var}_{bootstrap_type}_{bootstrap_method}.nc") tmp.to_netcdf(file_name) else: # store also true labels for each station labels = np.expand_dims(Y, axis=-1) - file_name = os.path.join(forecast_path, f"bootstraps_{station}_labels.nc") + file_name = os.path.join(forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc") labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=dims) labels.to_netcdf(file_name) - def calculate_bootstrap_skill_scores(self) -> Dict[str, xr.DataArray]: + def calculate_bootstrap_skill_scores(self, bootstrap_type, bootstrap_method) -> Dict[str, xr.DataArray]: """ Calculate skill score of bootstrapped variables. @@ -209,53 +222,64 @@ class PostProcessing(RunEnvironment): :return: The result dictionary with station-wise skill scores """ - with TimeTracking(name=inspect.stack()[0].function): + with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"): # extract all requirements from data store - bootstrap_path = self.data_store.get("bootstrap_path") forecast_path = self.data_store.get("forecast_path") number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing") forecast_file = f"forecasts_norm_%s_test.nc" - bootstraps = BootStraps(self.test_data[0], number_of_bootstraps).bootstraps() - skill_scores = statistics.SkillScores(None) + + bootstraps = BootStraps(self.test_data[0], number_of_bootstraps, bootstrap_type=bootstrap_type, + bootstrap_method=bootstrap_method) + number_of_bootstraps = bootstraps.number_of_bootstraps + bootstrap_iter = bootstraps.bootstraps() + skill_scores = statistics.SkillScores(None, ahead_dim=self.ahead_dim) score = {} for station in self.test_data: - logging.info(station) - # get station labels - file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_labels.nc") - labels = xr.open_dataarray(file_name) + file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc") + with xr.open_dataarray(file_name) as da: + labels = da.load() shape = labels.shape # get original forecasts orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), number_of_bootstraps) orig = orig.reshape(shape) coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"]) - orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"]) + orig = xr.DataArray(orig, coords=coords, dims=["index", self.ahead_dim, "type"]) # calculate skill scores for each variable skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1)) - for boot_set in bootstraps: - boot_var = f"{boot_set[0]}_{boot_set[1]}" - file_name = os.path.join(forecast_path, f"bootstraps_{station}_{boot_var}.nc") - boot_data = xr.open_dataarray(file_name) + for boot_set in bootstrap_iter: + boot_var = f"{boot_set[0]}_{boot_set[1]}" if isinstance(boot_set, tuple) else str(boot_set) + file_name = os.path.join(forecast_path, + f"bootstraps_{station}_{boot_var}_{bootstrap_type}_{bootstrap_method}.nc") + with xr.open_dataarray(file_name) as da: + boot_data = da.load() boot_data = boot_data.combine_first(labels).combine_first(orig) boot_scores = [] for ahead in range(1, self.window_lead_time + 1): - data = boot_data.sel(ahead=ahead) + data = boot_data.sel({self.ahead_dim: ahead}) boot_scores.append( skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name="orig")) skill.loc[boot_var] = np.array(boot_scores) # collect all results in single dictionary - score[str(station)] = xr.DataArray(skill, dims=["boot_var", "ahead"]) + score[str(station)] = xr.DataArray(skill, dims=["boot_var", self.ahead_dim]) return score def get_orig_prediction(self, path, file_name, number_of_bootstraps, prediction_name=None): if prediction_name is None: prediction_name = self.forecast_indicator file = os.path.join(path, file_name) - prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze() - vals = np.tile(prediction.data, (number_of_bootstraps, 1)) + with xr.open_dataarray(file) as da: + prediction = da.load().sel(type=prediction_name).squeeze() + return self.repeat_data(prediction, number_of_bootstraps) + + @staticmethod + def repeat_data(data, number_of_repetition): + if isinstance(data, xr.DataArray): + data = data.data + vals = np.tile(data, (number_of_repetition, 1)) return vals[~np.isnan(vals).any(axis=1), :] def _get_model_name(self): @@ -335,8 +359,16 @@ class PostProcessing(RunEnvironment): try: if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list): - PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, - model_setup=self.forecast_indicator) + for boot_type, boot_data in self.bootstrap_skill_scores.items(): + for boot_method, boot_skill_score in boot_data.items(): + try: + PlotBootstrapSkillScore(boot_skill_score, plot_folder=self.plot_path, + model_setup=self.forecast_indicator, sampling=self._sampling, + ahead_dim=self.ahead_dim, separate_vars=to_list(self.target_var), + bootstrap_type=boot_type, bootstrap_method=boot_method) + except Exception as e: + logging.error(f"Could not create plot PlotBootstrapSkillScore ({boot_type}, {boot_method}) " + f"due to the following error: {e}") except Exception as e: logging.error(f"Could not create plot PlotBootstrapSkillScore due to the following error: {e}") @@ -486,7 +518,8 @@ class PostProcessing(RunEnvironment): "obs": observation, "ols": ols_prediction} all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]), - time_dimension, **prediction_dict) + time_dimension, ahead_dim=self.ahead_dim, + **prediction_dict) # save all forecasts locally path = self.data_store.get("forecast_path") @@ -512,8 +545,8 @@ class PostProcessing(RunEnvironment): """ path = os.path.join(self.competitor_path, competitor_name) file = os.path.join(path, f"forecasts_{station_name}_test.nc") - data = xr.open_dataarray(file) - # data = data.expand_dims(Stations=[station_name]) # ToDo: remove line + with xr.open_dataarray(file) as da: + data = da.load() forecast = data.sel(type=[self.forecast_indicator]) forecast.coords["type"] = [competitor_name] return forecast @@ -550,7 +583,14 @@ class PostProcessing(RunEnvironment): """ tmp_ols = self.ols_model.predict(input_data) target_shape = ols_prediction.values.shape - ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols + if target_shape != tmp_ols.shape: + if len(target_shape)==2: + new_values = np.swapaxes(tmp_ols,1,0) + else: + new_values = np.swapaxes(tmp_ols, 2, 0) + else: + new_values = tmp_ols + ols_prediction.values = new_values if not normalised: ols_prediction = transformation_func(ols_prediction, "target", inverse=True) return ols_prediction @@ -637,7 +677,8 @@ class PostProcessing(RunEnvironment): return index @staticmethod - def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, **kwargs): + def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, + ahead_dim="ahead", **kwargs): """ Combine different forecast types into single xarray. @@ -650,7 +691,7 @@ class PostProcessing(RunEnvironment): """ keys = list(kwargs.keys()) res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan), - coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type']) + coords=[index.index, ahead_names, keys], dims=['index', ahead_dim, 'type']) for k, v in kwargs.items(): intersection = set(res.index.values) & set(v.indexes[time_dimension].values) match_index = np.array(list(intersection)) @@ -668,7 +709,8 @@ class PostProcessing(RunEnvironment): """ try: file = os.path.join(path, f"forecasts_{str(station)}_train_val.nc") - return xr.open_dataarray(file) + with xr.open_dataarray(file) as da: + return da.load() except (IndexError, KeyError, FileNotFoundError): return None @@ -683,7 +725,8 @@ class PostProcessing(RunEnvironment): """ try: file = os.path.join(path, f"forecasts_{str(station)}_test.nc") - return xr.open_dataarray(file) + with xr.open_dataarray(file) as da: + return da.load() except (IndexError, KeyError, FileNotFoundError): return None @@ -725,14 +768,14 @@ class PostProcessing(RunEnvironment): competitor = self.load_competitors(station) combined = self._combine_forecasts(external_data, competitor, dim="type") model_list = remove_items(list(combined.type.values), "obs") if combined is not None else None - skill_score = statistics.SkillScores(combined, models=model_list) + skill_score = statistics.SkillScores(combined, models=model_list, ahead_dim=self.ahead_dim) if external_data is not None: - skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time) + skill_score_competitive[station] = skill_score.skill_scores() internal_data = self._get_internal_data(station, path) if internal_data is not None: skill_score_climatological[station] = skill_score.climatological_skill_scores( - internal_data, self.window_lead_time, forecast_name=self.forecast_indicator) + internal_data, forecast_name=self.forecast_indicator) errors.update({"total": self.calculate_average_errors(errors)}) return skill_score_competitive, skill_score_climatological, errors diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index c00655239a6c0da727cee3462595ea959356a73a..3354e78c0c9ee85dad71f15a7a0171248913c0b7 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -284,10 +284,11 @@ class PreProcessing(RunEnvironment): kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name) use_multiprocessing = self.data_store.get("use_multiprocessing") - if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution + max_process = self.data_store.get("max_number_multiprocessing") + n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus + if n_process > 1 and use_multiprocessing is True: # parallel solution logging.info("use parallel validate station approach") - pool = multiprocessing.Pool( - min([psutil.cpu_count(logical=False), len(set_stations), 16])) # use only physical cpus + pool = multiprocessing.Pool(n_process) logging.info(f"running {getattr(pool, '_processes')} processes in parallel") output = [ pool.apply_async(f_proc, args=(data_handler, station, set_name, store_processed_data), kwds=kwargs) @@ -309,40 +310,22 @@ class PreProcessing(RunEnvironment): logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" f"{len(set_stations)} valid stations.") - return collection, valid_stations - - def validate_station_old(self, data_handler: AbstractDataHandler, set_stations, set_name=None, - store_processed_data=True): - """ - Check if all given stations in `all_stations` are valid. - - Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the - loading time are logged in debug mode. - - :return: Corrected list containing only valid station IDs. - """ - t_outer = TimeTracking() - logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}") - # calculate transformation using train data if set_name == "train": - logging.info("setup transformation using train data exclusively") - self.transformation(data_handler, set_stations) - # start station check - collection = DataCollection() - valid_stations = [] - kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name) - for station in set_stations: - try: - dp = data_handler.build(station, name_affix=set_name, store_processed_data=store_processed_data, - **kwargs) - collection.add(dp) - valid_stations.append(station) - except (AttributeError, EmptyQueryResult): - continue - logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" - f"{len(set_stations)} valid stations.") + self.store_data_handler_attributes(data_handler, collection) return collection, valid_stations + def store_data_handler_attributes(self, data_handler, collection): + store_attributes = data_handler.store_attributes() + if len(store_attributes) > 0: + logging.info("store data requested by the data handler") + attrs = {} + for dh in collection: + station = str(dh) + for k, v in dh.get_store_attributes().items(): + attrs[k] = dict(attrs.get(k, {}), **{station: v}) + for k, v in attrs.items(): + self.data_store.set(k, v) + def transformation(self, data_handler: AbstractDataHandler, stations): if hasattr(data_handler, "transformation"): kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") @@ -378,10 +361,11 @@ def f_proc(data_handler, station, name_affix, store, **kwargs): """ try: res = data_handler.build(station, name_affix=name_affix, store_processed_data=store, **kwargs) - except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError) as e: + except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError, IndexError) as e: formatted_lines = traceback.format_exc().splitlines() logging.info( f"remove station {station} because it raised an error: {e} -> {' | '.join(f_inspect_error(formatted_lines))}") + logging.debug(f"detailed information for removal of station {station}: {traceback.format_exc()}") res = None return res, station diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 5f895b77d53d45bedc255bc7ff051f9d6a8d20a3..00e8eae1581453666d3ca11f48fcdaedf6a24ad0 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -166,7 +166,11 @@ class Training(RunEnvironment): lr = self.callbacks.get_callback_by_name("lr") except IndexError: lr = None - self.save_callbacks_as_json(history, lr) + try: + epo_timing = self.callbacks.get_callback_by_name("epo_timing") + except IndexError: + epo_timing = None + self.save_callbacks_as_json(history, lr, epo_timing) self.load_best_model(checkpoint.filepath) self.create_monitoring_plots(history, lr) @@ -190,7 +194,7 @@ class Training(RunEnvironment): except OSError: logging.info('no weights to reload...') - def save_callbacks_as_json(self, history: Callback, lr_sc: Callback) -> None: + def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None: """ Save callbacks (history, learning rate) of training. @@ -207,6 +211,9 @@ class Training(RunEnvironment): if lr_sc: with open(os.path.join(path, "history_lr.json"), "w") as f: json.dump(lr_sc.lr, f) + if epo_timing is not None: + with open(os.path.join(path, "epo_timing.json"), "w") as f: + json.dump(epo_timing.epo_timing, f) def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None: """ diff --git a/run.py b/run.py index 05b43ade453a4eb36952e18ad1c7ebab788dc37d..bd93db698c55bc8bae49c5d39a85f9d26cc49780 100644 --- a/run.py +++ b/run.py @@ -29,7 +29,7 @@ def main(parser_args): evaluate_bootstraps=False, # plot_list=["PlotCompetitiveSkillScore"], competitors=["test_model", "test_model2"], competitor_path=os.path.join(os.getcwd(), "data", "comp_test"), - **parser_args.__dict__) + **parser_args.__dict__, start_script=__file__) workflow.run() diff --git a/run_HPC.py b/run_HPC.py index d6dbb4dc61e88a1e139b3cbe549bc6a3f2f0ab8a..dfa5045bbccf993d2381ff32c5aead90ea6957f3 100644 --- a/run_HPC.py +++ b/run_HPC.py @@ -7,7 +7,7 @@ from mlair.workflows import DefaultWorkflowHPC def main(parser_args): - workflow = DefaultWorkflowHPC(**parser_args.__dict__) + workflow = DefaultWorkflowHPC(**parser_args.__dict__, start_script=__file__) workflow.run() diff --git a/run_hourly.py b/run_hourly.py index 48c7205883eda7e08ee1c14fe3c0a8a9f429e3da..869f8ea16cd4093e04e40f1b05f863ca45ce3c99 100644 --- a/run_hourly.py +++ b/run_hourly.py @@ -22,7 +22,7 @@ def main(parser_args): train_model=False, create_new_model=False, network="UBA", - plot_list=["PlotStationMap"], **parser_args.__dict__) + plot_list=["PlotStationMap"], **parser_args.__dict__, start_script=__file__) workflow.run() diff --git a/run_hourly_kz.py b/run_hourly_kz.py index 5536b56e732d81b84dfee7f34bd68d0d2ba49020..ba2939162c3fd22fc6a611bc7bc21b9334fbfd3b 100644 --- a/run_hourly_kz.py +++ b/run_hourly_kz.py @@ -19,7 +19,7 @@ def main(parser_args): test_end="2011-12-31", stations=["DEBW107", "DEBW013"] ) - workflow = DefaultWorkflow(**args) + workflow = DefaultWorkflow(**args, start_script=__file__) workflow.run() diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py index 6ffb659953157060c39afb5960821e729df555dd..819ef51129854b4539632ef91a55e33a2607eb55 100644 --- a/run_mixed_sampling.py +++ b/run_mixed_sampling.py @@ -36,7 +36,7 @@ def main(parser_args): test_end="2011-12-31", **parser_args.__dict__, ) - workflow = DefaultWorkflow(**args) + workflow = DefaultWorkflow(**args, start_script=__file__) workflow.run() diff --git a/run_zam347.py b/run_zam347.py index 352f04177167441d3636359a9f6ade5f039c12c1..49fce3e7a0c0f2b24691c5b02590ff435300f552 100644 --- a/run_zam347.py +++ b/run_zam347.py @@ -31,7 +31,7 @@ def load_stations(): def main(parser_args): - workflow = DefaultWorkflowHPC(stations=load_stations(), **parser_args.__dict__) + workflow = DefaultWorkflowHPC(stations=load_stations(), **parser_args.__dict__, start_script=__file__) workflow.run() diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py index 922de3599dc7dc40717e0aeb8c7b8158ad21da38..27f38ce67b65c93a465051ab24fac1e8479fea59 100644 --- a/test/test_configuration/test_defaults.py +++ b/test/test_configuration/test_defaults.py @@ -68,4 +68,5 @@ class TestAllDefaults: assert DEFAULT_PLOT_LIST == ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", "PlotAvailability", "PlotAvailabilityHistogram", - "PlotDataHistogram","PlotOversampling","PlotOversamplingContingency"] + "PlotDataHistogram", "PlotPeriodogram","PlotOversampling", + "PlotOversamplingContingency"] diff --git a/test/test_data_handler/old_t_bootstraps.py b/test/test_data_handler/old_t_bootstraps.py index 9616ed3f457d74e44e8a9eae5a3ed862fa804011..21c18c6c2d6f6a6a38a41250f00d3d14a29ed457 100644 --- a/test/test_data_handler/old_t_bootstraps.py +++ b/test/test_data_handler/old_t_bootstraps.py @@ -160,7 +160,7 @@ class TestCreateShuffledData: def test_shuffle(self, shuffled_data_no_creation): dummy = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - res = shuffled_data_no_creation.shuffle(dummy, chunks=(2, 3)).compute() + res = shuffled_data_no_creation.apply_bootstrap_method(dummy, chunks=(2, 3)).compute() assert res.shape == dummy.shape assert dummy.max() >= res.max() assert dummy.min() <= res.min() diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py index 2a6553b7f495bb4eb8aeddf7c39f2f2517edc967..7418a435008f06a9016f903fe140b51d0a7c8106 100644 --- a/test/test_data_handler/test_data_handler_mixed_sampling.py +++ b/test/test_data_handler/test_data_handler_mixed_sampling.py @@ -2,10 +2,10 @@ __author__ = 'Lukas Leufen' __date__ = '2020-12-10' from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \ - DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilter, \ - DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerSeparationOfScales, \ - DataHandlerSeparationOfScalesSingleStation -from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation + DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithKzFilter, \ + DataHandlerMixedSamplingWithKzFilterSingleStation, DataHandlerSeparationOfScales, \ + DataHandlerSeparationOfScalesSingleStation, DataHandlerMixedSamplingWithFilterSingleStation +from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation from mlair.helpers import remove_items from mlair.configuration.defaults import DEFAULT_INTERPOLATION_METHOD @@ -86,19 +86,19 @@ class TestDataHandlerMixedSamplingSingleStation: pass -class TestDataHandlerMixedSamplingWithFilter: +class TestDataHandlerMixedSamplingWithKzFilter: def test_data_handler(self): - obj = object.__new__(DataHandlerMixedSamplingWithFilter) - assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__ + obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) + assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__ def test_data_handler_transformation(self): - obj = object.__new__(DataHandlerMixedSamplingWithFilter) - assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__ + obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) + assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__ def test_requirements(self): - obj = object.__new__(DataHandlerMixedSamplingWithFilter) - req1 = object.__new__(DataHandlerMixedSamplingSingleStation) + obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) + req1 = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation) req2 = object.__new__(DataHandlerKzFilterSingleStation) req = list(set(req1.requirements() + req2.requirements())) assert sorted(obj._requirements) == sorted(remove_items(req, "station")) @@ -119,8 +119,8 @@ class TestDataHandlerSeparationOfScales: assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ def test_requirements(self): - obj = object.__new__(DataHandlerMixedSamplingWithFilter) - req1 = object.__new__(DataHandlerMixedSamplingSingleStation) + obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) + req1 = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation) req2 = object.__new__(DataHandlerKzFilterSingleStation) req = list(set(req1.requirements() + req2.requirements())) assert sorted(obj._requirements) == sorted(remove_items(req, "station")) diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py index 8a7572148869537b505b2bd8e7f16cfdf7af1cdd..7cefd0e58f5b9b0787bafddffe1ad07e4851a068 100644 --- a/test/test_run_modules/test_model_setup.py +++ b/test/test_run_modules/test_model_setup.py @@ -80,7 +80,7 @@ class TestModelSetup: setup._set_callbacks() assert "general.model" in setup.data_store.search_name("callbacks") callbacks = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 3 + assert len(callbacks.get_callbacks()) == 4 def test_set_callbacks_no_lr_decay(self, setup): setup.data_store.set("lr_decay", None, "general.model") @@ -88,7 +88,7 @@ class TestModelSetup: setup.checkpoint_name = "TestName" setup._set_callbacks() callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 2 + assert len(callbacks.get_callbacks()) == 3 with pytest.raises(IndexError): callbacks.get_callback_by_name("lr_decay") diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py index 5ae64bf3d535e72d9361394741ed8b8094091b1d..0f2ee7a10fd2e3190c0b66da558626747d4c03c9 100644 --- a/test/test_run_modules/test_pre_processing.py +++ b/test/test_run_modules/test_pre_processing.py @@ -109,7 +109,7 @@ class TestPreProcessing: assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 ' r'station\(s\). Found 5/6 valid stations.')) - @mock.patch("multiprocessing.cpu_count", return_value=3) + @mock.patch("psutil.cpu_count", return_value=3) @mock.patch("multiprocessing.Pool", return_value=multiprocessing.Pool(3)) def test_validate_station_parallel(self, mock_pool, mock_cpu, caplog, obj_with_exp_setup): pre = obj_with_exp_setup diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index c2b58cbd2160bd958c76ba67649ef8caba09fcb4..ed0d8264326f5299403c47deb46859ccde4a85d7 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -13,7 +13,7 @@ from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler from mlair.helpers import PyTestRegex from mlair.model_modules.flatten import flatten_tail from mlair.model_modules.inception_model import InceptionModelBase -from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler +from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback from mlair.run_modules.run_environment import RunEnvironment from mlair.run_modules.training import Training @@ -100,6 +100,12 @@ class TestTraining: h.model = mock.MagicMock() return h + @pytest.fixture + def epo_timing(self): + epo_timing = EpoTimingCallback() + epo_timing.epoch = [0, 1] + epo_timing.epo_timing = {"epo_timing": [0.1, 0.2]} + @pytest.fixture def path(self): return os.path.join(os.path.dirname(__file__), "TestExperiment") @@ -144,9 +150,11 @@ class TestTraining: def callbacks(self, path): clbk = CallbackHandler() hist = HistoryAdvanced() + epo_timing = EpoTimingCallback() clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") lr = LearningRateDecay() clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") + clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing") clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True) return clbk, hist, lr @@ -256,22 +264,22 @@ class TestTraining: assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting")) assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) - def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, model_path): - init_without_run.save_callbacks_as_json(history, learning_rate) + def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): + init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) assert "history.json" in os.listdir(model_path) - def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, model_path): - init_without_run.save_callbacks_as_json(history, learning_rate) + def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, epo_timing, model_path): + init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) assert "history_lr.json" in os.listdir(model_path) - def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, model_path): - init_without_run.save_callbacks_as_json(history, learning_rate) + def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, epo_timing, model_path): + init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) with open(os.path.join(model_path, "history.json")) as jfile: hist = json.load(jfile) assert hist == history.history - def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, model_path): - init_without_run.save_callbacks_as_json(history, learning_rate) + def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, epo_timing, model_path): + init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) with open(os.path.join(model_path, "history_lr.json")) as jfile: lr = json.load(jfile) assert lr == learning_rate.lr